420 lines
14 KiB
Python
420 lines
14 KiB
Python
|
|
"""
|
||
|
|
Session Manager
|
||
|
|
|
||
|
|
Manages persistent Claude Code sessions with MCP integration.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import AsyncGenerator, Dict, List, Literal, Optional
|
||
|
|
|
||
|
|
from .conversation_store import ConversationStore
|
||
|
|
from .context_builder import ContextBuilder
|
||
|
|
|
||
|
|
# Paths
|
||
|
|
ATOMIZER_ROOT = Path(__file__).parent.parent.parent.parent.parent
|
||
|
|
MCP_SERVER_PATH = ATOMIZER_ROOT / "mcp-server" / "atomizer-tools"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ClaudeSession:
|
||
|
|
"""Represents an active Claude Code session"""
|
||
|
|
|
||
|
|
session_id: str
|
||
|
|
mode: Literal["user", "power"]
|
||
|
|
study_id: Optional[str]
|
||
|
|
process: Optional[asyncio.subprocess.Process] = None
|
||
|
|
created_at: datetime = field(default_factory=datetime.now)
|
||
|
|
last_active: datetime = field(default_factory=datetime.now)
|
||
|
|
|
||
|
|
def is_alive(self) -> bool:
|
||
|
|
"""Check if the subprocess is still running"""
|
||
|
|
return self.process is not None and self.process.returncode is None
|
||
|
|
|
||
|
|
|
||
|
|
class SessionManager:
|
||
|
|
"""Manages Claude Code sessions with MCP tools"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.sessions: Dict[str, ClaudeSession] = {}
|
||
|
|
self.store = ConversationStore()
|
||
|
|
self.context_builder = ContextBuilder()
|
||
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
|
||
|
|
async def start(self):
|
||
|
|
"""Start the session manager"""
|
||
|
|
# Start periodic cleanup of stale sessions
|
||
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
|
|
|
||
|
|
async def stop(self):
|
||
|
|
"""Stop the session manager and all sessions"""
|
||
|
|
if self._cleanup_task:
|
||
|
|
self._cleanup_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._cleanup_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Terminate all sessions
|
||
|
|
for session in list(self.sessions.values()):
|
||
|
|
await self._terminate_session(session)
|
||
|
|
|
||
|
|
async def create_session(
|
||
|
|
self,
|
||
|
|
mode: Literal["user", "power"] = "user",
|
||
|
|
study_id: Optional[str] = None,
|
||
|
|
resume_session_id: Optional[str] = None,
|
||
|
|
) -> ClaudeSession:
|
||
|
|
"""
|
||
|
|
Create or resume a Claude Code session.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mode: "user" for safe mode, "power" for full access
|
||
|
|
study_id: Optional study context
|
||
|
|
resume_session_id: Optional session ID to resume
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
ClaudeSession object
|
||
|
|
"""
|
||
|
|
async with self._lock:
|
||
|
|
# Resume existing session if requested and alive
|
||
|
|
if resume_session_id and resume_session_id in self.sessions:
|
||
|
|
session = self.sessions[resume_session_id]
|
||
|
|
if session.is_alive():
|
||
|
|
session.last_active = datetime.now()
|
||
|
|
self.store.touch_session(session.session_id)
|
||
|
|
return session
|
||
|
|
|
||
|
|
session_id = resume_session_id or str(uuid.uuid4())[:8]
|
||
|
|
|
||
|
|
# Create or update session in store
|
||
|
|
existing = self.store.get_session(session_id)
|
||
|
|
if existing:
|
||
|
|
self.store.update_session(session_id, mode=mode, study_id=study_id)
|
||
|
|
else:
|
||
|
|
self.store.create_session(session_id, mode, study_id)
|
||
|
|
|
||
|
|
# Build MCP config for this session
|
||
|
|
mcp_config = self._build_mcp_config(mode)
|
||
|
|
mcp_config_path = ATOMIZER_ROOT / f".claude-mcp-{session_id}.json"
|
||
|
|
with open(mcp_config_path, "w") as f:
|
||
|
|
json.dump(mcp_config, f)
|
||
|
|
|
||
|
|
# Build system prompt with context
|
||
|
|
history = self.store.get_history(session_id) if resume_session_id else []
|
||
|
|
system_prompt = self.context_builder.build(
|
||
|
|
mode=mode,
|
||
|
|
study_id=study_id,
|
||
|
|
conversation_history=history,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Write system prompt to temp file
|
||
|
|
prompt_path = ATOMIZER_ROOT / f".claude-prompt-{session_id}.md"
|
||
|
|
with open(prompt_path, "w") as f:
|
||
|
|
f.write(system_prompt)
|
||
|
|
|
||
|
|
# Build environment
|
||
|
|
env = os.environ.copy()
|
||
|
|
env["ATOMIZER_MODE"] = mode
|
||
|
|
env["ATOMIZER_ROOT"] = str(ATOMIZER_ROOT)
|
||
|
|
if study_id:
|
||
|
|
env["ATOMIZER_STUDY"] = study_id
|
||
|
|
|
||
|
|
# Start Claude Code subprocess
|
||
|
|
# Note: claude CLI with appropriate flags for JSON streaming
|
||
|
|
try:
|
||
|
|
process = await asyncio.create_subprocess_exec(
|
||
|
|
"claude",
|
||
|
|
"--print", # Non-interactive mode
|
||
|
|
"--output-format", "stream-json",
|
||
|
|
"--mcp-config", str(mcp_config_path),
|
||
|
|
"--system-prompt", str(prompt_path),
|
||
|
|
stdin=asyncio.subprocess.PIPE,
|
||
|
|
stdout=asyncio.subprocess.PIPE,
|
||
|
|
stderr=asyncio.subprocess.PIPE,
|
||
|
|
cwd=str(ATOMIZER_ROOT),
|
||
|
|
env=env,
|
||
|
|
)
|
||
|
|
except FileNotFoundError:
|
||
|
|
# Claude CLI not found - create session without process
|
||
|
|
# Frontend will get error on first message
|
||
|
|
process = None
|
||
|
|
|
||
|
|
session = ClaudeSession(
|
||
|
|
session_id=session_id,
|
||
|
|
mode=mode,
|
||
|
|
study_id=study_id,
|
||
|
|
process=process,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.sessions[session_id] = session
|
||
|
|
return session
|
||
|
|
|
||
|
|
async def send_message(
|
||
|
|
self,
|
||
|
|
session_id: str,
|
||
|
|
message: str,
|
||
|
|
) -> AsyncGenerator[Dict, None]:
|
||
|
|
"""
|
||
|
|
Send a message to a session and stream the response.
|
||
|
|
|
||
|
|
Uses one-shot Claude CLI calls (claude --print) since the CLI
|
||
|
|
doesn't support persistent interactive sessions via stdin/stdout.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session_id: Session ID
|
||
|
|
message: User message
|
||
|
|
|
||
|
|
Yields:
|
||
|
|
Response chunks (text, tool_calls, errors, done)
|
||
|
|
"""
|
||
|
|
session = self.sessions.get(session_id)
|
||
|
|
|
||
|
|
if not session:
|
||
|
|
yield {"type": "error", "message": "Session not found"}
|
||
|
|
return
|
||
|
|
|
||
|
|
session.last_active = datetime.now()
|
||
|
|
|
||
|
|
# Store user message
|
||
|
|
self.store.add_message(session_id, "user", message)
|
||
|
|
|
||
|
|
# Build context with conversation history
|
||
|
|
history = self.store.get_history(session_id, limit=10)
|
||
|
|
full_prompt = self.context_builder.build(
|
||
|
|
mode=session.mode,
|
||
|
|
study_id=session.study_id,
|
||
|
|
conversation_history=history[:-1], # Exclude current message
|
||
|
|
)
|
||
|
|
full_prompt += f"\n\nUser: {message}\n\nRespond helpfully and concisely:"
|
||
|
|
|
||
|
|
# Run Claude CLI one-shot
|
||
|
|
full_response = ""
|
||
|
|
tool_calls: List[Dict] = []
|
||
|
|
|
||
|
|
# Build CLI arguments based on mode
|
||
|
|
cli_args = ["claude", "--print"]
|
||
|
|
|
||
|
|
if session.mode == "user":
|
||
|
|
# User mode: Allow safe operations including report generation
|
||
|
|
# Allow Write tool for report files (STUDY_REPORT.md, *.md in study dirs)
|
||
|
|
cli_args.extend([
|
||
|
|
"--allowedTools",
|
||
|
|
"Read Write(**/STUDY_REPORT.md) Write(**/3_results/*.md) Bash(python:*)"
|
||
|
|
])
|
||
|
|
else:
|
||
|
|
# Power mode: Full access
|
||
|
|
cli_args.append("--dangerously-skip-permissions")
|
||
|
|
|
||
|
|
# Pass prompt via stdin (handles long prompts and special characters)
|
||
|
|
cli_args.append("-") # Read from stdin
|
||
|
|
|
||
|
|
try:
|
||
|
|
process = await asyncio.create_subprocess_exec(
|
||
|
|
*cli_args,
|
||
|
|
stdin=asyncio.subprocess.PIPE,
|
||
|
|
stdout=asyncio.subprocess.PIPE,
|
||
|
|
stderr=asyncio.subprocess.PIPE,
|
||
|
|
cwd=str(ATOMIZER_ROOT),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Send prompt via stdin
|
||
|
|
process.stdin.write(full_prompt.encode())
|
||
|
|
await process.stdin.drain()
|
||
|
|
process.stdin.close()
|
||
|
|
await process.stdin.wait_closed()
|
||
|
|
|
||
|
|
# Stream stdout
|
||
|
|
buffer = ""
|
||
|
|
while True:
|
||
|
|
chunk = await process.stdout.read(100)
|
||
|
|
if not chunk:
|
||
|
|
break
|
||
|
|
|
||
|
|
text = chunk.decode()
|
||
|
|
full_response += text
|
||
|
|
yield {"type": "text", "content": text}
|
||
|
|
|
||
|
|
await process.wait()
|
||
|
|
|
||
|
|
if process.returncode != 0:
|
||
|
|
stderr = await process.stderr.read()
|
||
|
|
error_msg = stderr.decode() if stderr else "Unknown error"
|
||
|
|
yield {"type": "error", "message": f"CLI error: {error_msg}"}
|
||
|
|
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
yield {"type": "error", "message": "Response timeout"}
|
||
|
|
except FileNotFoundError:
|
||
|
|
yield {"type": "error", "message": "Claude CLI not found in PATH"}
|
||
|
|
except Exception as e:
|
||
|
|
yield {"type": "error", "message": str(e)}
|
||
|
|
|
||
|
|
# Store assistant response
|
||
|
|
if full_response:
|
||
|
|
self.store.add_message(
|
||
|
|
session_id,
|
||
|
|
"assistant",
|
||
|
|
full_response.strip(),
|
||
|
|
tool_calls=tool_calls if tool_calls else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
yield {"type": "done", "tool_calls": tool_calls}
|
||
|
|
|
||
|
|
async def switch_mode(
|
||
|
|
self,
|
||
|
|
session_id: str,
|
||
|
|
new_mode: Literal["user", "power"],
|
||
|
|
) -> ClaudeSession:
|
||
|
|
"""
|
||
|
|
Switch a session's mode (requires restart).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session_id: Session to switch
|
||
|
|
new_mode: New mode ("user" or "power")
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
New ClaudeSession with updated mode
|
||
|
|
"""
|
||
|
|
session = self.sessions.get(session_id)
|
||
|
|
if not session:
|
||
|
|
raise ValueError(f"Session {session_id} not found")
|
||
|
|
|
||
|
|
study_id = session.study_id
|
||
|
|
|
||
|
|
# Terminate existing session
|
||
|
|
await self._terminate_session(session)
|
||
|
|
|
||
|
|
# Create new session with same ID but different mode
|
||
|
|
return await self.create_session(
|
||
|
|
mode=new_mode,
|
||
|
|
study_id=study_id,
|
||
|
|
resume_session_id=session_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def set_study_context(
|
||
|
|
self,
|
||
|
|
session_id: str,
|
||
|
|
study_id: str,
|
||
|
|
):
|
||
|
|
"""Update the study context for a session"""
|
||
|
|
session = self.sessions.get(session_id)
|
||
|
|
if session:
|
||
|
|
session.study_id = study_id
|
||
|
|
self.store.update_session(session_id, study_id=study_id)
|
||
|
|
|
||
|
|
# If session is alive, send context update
|
||
|
|
if session.is_alive() and session.process:
|
||
|
|
context_update = self.context_builder.build_study_context(study_id)
|
||
|
|
context_msg = f"[CONTEXT UPDATE] Study changed to: {study_id}\n\n{context_update}"
|
||
|
|
try:
|
||
|
|
session.process.stdin.write(f"{context_msg}\n".encode())
|
||
|
|
await session.process.stdin.drain()
|
||
|
|
except Exception:
|
||
|
|
pass # Ignore errors for context updates
|
||
|
|
|
||
|
|
def get_session(self, session_id: str) -> Optional[ClaudeSession]:
|
||
|
|
"""Get session by ID"""
|
||
|
|
return self.sessions.get(session_id)
|
||
|
|
|
||
|
|
def get_session_info(self, session_id: str) -> Optional[Dict]:
|
||
|
|
"""Get session info including database record"""
|
||
|
|
session = self.sessions.get(session_id)
|
||
|
|
if not session:
|
||
|
|
return None
|
||
|
|
|
||
|
|
db_record = self.store.get_session(session_id)
|
||
|
|
return {
|
||
|
|
"session_id": session.session_id,
|
||
|
|
"mode": session.mode,
|
||
|
|
"study_id": session.study_id,
|
||
|
|
"is_alive": session.is_alive(),
|
||
|
|
"created_at": session.created_at.isoformat(),
|
||
|
|
"last_active": session.last_active.isoformat(),
|
||
|
|
"message_count": self.store.get_message_count(session_id),
|
||
|
|
**({} if not db_record else {"db_record": db_record}),
|
||
|
|
}
|
||
|
|
|
||
|
|
def _build_mcp_config(self, mode: Literal["user", "power"]) -> dict:
|
||
|
|
"""Build MCP configuration for Claude"""
|
||
|
|
return {
|
||
|
|
"mcpServers": {
|
||
|
|
"atomizer": {
|
||
|
|
"command": "node",
|
||
|
|
"args": [str(MCP_SERVER_PATH / "dist" / "index.js")],
|
||
|
|
"env": {
|
||
|
|
"ATOMIZER_MODE": mode,
|
||
|
|
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
async def _terminate_session(self, session: ClaudeSession):
|
||
|
|
"""Terminate a Claude session and clean up"""
|
||
|
|
if session.process and session.is_alive():
|
||
|
|
session.process.terminate()
|
||
|
|
try:
|
||
|
|
await asyncio.wait_for(session.process.wait(), timeout=5.0)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
session.process.kill()
|
||
|
|
await session.process.wait()
|
||
|
|
|
||
|
|
# Clean up temp files
|
||
|
|
for pattern in [
|
||
|
|
f".claude-mcp-{session.session_id}.json",
|
||
|
|
f".claude-prompt-{session.session_id}.md",
|
||
|
|
]:
|
||
|
|
path = ATOMIZER_ROOT / pattern
|
||
|
|
if path.exists():
|
||
|
|
try:
|
||
|
|
path.unlink()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Remove from active sessions
|
||
|
|
self.sessions.pop(session.session_id, None)
|
||
|
|
|
||
|
|
async def _cleanup_loop(self):
|
||
|
|
"""Periodically clean up stale sessions"""
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
await asyncio.sleep(300) # Every 5 minutes
|
||
|
|
|
||
|
|
now = datetime.now()
|
||
|
|
stale = [
|
||
|
|
sid
|
||
|
|
for sid, session in list(self.sessions.items())
|
||
|
|
if (now - session.last_active).total_seconds() > 3600 # 1 hour
|
||
|
|
]
|
||
|
|
|
||
|
|
for sid in stale:
|
||
|
|
session = self.sessions.get(sid)
|
||
|
|
if session:
|
||
|
|
await self._terminate_session(session)
|
||
|
|
|
||
|
|
# Also clean up database
|
||
|
|
self.store.cleanup_stale_sessions(max_age_hours=24)
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception:
|
||
|
|
pass # Continue cleanup loop on errors
|
||
|
|
|
||
|
|
|
||
|
|
# Global instance for the application
|
||
|
|
_session_manager: Optional[SessionManager] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_session_manager() -> SessionManager:
|
||
|
|
"""Get or create the global session manager instance"""
|
||
|
|
global _session_manager
|
||
|
|
if _session_manager is None:
|
||
|
|
_session_manager = SessionManager()
|
||
|
|
return _session_manager
|