Files
Atomizer/atomizer-dashboard/backend/api/routes/claude_code.py
Anto01 b05412f807 feat(canvas): Claude Code integration with streaming, snippets, and live preview
Backend:
- Add POST /generate-extractor for AI code generation via Claude CLI
- Add POST /generate-extractor/stream for SSE streaming generation
- Add POST /validate-extractor with enhanced syntax checking
- Add POST /check-dependencies for import analysis
- Add POST /test-extractor for live OP2 file testing
- Add ClaudeCodeSession service for managing CLI sessions

Frontend:
- Add lib/api/claude.ts with typed API functions
- Enhance CodeEditorPanel with:
  - Streaming generation with live preview
  - Code snippets library (6 templates: displacement, stress, frequency, mass, energy, reaction)
  - Test button for live OP2 validation
  - Cancel button for stopping generation
  - Dependency warnings display
- Integrate streaming and testing into NodeConfigPanelV2

Uses Claude CLI (--print mode) to leverage Pro/Max subscription without API costs.
2026-01-20 13:08:12 -05:00

895 lines
30 KiB
Python

"""
Claude Code WebSocket Routes
Provides WebSocket endpoint that connects to actual Claude Code CLI.
This gives dashboard users the same power as terminal Claude Code users.
Unlike the MCP-based approach in claude.py:
- Spawns actual Claude Code CLI processes
- Full file editing capabilities
- Full command execution
- Opus 4.5 model with unlimited tool use
Also provides single-shot endpoints for code generation:
- POST /generate-extractor: Generate Python extractor code
- POST /validate-extractor: Validate Python syntax
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Body
from pydantic import BaseModel
from typing import Dict, Optional, List
import json
import asyncio
import re
import os
from pathlib import Path
from api.services.claude_code_session import (
get_claude_code_manager,
ClaudeCodeSession,
ATOMIZER_ROOT,
)
router = APIRouter(prefix="/claude-code", tags=["Claude Code"])
# ==================== Extractor Code Generation ====================
class ExtractorGenerationRequest(BaseModel):
"""Request model for extractor code generation"""
prompt: str # User's description
study_id: Optional[str] = None # Study context
existing_code: Optional[str] = None # Current code to improve
output_names: List[str] = [] # Expected outputs
class ExtractorGenerationResponse(BaseModel):
"""Response model for generated code"""
code: str # Generated Python code
outputs: List[str] # Detected output names
explanation: Optional[str] = None # Brief explanation
class CodeValidationRequest(BaseModel):
"""Request model for code validation"""
code: str
class CodeValidationResponse(BaseModel):
"""Response model for validation result"""
valid: bool
error: Optional[str] = None
@router.post("/generate-extractor", response_model=ExtractorGenerationResponse)
async def generate_extractor_code(request: ExtractorGenerationRequest):
"""
Generate Python extractor code using Claude Code CLI.
Uses --print mode for single-shot generation (no session state).
Focused system prompt for fast, accurate results.
Args:
request: ExtractorGenerationRequest with prompt and context
Returns:
ExtractorGenerationResponse with generated code and detected outputs
"""
# Build focused system prompt for extractor generation
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
The function MUST:
1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict
2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3})
3. Use pyNastran.op2.op2.OP2 for reading OP2 results
4. Handle missing data gracefully with try/except blocks
Available imports (already available, just use them):
- from pyNastran.op2.op2 import OP2
- import numpy as np
- from pathlib import Path
Common patterns:
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components)
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
- Eigenvalues: op2.eigenvalues[subcase_id]
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block."""
# Build user prompt with context
user_prompt = f"Generate a custom extractor that: {request.prompt}"
if request.existing_code:
user_prompt += (
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
)
if request.output_names:
user_prompt += (
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
)
try:
# Call Claude CLI with focused prompt (single-shot, no session)
process = await asyncio.create_subprocess_exec(
"claude",
"--print",
"--system-prompt",
system_prompt,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(ATOMIZER_ROOT),
env={
**os.environ,
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
},
)
# Send prompt and wait for response (60 second timeout)
stdout, stderr = await asyncio.wait_for(
process.communicate(user_prompt.encode("utf-8")), timeout=60.0
)
if process.returncode != 0:
error_text = stderr.decode("utf-8", errors="replace")
raise HTTPException(status_code=500, detail=f"Claude CLI error: {error_text[:500]}")
output = stdout.decode("utf-8", errors="replace")
# Extract Python code from markdown code block
code_match = re.search(r"```python\s*(.*?)\s*```", output, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
else:
# Try to find def extract( directly (Claude might not use code blocks)
if "def extract(" in output:
# Extract from def extract to end of function
code = output.strip()
else:
raise HTTPException(
status_code=500,
detail="Failed to parse generated code - no Python code block found",
)
# Detect output names from return statement
detected_outputs: List[str] = []
return_match = re.search(r"return\s*\{([^}]+)\}", code)
if return_match:
# Parse dict keys like 'max_stress': ... or "mass": ...
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
detected_outputs = key_matches
# Use detected outputs or fall back to requested ones
final_outputs = detected_outputs if detected_outputs else request.output_names
# Extract any explanation text before the code block
explanation = None
parts = output.split("```python")
if len(parts) > 1 and parts[0].strip():
explanation = parts[0].strip()[:300] # First 300 chars max
return ExtractorGenerationResponse(
code=code, outputs=final_outputs, explanation=explanation
)
except asyncio.TimeoutError:
raise HTTPException(
status_code=504, detail="Code generation timed out (60s limit). Try a simpler prompt."
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
class DependencyCheckResponse(BaseModel):
"""Response model for dependency check"""
imports: List[str]
available: List[str]
missing: List[str]
warnings: List[str]
# Known available packages in the atomizer environment
KNOWN_PACKAGES = {
"pyNastran": ["pyNastran", "pyNastran.op2", "pyNastran.bdf"],
"numpy": ["numpy", "np"],
"scipy": ["scipy"],
"pandas": ["pandas", "pd"],
"pathlib": ["pathlib", "Path"],
"json": ["json"],
"os": ["os"],
"re": ["re"],
"math": ["math"],
"typing": ["typing"],
"collections": ["collections"],
"itertools": ["itertools"],
"functools": ["functools"],
}
def extract_imports(code: str) -> List[str]:
"""Extract import statements from Python code using AST"""
import ast
imports = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name.split(".")[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.append(node.module.split(".")[0])
except SyntaxError:
# Fall back to regex if AST fails
import re
import_pattern = r"^(?:from\s+(\w+)|import\s+(\w+))"
for line in code.split("\n"):
match = re.match(import_pattern, line.strip())
if match:
imports.append(match.group(1) or match.group(2))
return list(set(imports))
@router.post("/check-dependencies", response_model=DependencyCheckResponse)
async def check_code_dependencies(request: CodeValidationRequest):
"""
Check which imports in the code are available in the atomizer environment.
Args:
request: CodeValidationRequest with code to check
Returns:
DependencyCheckResponse with available and missing packages
"""
imports = extract_imports(request.code)
available = []
missing = []
warnings = []
# Known available in atomizer
known_available = set()
for pkg, aliases in KNOWN_PACKAGES.items():
known_available.update([a.split(".")[0] for a in aliases])
for imp in imports:
if imp in known_available:
available.append(imp)
else:
# Check if it's a standard library module
try:
import importlib.util
spec = importlib.util.find_spec(imp)
if spec is not None:
available.append(imp)
else:
missing.append(imp)
except (ImportError, ModuleNotFoundError):
missing.append(imp)
# Add warnings for potentially problematic imports
if "matplotlib" in imports:
warnings.append("matplotlib may cause issues in headless NX environment")
if "tensorflow" in imports or "torch" in imports:
warnings.append("Deep learning frameworks may cause memory issues during optimization")
return DependencyCheckResponse(
imports=imports, available=available, missing=missing, warnings=warnings
)
@router.post("/validate-extractor", response_model=CodeValidationResponse)
async def validate_extractor_code(request: CodeValidationRequest):
"""
Validate Python extractor code syntax and structure.
Args:
request: CodeValidationRequest with code to validate
Returns:
CodeValidationResponse with valid flag and optional error message
"""
import ast
try:
tree = ast.parse(request.code)
# Check for extract function
has_extract = False
extract_returns_dict = False
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "extract":
has_extract = True
# Check if it has a return statement
for child in ast.walk(node):
if isinstance(child, ast.Return) and child.value:
if isinstance(child.value, ast.Dict):
extract_returns_dict = True
elif isinstance(child.value, ast.Name):
# Variable return, could be a dict
extract_returns_dict = True
if not has_extract:
return CodeValidationResponse(
valid=False, error="Code must define a function named 'extract'"
)
if not extract_returns_dict:
return CodeValidationResponse(
valid=False, error="extract() function should return a dict"
)
return CodeValidationResponse(valid=True, error=None)
except SyntaxError as e:
return CodeValidationResponse(valid=False, error=f"Line {e.lineno}: {e.msg}")
except Exception as e:
return CodeValidationResponse(valid=False, error=str(e))
# ==================== Live Preview / Test Execution ====================
class TestExtractorRequest(BaseModel):
"""Request model for testing extractor code"""
code: str
study_id: Optional[str] = None
subcase_id: int = 1
class TestExtractorResponse(BaseModel):
"""Response model for extractor test"""
success: bool
outputs: Optional[Dict[str, float]] = None
error: Optional[str] = None
execution_time_ms: Optional[float] = None
@router.post("/test-extractor", response_model=TestExtractorResponse)
async def test_extractor_code(request: TestExtractorRequest):
"""
Test extractor code against a sample or study OP2 file.
This executes the code in a sandboxed environment and returns the results.
If a study_id is provided, it uses the most recent trial's OP2 file.
Otherwise, it uses mock data for testing.
Args:
request: TestExtractorRequest with code and optional study context
Returns:
TestExtractorResponse with extracted outputs or error
"""
import time
import tempfile
import traceback
start_time = time.time()
# Find OP2 file to test against
op2_path = None
fem_path = None
if request.study_id:
# Look for the most recent trial's OP2 file
from pathlib import Path
study_path = ATOMIZER_ROOT / "studies" / request.study_id
if not study_path.exists():
# Try nested path
for parent in (ATOMIZER_ROOT / "studies").iterdir():
if parent.is_dir():
nested = parent / request.study_id
if nested.exists():
study_path = nested
break
if study_path.exists():
# Look in 2_iterations for trial folders
iterations_dir = study_path / "2_iterations"
if iterations_dir.exists():
# Find the latest trial folder with an OP2 file
trial_folders = sorted(
[
d
for d in iterations_dir.iterdir()
if d.is_dir() and d.name.startswith("trial_")
],
reverse=True,
)
for trial_dir in trial_folders:
op2_files = list(trial_dir.glob("*.op2"))
fem_files = list(trial_dir.glob("*.fem"))
if op2_files:
op2_path = str(op2_files[0])
if fem_files:
fem_path = str(fem_files[0])
break
if not op2_path:
# No OP2 file available - run in "dry run" mode with mock
return TestExtractorResponse(
success=False,
error="No OP2 file available for testing. Run at least one optimization trial first.",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Execute the code in a sandboxed way
try:
# Create a temporary module
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(request.code)
temp_file = f.name
try:
# Import the module
import importlib.util
spec = importlib.util.spec_from_file_location("temp_extractor", temp_file)
if spec is None or spec.loader is None:
return TestExtractorResponse(
success=False,
error="Failed to load code as module",
execution_time_ms=(time.time() - start_time) * 1000,
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Check for extract function
if not hasattr(module, "extract"):
return TestExtractorResponse(
success=False,
error="Code does not define an 'extract' function",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Call the extract function
extract_fn = module.extract
result = extract_fn(
op2_path=op2_path,
fem_path=fem_path or "",
params={}, # Empty params for testing
subcase_id=request.subcase_id,
)
if not isinstance(result, dict):
return TestExtractorResponse(
success=False,
error=f"extract() returned {type(result).__name__}, expected dict",
execution_time_ms=(time.time() - start_time) * 1000,
)
# Convert all values to float for JSON serialization
outputs = {}
for k, v in result.items():
try:
outputs[k] = float(v)
except (TypeError, ValueError):
outputs[k] = 0.0 # Can't convert, use 0
return TestExtractorResponse(
success=True, outputs=outputs, execution_time_ms=(time.time() - start_time) * 1000
)
finally:
# Clean up temp file
import os
try:
os.unlink(temp_file)
except:
pass
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
tb = traceback.format_exc()
# Include relevant part of traceback
if "temp_extractor.py" in tb:
lines = tb.split("\n")
relevant = [l for l in lines if "temp_extractor.py" in l or "line" in l.lower()]
if relevant:
error_msg += f"\n{relevant[-1]}"
return TestExtractorResponse(
success=False, error=error_msg, execution_time_ms=(time.time() - start_time) * 1000
)
# ==================== Streaming Generation ====================
from fastapi.responses import StreamingResponse
@router.post("/generate-extractor/stream")
async def generate_extractor_code_stream(request: ExtractorGenerationRequest):
"""
Stream Python extractor code generation using Claude Code CLI.
Uses Server-Sent Events (SSE) to stream tokens as they arrive.
Event types:
- data: {"type": "token", "content": "..."} - Partial code token
- data: {"type": "done", "code": "...", "outputs": [...]} - Final result
- data: {"type": "error", "message": "..."} - Error occurred
Args:
request: ExtractorGenerationRequest with prompt and context
Returns:
StreamingResponse with text/event-stream content type
"""
# Build focused system prompt for extractor generation
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
The function MUST:
1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict
2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3})
3. Use pyNastran.op2.op2.OP2 for reading OP2 results
4. Handle missing data gracefully with try/except blocks
Available imports (already available, just use them):
- from pyNastran.op2.op2 import OP2
- import numpy as np
- from pathlib import Path
Common patterns:
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components)
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
- Eigenvalues: op2.eigenvalues[subcase_id]
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block."""
# Build user prompt with context
user_prompt = f"Generate a custom extractor that: {request.prompt}"
if request.existing_code:
user_prompt += (
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
)
if request.output_names:
user_prompt += (
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
)
async def generate():
full_output = ""
try:
# Call Claude CLI with streaming output
process = await asyncio.create_subprocess_exec(
"claude",
"--print",
"--system-prompt",
system_prompt,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(ATOMIZER_ROOT),
env={
**os.environ,
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
},
)
# Write prompt to stdin and close
process.stdin.write(user_prompt.encode("utf-8"))
await process.stdin.drain()
process.stdin.close()
# Stream stdout chunks as they arrive
while True:
chunk = await asyncio.wait_for(
process.stdout.read(256), # Read in small chunks for responsiveness
timeout=60.0,
)
if not chunk:
break
decoded = chunk.decode("utf-8", errors="replace")
full_output += decoded
# Send token event
yield f"data: {json.dumps({'type': 'token', 'content': decoded})}\n\n"
# Wait for process to complete
await process.wait()
# Check for errors
if process.returncode != 0:
stderr = await process.stderr.read()
error_text = stderr.decode("utf-8", errors="replace")
yield f"data: {json.dumps({'type': 'error', 'message': f'Claude CLI error: {error_text[:500]}'})}\n\n"
return
# Parse the complete output to extract code
code_match = re.search(r"```python\s*(.*?)\s*```", full_output, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
elif "def extract(" in full_output:
code = full_output.strip()
else:
yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse generated code'})}\n\n"
return
# Detect output names
detected_outputs: List[str] = []
return_match = re.search(r"return\s*\{([^}]+)\}", code)
if return_match:
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
detected_outputs = key_matches
final_outputs = detected_outputs if detected_outputs else request.output_names
# Send completion event with parsed code
yield f"data: {json.dumps({'type': 'done', 'code': code, 'outputs': final_outputs})}\n\n"
except asyncio.TimeoutError:
yield f"data: {json.dumps({'type': 'error', 'message': 'Generation timed out (60s limit)'})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
# ==================== Session Management ====================
# Store active WebSocket connections
_active_connections: Dict[str, WebSocket] = {}
@router.post("/sessions")
async def create_claude_code_session(study_id: Optional[str] = None):
"""
Create a new Claude Code session.
Args:
study_id: Optional study to provide context
Returns:
Session info including session_id
"""
try:
manager = get_claude_code_manager()
session = manager.create_session(study_id)
return {
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
"message": "Claude Code session created. Connect via WebSocket to chat.",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sessions/{session_id}")
async def get_claude_code_session(session_id: str):
"""Get session info"""
manager = get_claude_code_manager()
session = manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return {
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
"has_canvas_state": session.canvas_state is not None,
"conversation_length": len(session.conversation_history),
}
@router.delete("/sessions/{session_id}")
async def delete_claude_code_session(session_id: str):
"""Delete a session"""
manager = get_claude_code_manager()
manager.remove_session(session_id)
return {"message": "Session deleted"}
@router.websocket("/ws")
async def claude_code_websocket(websocket: WebSocket):
"""
WebSocket for full Claude Code CLI access (no session required).
This is a simplified endpoint that creates a session per connection.
Message formats (client -> server):
{"type": "init", "study_id": "optional_study_name"}
{"type": "message", "content": "user message"}
{"type": "set_canvas", "canvas_state": {...}}
{"type": "ping"}
Message formats (server -> client):
{"type": "initialized", "session_id": "...", "study_id": "..."}
{"type": "text", "content": "..."}
{"type": "done"}
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
{"type": "error", "content": "..."}
{"type": "pong"}
"""
print("[ClaudeCode WS] Connection attempt received")
await websocket.accept()
print("[ClaudeCode WS] WebSocket accepted")
manager = get_claude_code_manager()
session: Optional[ClaudeCodeSession] = None
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "init":
# Create or reinitialize session
study_id = data.get("study_id")
session = manager.create_session(study_id)
_active_connections[session.session_id] = websocket
await websocket.send_json(
{
"type": "initialized",
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
}
)
elif msg_type == "message":
if not session:
# Auto-create session if not initialized
session = manager.create_session()
_active_connections[session.session_id] = websocket
content = data.get("content", "")
if not content:
continue
# Update canvas state if provided with message
if data.get("canvas_state"):
session.set_canvas_state(data["canvas_state"])
# Stream response from Claude Code CLI
async for chunk in session.send_message(content):
await websocket.send_json(chunk)
elif msg_type == "set_canvas":
if session:
session.set_canvas_state(data.get("canvas_state", {}))
await websocket.send_json(
{
"type": "canvas_updated",
}
)
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
# Clean up on disconnect
if session:
_active_connections.pop(session.session_id, None)
# Keep session in manager for potential reconnect
except Exception as e:
try:
await websocket.send_json(
{
"type": "error",
"content": str(e),
}
)
except:
pass
if session:
_active_connections.pop(session.session_id, None)
@router.websocket("/ws/{study_id:path}")
async def claude_code_websocket_with_study(websocket: WebSocket, study_id: str):
"""
WebSocket for Claude Code CLI with study context.
Same as /ws but automatically initializes with the given study.
Message formats (client -> server):
{"type": "message", "content": "user message"}
{"type": "set_canvas", "canvas_state": {...}}
{"type": "ping"}
Message formats (server -> client):
{"type": "initialized", "session_id": "...", "study_id": "..."}
{"type": "text", "content": "..."}
{"type": "done"}
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
{"type": "error", "content": "..."}
{"type": "pong"}
"""
print(f"[ClaudeCode WS] Connection attempt received for study: {study_id}")
await websocket.accept()
print(f"[ClaudeCode WS] WebSocket accepted for study: {study_id}")
manager = get_claude_code_manager()
session = manager.create_session(study_id)
_active_connections[session.session_id] = websocket
# Send initialization message
await websocket.send_json(
{
"type": "initialized",
"session_id": session.session_id,
"study_id": session.study_id,
"working_dir": str(session.working_dir),
}
)
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "message":
content = data.get("content", "")
if not content:
continue
# Update canvas state if provided with message
if data.get("canvas_state"):
session.set_canvas_state(data["canvas_state"])
# Stream response from Claude Code CLI
async for chunk in session.send_message(content):
await websocket.send_json(chunk)
elif msg_type == "set_canvas":
session.set_canvas_state(data.get("canvas_state", {}))
await websocket.send_json(
{
"type": "canvas_updated",
}
)
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
_active_connections.pop(session.session_id, None)
except Exception as e:
try:
await websocket.send_json(
{
"type": "error",
"content": str(e),
}
)
except:
pass
_active_connections.pop(session.session_id, None)