feat(canvas): Claude Code integration with streaming, snippets, and live preview

Backend:
- Add POST /generate-extractor for AI code generation via Claude CLI
- Add POST /generate-extractor/stream for SSE streaming generation
- Add POST /validate-extractor with enhanced syntax checking
- Add POST /check-dependencies for import analysis
- Add POST /test-extractor for live OP2 file testing
- Add ClaudeCodeSession service for managing CLI sessions

Frontend:
- Add lib/api/claude.ts with typed API functions
- Enhance CodeEditorPanel with:
  - Streaming generation with live preview
  - Code snippets library (6 templates: displacement, stress, frequency, mass, energy, reaction)
  - Test button for live OP2 validation
  - Cancel button for stopping generation
  - Dependency warnings display
- Integrate streaming and testing into NodeConfigPanelV2

Uses Claude CLI (--print mode) to leverage Pro/Max subscription without API costs.
This commit is contained in:
2026-01-20 13:08:12 -05:00
parent ffd41e3a60
commit b05412f807
5 changed files with 2311 additions and 49 deletions

View File

@@ -0,0 +1,894 @@
"""
Claude Code WebSocket Routes
Provides WebSocket endpoint that connects to actual Claude Code CLI.
This gives dashboard users the same power as terminal Claude Code users.
Unlike the MCP-based approach in claude.py:
- Spawns actual Claude Code CLI processes
- Full file editing capabilities
- Full command execution
- Opus 4.5 model with unlimited tool use
Also provides single-shot endpoints for code generation:
- POST /generate-extractor: Generate Python extractor code
- POST /validate-extractor: Validate Python syntax
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Body
from pydantic import BaseModel
from typing import Dict, Optional, List
import json
import asyncio
import re
import os
from pathlib import Path
from api.services.claude_code_session import (
get_claude_code_manager,
ClaudeCodeSession,
ATOMIZER_ROOT,
)
router = APIRouter(prefix="/claude-code", tags=["Claude Code"])
# ==================== Extractor Code Generation ====================
class ExtractorGenerationRequest(BaseModel):
"""Request model for extractor code generation"""
prompt: str # User's description
study_id: Optional[str] = None # Study context
existing_code: Optional[str] = None # Current code to improve
output_names: List[str] = [] # Expected outputs
class ExtractorGenerationResponse(BaseModel):
"""Response model for generated code"""
code: str # Generated Python code
outputs: List[str] # Detected output names
explanation: Optional[str] = None # Brief explanation
class CodeValidationRequest(BaseModel):
"""Request model for code validation"""
code: str
class CodeValidationResponse(BaseModel):
"""Response model for validation result"""
valid: bool
error: Optional[str] = None
@router.post("/generate-extractor", response_model=ExtractorGenerationResponse)
async def generate_extractor_code(request: ExtractorGenerationRequest):
"""
Generate Python extractor code using Claude Code CLI.
Uses --print mode for single-shot generation (no session state).
Focused system prompt for fast, accurate results.
Args:
request: ExtractorGenerationRequest with prompt and context
Returns:
ExtractorGenerationResponse with generated code and detected outputs
"""
# Build focused system prompt for extractor generation
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
The function MUST:
1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict
2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3})
3. Use pyNastran.op2.op2.OP2 for reading OP2 results
4. Handle missing data gracefully with try/except blocks
Available imports (already available, just use them):
- from pyNastran.op2.op2 import OP2
- import numpy as np
- from pathlib import Path
Common patterns:
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components)
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
- Eigenvalues: op2.eigenvalues[subcase_id]
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block."""
# Build user prompt with context
user_prompt = f"Generate a custom extractor that: {request.prompt}"
if request.existing_code:
user_prompt += (
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
)
if request.output_names:
user_prompt += (
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
)
try:
# Call Claude CLI with focused prompt (single-shot, no session)
process = await asyncio.create_subprocess_exec(
"claude",
"--print",
"--system-prompt",
system_prompt,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(ATOMIZER_ROOT),
env={
**os.environ,
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
},
)
# Send prompt and wait for response (60 second timeout)
stdout, stderr = await asyncio.wait_for(
process.communicate(user_prompt.encode("utf-8")), timeout=60.0
)
if process.returncode != 0:
error_text = stderr.decode("utf-8", errors="replace")
raise HTTPException(status_code=500, detail=f"Claude CLI error: {error_text[:500]}")
output = stdout.decode("utf-8", errors="replace")
# Extract Python code from markdown code block
code_match = re.search(r"```python\s*(.*?)\s*```", output, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
else:
# Try to find def extract( directly (Claude might not use code blocks)
if "def extract(" in output:
# Extract from def extract to end of function
code = output.strip()
else:
raise HTTPException(
status_code=500,
detail="Failed to parse generated code - no Python code block found",
)
# Detect output names from return statement
detected_outputs: List[str] = []
return_match = re.search(r"return\s*\{([^}]+)\}", code)
if return_match:
# Parse dict keys like 'max_stress': ... or "mass": ...
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
detected_outputs = key_matches
# Use detected outputs or fall back to requested ones
final_outputs = detected_outputs if detected_outputs else request.output_names
# Extract any explanation text before the code block
explanation = None
parts = output.split("```python")
if len(parts) > 1 and parts[0].strip():
explanation = parts[0].strip()[:300] # First 300 chars max
return ExtractorGenerationResponse(
code=code, outputs=final_outputs, explanation=explanation
)
except asyncio.TimeoutError:
raise HTTPException(
status_code=504, detail="Code generation timed out (60s limit). Try a simpler prompt."
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
class DependencyCheckResponse(BaseModel):
"""Response model for dependency check"""
imports: List[str]
available: List[str]
missing: List[str]
warnings: List[str]
# Known available packages in the atomizer environment
KNOWN_PACKAGES = {
"pyNastran": ["pyNastran", "pyNastran.op2", "pyNastran.bdf"],
"numpy": ["numpy", "np"],
"scipy": ["scipy"],
"pandas": ["pandas", "pd"],
"pathlib": ["pathlib", "Path"],
"json": ["json"],
"os": ["os"],
"re": ["re"],
"math": ["math"],
"typing": ["typing"],
"collections": ["collections"],
"itertools": ["itertools"],
"functools": ["functools"],
}
def extract_imports(code: str) -> List[str]:
"""Extract import statements from Python code using AST"""
import ast
imports = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name.split(".")[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.append(node.module.split(".")[0])
except SyntaxError:
# Fall back to regex if AST fails
import re
import_pattern = r"^(?:from\s+(\w+)|import\s+(\w+))"
for line in code.split("\n"):
match = re.match(import_pattern, line.strip())
if match:
imports.append(match.group(1) or match.group(2))
return list(set(imports))
@router.post("/check-dependencies", response_model=DependencyCheckResponse)
async def check_code_dependencies(request: CodeValidationRequest):
"""
Check which imports in the code are available in the atomizer environment.
Args:
request: CodeValidationRequest with code to check
Returns:
DependencyCheckResponse with available and missing packages
"""
imports = extract_imports(request.code)
available = []
missing = []
warnings = []
# Known available in atomizer
known_available = set()
for pkg, aliases in KNOWN_PACKAGES.items():
known_available.update([a.split(".")[0] for a in aliases])
for imp in imports:
if imp in known_available:
available.append(imp)
else:
# Check if it's a standard library module
try:
import importlib.util
spec = importlib.util.find_spec(imp)
if spec is not None:
available.append(imp)
else:
missing.append(imp)
except (ImportError, ModuleNotFoundError):
missing.append(imp)
# Add warnings for potentially problematic imports
if "matplotlib" in imports:
warnings.append("matplotlib may cause issues in headless NX environment")
if "tensorflow" in imports or "torch" in imports:
warnings.append("Deep learning frameworks may cause memory issues during optimization")
return DependencyCheckResponse(
imports=imports, available=available, missing=missing, warnings=warnings
)
@router.post("/validate-extractor", response_model=CodeValidationResponse)
async def validate_extractor_code(request: CodeValidationRequest):
"""
Validate Python extractor code syntax and structure.
Args:
request: CodeValidationRequest with code to validate
Returns:
CodeValidationResponse with valid flag and optional error message
"""
import ast
try:
tree = ast.parse(request.code)
# Check for extract function
has_extract = False
extract_returns_dict = False
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "extract":
has_extract = True
# Check if it has a return statement
for child in ast.walk(node):
if isinstance(child, ast.Return) and child.value:
if isinstance(child.value, ast.Dict):
extract_returns_dict = True
elif isinstance(child.value, ast.Name):
# Variable return, could be a dict
extract_returns_dict = True
if not has_extract:
return CodeValidationResponse(
valid=False, error="Code must define a function named 'extract'"
)
if not extract_returns_dict:
return CodeValidationResponse(
valid=False, error="extract() function should return a dict"
)
return CodeValidationResponse(valid=True, error=None)
except SyntaxError as e:
return CodeValidationResponse(valid=False, error=f"Line {e.lineno}: {e.msg}")
except Exception as e:
return CodeValidationResponse(valid=False, error=str(e))
# ==================== Live Preview / Test Execution ====================
class TestExtractorRequest(BaseModel):
"""Request model for testing extractor code"""
code: str
study_id: Optional[str] = None
subcase_id: int = 1
class TestExtractorResponse(BaseModel):
"""Response model for extractor test"""
success: bool
outputs: Optional[Dict[str, float]] = None
error: Optional[str] = None
execution_time_ms: Optional[float] = None
@router.post("/test-extractor", response_model=TestExtractorResponse)
async def test_extractor_code(request: TestExtractorRequest):
"""
Test extractor code against a sample or study OP2 file.
This executes the code in a sandboxed environment and returns the results.
If a study_id is provided, it uses the most recent trial's OP2 file.
Otherwise, it uses mock data for testing.
Args:
request: TestExtractorRequest with code and optional study context
Returns:
TestExtractorResponse with extracted outputs or error
"""
import time
import tempfile
import traceback
start_time = time.time()
# Find OP2 file to test against
op2_path = None
fem_path = None
if request.study_id:
# Look for the most recent trial's OP2 file
from pathlib import Path
study_path = ATOMIZER_ROOT / "studies" / request.study_id
if not study_path.exists():
# Try nested path
for parent in (ATOMIZER_ROOT / "studies").iterdir():
if parent.is_dir():
nested = parent / request.study_id
if nested.exists():
study_path = nested
break
if study_path.exists():
# Look in 2_iterations for trial folders
iterations_dir = study_path / "2_iterations"
if iterations_dir.exists():
# Find the latest trial folder with an OP2 file
trial_folders = sorted(
[
d
for d in iterations_dir.iterdir()
if d.is_dir() and d.name.startswith("trial_")
],
reverse=True,
)
for trial_dir in trial_folders:
op2_files = list(trial_dir.glob("*.op2"))
fem_files = list(trial_dir.glob("*.fem"))
if op2_files:
op2_path = str(op2_files[0])
if fem_files:
fem_path = str(fem_files[0])
break
if not op2_path:
# No OP2 file available - run in "dry run" mode with mock
return TestExtractorResponse(
success=False,
error="No OP2 file available for testing. Run at least one optimization trial first.",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Execute the code in a sandboxed way
try:
# Create a temporary module
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(request.code)
temp_file = f.name
try:
# Import the module
import importlib.util
spec = importlib.util.spec_from_file_location("temp_extractor", temp_file)
if spec is None or spec.loader is None:
return TestExtractorResponse(
success=False,
error="Failed to load code as module",
execution_time_ms=(time.time() - start_time) * 1000,
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Check for extract function
if not hasattr(module, "extract"):
return TestExtractorResponse(
success=False,
error="Code does not define an 'extract' function",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Call the extract function
extract_fn = module.extract
result = extract_fn(
op2_path=op2_path,
fem_path=fem_path or "",
params={}, # Empty params for testing
subcase_id=request.subcase_id,
)
if not isinstance(result, dict):
return TestExtractorResponse(
success=False,
error=f"extract() returned {type(result).__name__}, expected dict",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Convert all values to float for JSON serialization
outputs = {}
for k, v in result.items():
try:
outputs[k] = float(v)
except (TypeError, ValueError):
outputs[k] = 0.0 # Can't convert, use 0
return TestExtractorResponse(
success=True, outputs=outputs, execution_time_ms=(time.time() - start_time) * 1000
)
finally:
# Clean up temp file
import os
try:
os.unlink(temp_file)
except:
pass
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
tb = traceback.format_exc()
# Include relevant part of traceback
if "temp_extractor.py" in tb:
lines = tb.split("\n")
relevant = [l for l in lines if "temp_extractor.py" in l or "line" in l.lower()]
if relevant:
error_msg += f"\n{relevant[-1]}"
return TestExtractorResponse(
success=False, error=error_msg, execution_time_ms=(time.time() - start_time) * 1000
)
# ==================== Streaming Generation ====================
from fastapi.responses import StreamingResponse
@router.post("/generate-extractor/stream")
async def generate_extractor_code_stream(request: ExtractorGenerationRequest):
"""
Stream Python extractor code generation using Claude Code CLI.
Uses Server-Sent Events (SSE) to stream tokens as they arrive.
Event types:
- data: {"type": "token", "content": "..."} - Partial code token
- data: {"type": "done", "code": "...", "outputs": [...]} - Final result
- data: {"type": "error", "message": "..."} - Error occurred
Args:
request: ExtractorGenerationRequest with prompt and context
Returns:
StreamingResponse with text/event-stream content type
"""
# Build focused system prompt for extractor generation
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
The function MUST:
1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict
2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3})
3. Use pyNastran.op2.op2.OP2 for reading OP2 results
4. Handle missing data gracefully with try/except blocks
Available imports (already available, just use them):
- from pyNastran.op2.op2 import OP2
- import numpy as np
- from pathlib import Path
Common patterns:
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components)
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
- Eigenvalues: op2.eigenvalues[subcase_id]
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block."""
# Build user prompt with context
user_prompt = f"Generate a custom extractor that: {request.prompt}"
if request.existing_code:
user_prompt += (
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
)
if request.output_names:
user_prompt += (
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
)
async def generate():
full_output = ""
try:
# Call Claude CLI with streaming output
process = await asyncio.create_subprocess_exec(
"claude",
"--print",
"--system-prompt",
system_prompt,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(ATOMIZER_ROOT),
env={
**os.environ,
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
},
)
# Write prompt to stdin and close
process.stdin.write(user_prompt.encode("utf-8"))
await process.stdin.drain()
process.stdin.close()
# Stream stdout chunks as they arrive
while True:
chunk = await asyncio.wait_for(
process.stdout.read(256), # Read in small chunks for responsiveness
timeout=60.0,
)
if not chunk:
break
decoded = chunk.decode("utf-8", errors="replace")
full_output += decoded
# Send token event
yield f"data: {json.dumps({'type': 'token', 'content': decoded})}\n\n"
# Wait for process to complete
await process.wait()
# Check for errors
if process.returncode != 0:
stderr = await process.stderr.read()
error_text = stderr.decode("utf-8", errors="replace")
yield f"data: {json.dumps({'type': 'error', 'message': f'Claude CLI error: {error_text[:500]}'})}\n\n"
return
# Parse the complete output to extract code
code_match = re.search(r"```python\s*(.*?)\s*```", full_output, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
elif "def extract(" in full_output:
code = full_output.strip()
else:
yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse generated code'})}\n\n"
return
# Detect output names
detected_outputs: List[str] = []
return_match = re.search(r"return\s*\{([^}]+)\}", code)
if return_match:
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
detected_outputs = key_matches
final_outputs = detected_outputs if detected_outputs else request.output_names
# Send completion event with parsed code
yield f"data: {json.dumps({'type': 'done', 'code': code, 'outputs': final_outputs})}\n\n"
except asyncio.TimeoutError:
yield f"data: {json.dumps({'type': 'error', 'message': 'Generation timed out (60s limit)'})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
# ==================== Session Management ====================
# Store active WebSocket connections
_active_connections: Dict[str, WebSocket] = {}
@router.post("/sessions")
async def create_claude_code_session(study_id: Optional[str] = None):
"""
Create a new Claude Code session.
Args:
study_id: Optional study to provide context
Returns:
Session info including session_id
"""
try:
manager = get_claude_code_manager()
session = manager.create_session(study_id)
return {
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
"message": "Claude Code session created. Connect via WebSocket to chat.",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sessions/{session_id}")
async def get_claude_code_session(session_id: str):
"""Get session info"""
manager = get_claude_code_manager()
session = manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return {
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
"has_canvas_state": session.canvas_state is not None,
"conversation_length": len(session.conversation_history),
}
@router.delete("/sessions/{session_id}")
async def delete_claude_code_session(session_id: str):
"""Delete a session"""
manager = get_claude_code_manager()
manager.remove_session(session_id)
return {"message": "Session deleted"}
@router.websocket("/ws")
async def claude_code_websocket(websocket: WebSocket):
"""
WebSocket for full Claude Code CLI access (no session required).
This is a simplified endpoint that creates a session per connection.
Message formats (client -> server):
{"type": "init", "study_id": "optional_study_name"}
{"type": "message", "content": "user message"}
{"type": "set_canvas", "canvas_state": {...}}
{"type": "ping"}
Message formats (server -> client):
{"type": "initialized", "session_id": "...", "study_id": "..."}
{"type": "text", "content": "..."}
{"type": "done"}
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
{"type": "error", "content": "..."}
{"type": "pong"}
"""
print("[ClaudeCode WS] Connection attempt received")
await websocket.accept()
print("[ClaudeCode WS] WebSocket accepted")
manager = get_claude_code_manager()
session: Optional[ClaudeCodeSession] = None
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "init":
# Create or reinitialize session
study_id = data.get("study_id")
session = manager.create_session(study_id)
_active_connections[session.session_id] = websocket
await websocket.send_json(
{
"type": "initialized",
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
}
)
elif msg_type == "message":
if not session:
# Auto-create session if not initialized
session = manager.create_session()
_active_connections[session.session_id] = websocket
content = data.get("content", "")
if not content:
continue
# Update canvas state if provided with message
if data.get("canvas_state"):
session.set_canvas_state(data["canvas_state"])
# Stream response from Claude Code CLI
async for chunk in session.send_message(content):
await websocket.send_json(chunk)
elif msg_type == "set_canvas":
if session:
session.set_canvas_state(data.get("canvas_state", {}))
await websocket.send_json(
{
"type": "canvas_updated",
}
)
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
# Clean up on disconnect
if session:
_active_connections.pop(session.session_id, None)
# Keep session in manager for potential reconnect
except Exception as e:
try:
await websocket.send_json(
{
"type": "error",
"content": str(e),
}
)
except:
pass
if session:
_active_connections.pop(session.session_id, None)
@router.websocket("/ws/{study_id:path}")
async def claude_code_websocket_with_study(websocket: WebSocket, study_id: str):
"""
WebSocket for Claude Code CLI with study context.
Same as /ws but automatically initializes with the given study.
Message formats (client -> server):
{"type": "message", "content": "user message"}
{"type": "set_canvas", "canvas_state": {...}}
{"type": "ping"}
Message formats (server -> client):
{"type": "initialized", "session_id": "...", "study_id": "..."}
{"type": "text", "content": "..."}
{"type": "done"}
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
{"type": "error", "content": "..."}
{"type": "pong"}
"""
print(f"[ClaudeCode WS] Connection attempt received for study: {study_id}")
await websocket.accept()
print(f"[ClaudeCode WS] WebSocket accepted for study: {study_id}")
manager = get_claude_code_manager()
session = manager.create_session(study_id)
_active_connections[session.session_id] = websocket
# Send initialization message
await websocket.send_json(
{
"type": "initialized",
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
}
)
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "message":
content = data.get("content", "")
if not content:
continue
# Update canvas state if provided with message
if data.get("canvas_state"):
session.set_canvas_state(data["canvas_state"])
# Stream response from Claude Code CLI
async for chunk in session.send_message(content):
await websocket.send_json(chunk)
elif msg_type == "set_canvas":
session.set_canvas_state(data.get("canvas_state", {}))
await websocket.send_json(
{
"type": "canvas_updated",
}
)
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
_active_connections.pop(session.session_id, None)
except Exception as e:
try:
await websocket.send_json(
{
"type": "error",
"content": str(e),
}
)
except:
pass
_active_connections.pop(session.session_id, None)

View File

@@ -0,0 +1,451 @@
"""
Claude Code CLI Session Manager
Spawns actual Claude Code CLI processes with full Atomizer access.
This gives dashboard users the same power as terminal users.
Unlike the MCP-based approach:
- Claude can actually edit files (not just return instructions)
- Claude can run Python scripts
- Claude can execute git commands
- Full Opus 4.5 capabilities
"""
import asyncio
import json
import os
import uuid
from pathlib import Path
from typing import AsyncGenerator, Dict, Optional, Any
# Atomizer paths
ATOMIZER_ROOT = Path(__file__).parent.parent.parent.parent.parent
STUDIES_DIR = ATOMIZER_ROOT / "studies"
class ClaudeCodeSession:
"""
Manages a Claude Code CLI session with full capabilities.
Unlike MCP tools, this spawns the actual claude CLI which has:
- Full file system access
- Full command execution
- Opus 4.5 model
- All Claude Code capabilities
"""
def __init__(self, session_id: str, study_id: Optional[str] = None):
self.session_id = session_id
self.study_id = study_id
self.canvas_state: Optional[Dict] = None
self.conversation_history: list = []
# Determine working directory
self.working_dir = ATOMIZER_ROOT
if study_id:
# Handle nested study paths like "M1_Mirror/m1_mirror_flatback_lateral"
study_path = STUDIES_DIR / study_id
if study_path.exists():
self.working_dir = study_path
else:
# Try finding it in subdirectories
for parent in STUDIES_DIR.iterdir():
if parent.is_dir():
nested_path = parent / study_id
if nested_path.exists():
self.working_dir = nested_path
break
def set_canvas_state(self, canvas_state: Dict):
"""Update canvas state from frontend"""
self.canvas_state = canvas_state
async def send_message(self, message: str) -> AsyncGenerator[Dict[str, Any], None]:
"""
Send message to Claude Code CLI and stream response.
Uses claude CLI with:
- --print for output
- --dangerously-skip-permissions for full access (controlled environment)
- Runs from Atomizer root to get CLAUDE.md context automatically
- Study-specific context injected into prompt
Yields:
Dict messages: {"type": "text", "content": "..."} or {"type": "done"}
"""
# Build comprehensive prompt with all context
full_prompt = self._build_full_prompt(message)
# Create MCP config file for the session
mcp_config_file = ATOMIZER_ROOT / f".claude-mcp-{self.session_id}.json"
mcp_config = {
"mcpServers": {
"atomizer-tools": {
"command": "npx",
"args": ["-y", "ts-node", str(ATOMIZER_ROOT / "atomizer-dashboard" / "mcp-server" / "src" / "index.ts")],
"cwd": str(ATOMIZER_ROOT / "atomizer-dashboard" / "mcp-server"),
"env": {
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
"STUDIES_DIR": str(STUDIES_DIR),
}
}
}
}
mcp_config_file.write_text(json.dumps(mcp_config, indent=2), encoding='utf-8')
try:
# Spawn claude CLI from ATOMIZER_ROOT so it picks up CLAUDE.md
# This gives it full Atomizer context automatically
# Note: prompt is passed via stdin for complex multi-line prompts
process = await asyncio.create_subprocess_exec(
"claude",
"--print",
"--dangerously-skip-permissions",
"--mcp-config", str(mcp_config_file),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
cwd=str(ATOMIZER_ROOT),
env={
**os.environ,
"ATOMIZER_STUDY": self.study_id or "",
"ATOMIZER_STUDY_PATH": str(self.working_dir),
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
}
)
# Write prompt to stdin
process.stdin.write(full_prompt.encode('utf-8'))
await process.stdin.drain()
process.stdin.close()
# Read and yield output as it comes
full_output = ""
# Stream stdout
while True:
chunk = await process.stdout.read(512)
if not chunk:
break
decoded = chunk.decode('utf-8', errors='replace')
full_output += decoded
yield {"type": "text", "content": decoded}
# Wait for process to complete
await process.wait()
# Check for errors
stderr = await process.stderr.read()
if stderr and process.returncode != 0:
error_text = stderr.decode('utf-8', errors='replace')
yield {"type": "error", "content": f"\n[Error]: {error_text}"}
# Update conversation history
self.conversation_history.append({"role": "user", "content": message})
self.conversation_history.append({"role": "assistant", "content": full_output})
# Signal completion
yield {"type": "done"}
# Check if any files were modified and signal canvas refresh
if self._output_indicates_file_changes(full_output):
yield {
"type": "refresh_canvas",
"study_id": self.study_id,
"reason": "Claude modified study files"
}
finally:
# Clean up temp files
if mcp_config_file.exists():
try:
mcp_config_file.unlink()
except:
pass
def _build_full_prompt(self, message: str) -> str:
"""Build comprehensive prompt with all context"""
parts = []
# Study context
study_context = self._build_study_context() if self.study_id else ""
if study_context:
parts.append("## Current Study Context")
parts.append(study_context)
# Canvas context
if self.canvas_state:
canvas_context = self._build_canvas_context()
if canvas_context:
parts.append("## Current Canvas State")
parts.append(canvas_context)
# Conversation history (last few exchanges)
if self.conversation_history:
parts.append("## Recent Conversation")
for msg in self.conversation_history[-6:]:
role = "User" if msg["role"] == "user" else "Assistant"
# Truncate long messages
content = msg["content"][:500] + "..." if len(msg["content"]) > 500 else msg["content"]
parts.append(f"**{role}:** {content}")
parts.append("")
# User's actual request
parts.append("## User Request")
parts.append(message)
parts.append("")
# Critical instruction
parts.append("## Important")
parts.append("You have FULL power to edit files in this environment. When asked to make changes:")
parts.append("1. Use the Edit or Write tools to ACTUALLY MODIFY the files")
parts.append("2. Show a brief summary of what you changed")
parts.append("3. Do not just describe changes - MAKE THEM")
parts.append("")
parts.append("After making changes to optimization_config.json, the dashboard canvas will auto-refresh.")
return "\n".join(parts)
def _build_study_context(self) -> str:
"""Build detailed context for the active study"""
if not self.study_id:
return ""
context_parts = [f"**Study ID:** `{self.study_id}`"]
context_parts.append(f"**Study Path:** `{self.working_dir}`")
context_parts.append("")
# Find and read optimization_config.json
config_path = self.working_dir / "1_setup" / "optimization_config.json"
if not config_path.exists():
config_path = self.working_dir / "optimization_config.json"
if config_path.exists():
try:
config = json.loads(config_path.read_text(encoding='utf-8'))
context_parts.append(f"**Config File:** `{config_path.relative_to(ATOMIZER_ROOT)}`")
context_parts.append("")
# Design variables summary
dvs = config.get("design_variables", [])
if dvs:
context_parts.append("### Design Variables")
context_parts.append("")
context_parts.append("| Name | Min | Max | Baseline | Unit |")
context_parts.append("|------|-----|-----|----------|------|")
for dv in dvs[:15]:
name = dv.get("name", dv.get("expression_name", "?"))
min_v = dv.get("min", dv.get("lower", "?"))
max_v = dv.get("max", dv.get("upper", "?"))
baseline = dv.get("baseline", "-")
unit = dv.get("units", dv.get("unit", "-"))
context_parts.append(f"| {name} | {min_v} | {max_v} | {baseline} | {unit} |")
if len(dvs) > 15:
context_parts.append(f"\n*... and {len(dvs) - 15} more*")
context_parts.append("")
# Objectives
objs = config.get("objectives", [])
if objs:
context_parts.append("### Objectives")
context_parts.append("")
for obj in objs:
name = obj.get("name", "?")
direction = obj.get("direction", "minimize")
weight = obj.get("weight", 1)
context_parts.append(f"- **{name}**: {direction} (weight: {weight})")
context_parts.append("")
# Extraction method (for Zernike)
ext_method = config.get("extraction_method", {})
if ext_method:
context_parts.append("### Extraction Method")
context_parts.append("")
context_parts.append(f"- Type: `{ext_method.get('type', '?')}`")
context_parts.append(f"- Class: `{ext_method.get('class', '?')}`")
if ext_method.get("inner_radius"):
context_parts.append(f"- Inner Radius: `{ext_method.get('inner_radius')}`")
context_parts.append("")
# Zernike settings
zernike = config.get("zernike_settings", {})
if zernike:
context_parts.append("### Zernike Settings")
context_parts.append("")
context_parts.append(f"- Modes: `{zernike.get('n_modes', '?')}`")
context_parts.append(f"- Filter Low Orders: `{zernike.get('filter_low_orders', '?')}`")
context_parts.append(f"- Subcases: `{zernike.get('subcases', [])}`")
context_parts.append("")
# Algorithm
method = config.get("method", config.get("optimization", {}).get("sampler", "TPE"))
max_trials = config.get("max_trials", config.get("optimization", {}).get("n_trials", 100))
context_parts.append("### Algorithm")
context_parts.append("")
context_parts.append(f"- Method: `{method}`")
context_parts.append(f"- Max Trials: `{max_trials}`")
context_parts.append("")
except Exception as e:
context_parts.append(f"*Error reading config: {e}*")
context_parts.append("")
else:
context_parts.append("*No optimization_config.json found*")
context_parts.append("")
# Check for run_optimization.py
run_opt_path = self.working_dir / "run_optimization.py"
if run_opt_path.exists():
context_parts.append(f"**Run Script:** `{run_opt_path.relative_to(ATOMIZER_ROOT)}` (exists)")
else:
context_parts.append("**Run Script:** not found")
context_parts.append("")
# Check results
db_path = self.working_dir / "3_results" / "study.db"
if not db_path.exists():
db_path = self.working_dir / "2_results" / "study.db"
if db_path.exists():
context_parts.append("**Results Database:** exists")
# Could query trial count here
else:
context_parts.append("**Results Database:** not found (no optimization run yet)")
return "\n".join(context_parts)
def _build_canvas_context(self) -> str:
"""Build markdown context from canvas state"""
if not self.canvas_state:
return ""
parts = []
nodes = self.canvas_state.get("nodes", [])
edges = self.canvas_state.get("edges", [])
if not nodes:
return "*Canvas is empty*"
# Group nodes by type
design_vars = [n for n in nodes if n.get("type") == "designVar"]
objectives = [n for n in nodes if n.get("type") == "objective"]
extractors = [n for n in nodes if n.get("type") == "extractor"]
models = [n for n in nodes if n.get("type") == "nxModel"]
algorithms = [n for n in nodes if n.get("type") == "algorithm"]
if models:
parts.append("### NX Model")
for m in models:
data = m.get("data", {})
parts.append(f"- File: `{data.get('filePath', 'Not set')}`")
parts.append("")
if design_vars:
parts.append("### Design Variables (Canvas)")
parts.append("")
parts.append("| Name | Min | Max | Baseline |")
parts.append("|------|-----|-----|----------|")
for dv in design_vars[:20]:
data = dv.get("data", {})
name = data.get("expressionName") or data.get("label", "?")
min_v = data.get("minValue", "?")
max_v = data.get("maxValue", "?")
baseline = data.get("baseline", "-")
parts.append(f"| {name} | {min_v} | {max_v} | {baseline} |")
if len(design_vars) > 20:
parts.append(f"\n*... and {len(design_vars) - 20} more*")
parts.append("")
if extractors:
parts.append("### Extractors (Canvas)")
parts.append("")
for ext in extractors:
data = ext.get("data", {})
ext_type = data.get("extractorType") or data.get("extractorId", "?")
label = data.get("label", "?")
parts.append(f"- **{label}**: `{ext_type}`")
parts.append("")
if objectives:
parts.append("### Objectives (Canvas)")
parts.append("")
for obj in objectives:
data = obj.get("data", {})
name = data.get("objectiveName") or data.get("label", "?")
direction = data.get("direction", "minimize")
weight = data.get("weight", 1)
parts.append(f"- **{name}**: {direction} (weight: {weight})")
parts.append("")
if algorithms:
parts.append("### Algorithm (Canvas)")
for alg in algorithms:
data = alg.get("data", {})
method = data.get("method", "?")
trials = data.get("maxTrials", "?")
parts.append(f"- Method: `{method}`")
parts.append(f"- Max Trials: `{trials}`")
parts.append("")
return "\n".join(parts)
def _output_indicates_file_changes(self, output: str) -> bool:
"""Check if Claude's output indicates file modifications"""
indicators = [
"✓ Edited",
"✓ Wrote",
"Successfully wrote",
"Successfully edited",
"Modified:",
"Updated:",
"Added to file",
"optimization_config.json", # Common target
"run_optimization.py", # Common target
]
output_lower = output.lower()
return any(indicator.lower() in output_lower for indicator in indicators)
class ClaudeCodeSessionManager:
"""
Manages multiple Claude Code sessions.
Each session is independent and can have different study contexts.
"""
def __init__(self):
self.sessions: Dict[str, ClaudeCodeSession] = {}
def create_session(self, study_id: Optional[str] = None) -> ClaudeCodeSession:
"""Create a new Claude Code session"""
session_id = str(uuid.uuid4())[:8]
session = ClaudeCodeSession(session_id, study_id)
self.sessions[session_id] = session
return session
def get_session(self, session_id: str) -> Optional[ClaudeCodeSession]:
"""Get an existing session"""
return self.sessions.get(session_id)
def remove_session(self, session_id: str):
"""Remove a session"""
self.sessions.pop(session_id, None)
def set_canvas_state(self, session_id: str, canvas_state: Dict):
"""Update canvas state for a session"""
session = self.sessions.get(session_id)
if session:
session.set_canvas_state(canvas_state)
# Global session manager instance
_session_manager: Optional[ClaudeCodeSessionManager] = None
def get_claude_code_manager() -> ClaudeCodeSessionManager:
"""Get the global session manager"""
global _session_manager
if _session_manager is None:
_session_manager = ClaudeCodeSessionManager()
return _session_manager