""" Claude Code WebSocket Routes Provides WebSocket endpoint that connects to actual Claude Code CLI. This gives dashboard users the same power as terminal Claude Code users. Unlike the MCP-based approach in claude.py: - Spawns actual Claude Code CLI processes - Full file editing capabilities - Full command execution - Opus 4.5 model with unlimited tool use Also provides single-shot endpoints for code generation: - POST /generate-extractor: Generate Python extractor code - POST /validate-extractor: Validate Python syntax """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Body from pydantic import BaseModel from typing import Dict, Optional, List import json import asyncio import re import os from pathlib import Path from api.services.claude_code_session import ( get_claude_code_manager, ClaudeCodeSession, ATOMIZER_ROOT, ) router = APIRouter(prefix="/claude-code", tags=["Claude Code"]) # ==================== Extractor Code Generation ==================== class ExtractorGenerationRequest(BaseModel): """Request model for extractor code generation""" prompt: str # User's description study_id: Optional[str] = None # Study context existing_code: Optional[str] = None # Current code to improve output_names: List[str] = [] # Expected outputs class ExtractorGenerationResponse(BaseModel): """Response model for generated code""" code: str # Generated Python code outputs: List[str] # Detected output names explanation: Optional[str] = None # Brief explanation class CodeValidationRequest(BaseModel): """Request model for code validation""" code: str class CodeValidationResponse(BaseModel): """Response model for validation result""" valid: bool error: Optional[str] = None @router.post("/generate-extractor", response_model=ExtractorGenerationResponse) async def generate_extractor_code(request: ExtractorGenerationRequest): """ Generate Python extractor code using Claude Code CLI. Uses --print mode for single-shot generation (no session state). Focused system prompt for fast, accurate results. Args: request: ExtractorGenerationRequest with prompt and context Returns: ExtractorGenerationResponse with generated code and detected outputs """ # Build focused system prompt for extractor generation system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization. IMPORTANT: Choose the appropriate function signature based on what data is needed: ## Option 1: FEA Results (OP2) - Use for stresses, displacements, frequencies, forces ```python def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict: from pyNastran.op2.op2 import OP2 op2 = OP2() op2.read_op2(op2_path) # Access: op2.displacements[subcase_id], op2.cquad4_stress[subcase_id], etc. return {"max_stress": value} ``` ## Option 2: Expression/Computed Values (no FEA needed) - Use for dimensions, volumes, derived values ```python def extract(trial_dir: str, config: dict, context: dict) -> dict: import json from pathlib import Path # Read mass properties (if available from model introspection) mass_file = Path(trial_dir) / "mass_properties.json" if mass_file.exists(): with open(mass_file) as f: props = json.load(f) mass = props.get("mass_kg", 0) # Or use config values directly (e.g., expression values) length_mm = config.get("length_expression", 100) # context has results from other extractors other_value = context.get("other_extractor_output", 0) return {"computed_value": length_mm * 2} ``` Available imports: pyNastran.op2.op2.OP2, numpy, pathlib.Path, json Common OP2 patterns: - Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z) - Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id] - Eigenvalues: op2.eigenvalues[subcase_id] - Mass: op2.grid_point_weight (if available) Return ONLY the complete Python code wrapped in ```python ... ```. No explanations.""" # Build user prompt with context user_prompt = f"Generate a custom extractor that: {request.prompt}" if request.existing_code: user_prompt += ( f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```" ) if request.output_names: user_prompt += ( f"\n\nThe function should output these keys: {', '.join(request.output_names)}" ) try: # Call Claude CLI with focused prompt (single-shot, no session) process = await asyncio.create_subprocess_exec( "claude", "--print", "--system-prompt", system_prompt, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(ATOMIZER_ROOT), env={ **os.environ, "ATOMIZER_ROOT": str(ATOMIZER_ROOT), }, ) # Send prompt and wait for response (60 second timeout) stdout, stderr = await asyncio.wait_for( process.communicate(user_prompt.encode("utf-8")), timeout=60.0 ) if process.returncode != 0: error_text = stderr.decode("utf-8", errors="replace") raise HTTPException(status_code=500, detail=f"Claude CLI error: {error_text[:500]}") output = stdout.decode("utf-8", errors="replace") # Extract Python code from markdown code block code_match = re.search(r"```python\s*(.*?)\s*```", output, re.DOTALL) if code_match: code = code_match.group(1).strip() else: # Try to find def extract( directly (Claude might not use code blocks) if "def extract(" in output: # Extract from def extract to end of function code = output.strip() else: raise HTTPException( status_code=500, detail="Failed to parse generated code - no Python code block found", ) # Detect output names from return statement detected_outputs: List[str] = [] return_match = re.search(r"return\s*\{([^}]+)\}", code) if return_match: # Parse dict keys like 'max_stress': ... or "mass": ... key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1)) detected_outputs = key_matches # Use detected outputs or fall back to requested ones final_outputs = detected_outputs if detected_outputs else request.output_names # Extract any explanation text before the code block explanation = None parts = output.split("```python") if len(parts) > 1 and parts[0].strip(): explanation = parts[0].strip()[:300] # First 300 chars max return ExtractorGenerationResponse( code=code, outputs=final_outputs, explanation=explanation ) except asyncio.TimeoutError: raise HTTPException( status_code=504, detail="Code generation timed out (60s limit). Try a simpler prompt." ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") class DependencyCheckResponse(BaseModel): """Response model for dependency check""" imports: List[str] available: List[str] missing: List[str] warnings: List[str] # Known available packages in the atomizer environment KNOWN_PACKAGES = { "pyNastran": ["pyNastran", "pyNastran.op2", "pyNastran.bdf"], "numpy": ["numpy", "np"], "scipy": ["scipy"], "pandas": ["pandas", "pd"], "pathlib": ["pathlib", "Path"], "json": ["json"], "os": ["os"], "re": ["re"], "math": ["math"], "typing": ["typing"], "collections": ["collections"], "itertools": ["itertools"], "functools": ["functools"], } def extract_imports(code: str) -> List[str]: """Extract import statements from Python code using AST""" import ast imports = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: imports.append(alias.name.split(".")[0]) elif isinstance(node, ast.ImportFrom): if node.module: imports.append(node.module.split(".")[0]) except SyntaxError: # Fall back to regex if AST fails import re import_pattern = r"^(?:from\s+(\w+)|import\s+(\w+))" for line in code.split("\n"): match = re.match(import_pattern, line.strip()) if match: imports.append(match.group(1) or match.group(2)) return list(set(imports)) @router.post("/check-dependencies", response_model=DependencyCheckResponse) async def check_code_dependencies(request: CodeValidationRequest): """ Check which imports in the code are available in the atomizer environment. Args: request: CodeValidationRequest with code to check Returns: DependencyCheckResponse with available and missing packages """ imports = extract_imports(request.code) available = [] missing = [] warnings = [] # Known available in atomizer known_available = set() for pkg, aliases in KNOWN_PACKAGES.items(): known_available.update([a.split(".")[0] for a in aliases]) for imp in imports: if imp in known_available: available.append(imp) else: # Check if it's a standard library module try: import importlib.util spec = importlib.util.find_spec(imp) if spec is not None: available.append(imp) else: missing.append(imp) except (ImportError, ModuleNotFoundError): missing.append(imp) # Add warnings for potentially problematic imports if "matplotlib" in imports: warnings.append("matplotlib may cause issues in headless NX environment") if "tensorflow" in imports or "torch" in imports: warnings.append("Deep learning frameworks may cause memory issues during optimization") return DependencyCheckResponse( imports=imports, available=available, missing=missing, warnings=warnings ) @router.post("/validate-extractor", response_model=CodeValidationResponse) async def validate_extractor_code(request: CodeValidationRequest): """ Validate Python extractor code syntax and structure. Args: request: CodeValidationRequest with code to validate Returns: CodeValidationResponse with valid flag and optional error message """ import ast try: tree = ast.parse(request.code) # Check for extract function has_extract = False extract_returns_dict = False for node in ast.walk(tree): if isinstance(node, ast.FunctionDef) and node.name == "extract": has_extract = True # Check if it has a return statement for child in ast.walk(node): if isinstance(child, ast.Return) and child.value: if isinstance(child.value, ast.Dict): extract_returns_dict = True elif isinstance(child.value, ast.Name): # Variable return, could be a dict extract_returns_dict = True if not has_extract: return CodeValidationResponse( valid=False, error="Code must define a function named 'extract'" ) if not extract_returns_dict: return CodeValidationResponse( valid=False, error="extract() function should return a dict" ) return CodeValidationResponse(valid=True, error=None) except SyntaxError as e: return CodeValidationResponse(valid=False, error=f"Line {e.lineno}: {e.msg}") except Exception as e: return CodeValidationResponse(valid=False, error=str(e)) # ==================== Live Preview / Test Execution ==================== class TestExtractorRequest(BaseModel): """Request model for testing extractor code""" code: str study_id: Optional[str] = None subcase_id: int = 1 class TestExtractorResponse(BaseModel): """Response model for extractor test""" success: bool outputs: Optional[Dict[str, float]] = None error: Optional[str] = None execution_time_ms: Optional[float] = None @router.post("/test-extractor", response_model=TestExtractorResponse) async def test_extractor_code(request: TestExtractorRequest): """ Test extractor code against a sample or study OP2 file. This executes the code in a sandboxed environment and returns the results. If a study_id is provided, it uses the most recent trial's OP2 file. Otherwise, it uses mock data for testing. Args: request: TestExtractorRequest with code and optional study context Returns: TestExtractorResponse with extracted outputs or error """ import time import tempfile import traceback start_time = time.time() # Find OP2 file to test against op2_path = None fem_path = None if request.study_id: # Look for the most recent trial's OP2 file from pathlib import Path study_path = ATOMIZER_ROOT / "studies" / request.study_id if not study_path.exists(): # Try nested path for parent in (ATOMIZER_ROOT / "studies").iterdir(): if parent.is_dir(): nested = parent / request.study_id if nested.exists(): study_path = nested break if study_path.exists(): # Look in 2_iterations for trial folders iterations_dir = study_path / "2_iterations" if iterations_dir.exists(): # Find the latest trial folder with an OP2 file trial_folders = sorted( [ d for d in iterations_dir.iterdir() if d.is_dir() and d.name.startswith("trial_") ], reverse=True, ) for trial_dir in trial_folders: op2_files = list(trial_dir.glob("*.op2")) fem_files = list(trial_dir.glob("*.fem")) if op2_files: op2_path = str(op2_files[0]) if fem_files: fem_path = str(fem_files[0]) break if not op2_path: # No OP2 file available - run in "dry run" mode with mock return TestExtractorResponse( success=False, error="No OP2 file available for testing. Run at least one optimization trial first.", execution_time_ms=(time.time() - start_time) * 1000, ) # Execute the code in a sandboxed way try: # Create a temporary module with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write(request.code) temp_file = f.name try: # Import the module import importlib.util spec = importlib.util.spec_from_file_location("temp_extractor", temp_file) if spec is None or spec.loader is None: return TestExtractorResponse( success=False, error="Failed to load code as module", execution_time_ms=(time.time() - start_time) * 1000, ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Check for extract function if not hasattr(module, "extract"): return TestExtractorResponse( success=False, error="Code does not define an 'extract' function", execution_time_ms=(time.time() - start_time) * 1000, ) # Call the extract function extract_fn = module.extract result = extract_fn( op2_path=op2_path, fem_path=fem_path or "", params={}, # Empty params for testing subcase_id=request.subcase_id, ) if not isinstance(result, dict): return TestExtractorResponse( success=False, error=f"extract() returned {type(result).__name__}, expected dict", execution_time_ms=(time.time() - start_time) * 1000, ) # Convert all values to float for JSON serialization outputs = {} for k, v in result.items(): try: outputs[k] = float(v) except (TypeError, ValueError): outputs[k] = 0.0 # Can't convert, use 0 return TestExtractorResponse( success=True, outputs=outputs, execution_time_ms=(time.time() - start_time) * 1000 ) finally: # Clean up temp file import os try: os.unlink(temp_file) except: pass except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" tb = traceback.format_exc() # Include relevant part of traceback if "temp_extractor.py" in tb: lines = tb.split("\n") relevant = [l for l in lines if "temp_extractor.py" in l or "line" in l.lower()] if relevant: error_msg += f"\n{relevant[-1]}" return TestExtractorResponse( success=False, error=error_msg, execution_time_ms=(time.time() - start_time) * 1000 ) # ==================== Streaming Generation ==================== from fastapi.responses import StreamingResponse @router.post("/generate-extractor/stream") async def generate_extractor_code_stream(request: ExtractorGenerationRequest): """ Stream Python extractor code generation using Claude Code CLI. Uses Server-Sent Events (SSE) to stream tokens as they arrive. Event types: - data: {"type": "token", "content": "..."} - Partial code token - data: {"type": "done", "code": "...", "outputs": [...]} - Final result - data: {"type": "error", "message": "..."} - Error occurred Args: request: ExtractorGenerationRequest with prompt and context Returns: StreamingResponse with text/event-stream content type """ # Build focused system prompt for extractor generation system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization. The function MUST: 1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict 2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3}) 3. Use pyNastran.op2.op2.OP2 for reading OP2 results 4. Handle missing data gracefully with try/except blocks Available imports (already available, just use them): - from pyNastran.op2.op2 import OP2 - import numpy as np - from pathlib import Path Common patterns: - Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components) - Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id] - Eigenvalues: op2.eigenvalues[subcase_id] Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block.""" # Build user prompt with context user_prompt = f"Generate a custom extractor that: {request.prompt}" if request.existing_code: user_prompt += ( f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```" ) if request.output_names: user_prompt += ( f"\n\nThe function should output these keys: {', '.join(request.output_names)}" ) async def generate(): full_output = "" try: # Call Claude CLI with streaming output process = await asyncio.create_subprocess_exec( "claude", "--print", "--system-prompt", system_prompt, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(ATOMIZER_ROOT), env={ **os.environ, "ATOMIZER_ROOT": str(ATOMIZER_ROOT), }, ) # Write prompt to stdin and close process.stdin.write(user_prompt.encode("utf-8")) await process.stdin.drain() process.stdin.close() # Stream stdout chunks as they arrive while True: chunk = await asyncio.wait_for( process.stdout.read(256), # Read in small chunks for responsiveness timeout=60.0, ) if not chunk: break decoded = chunk.decode("utf-8", errors="replace") full_output += decoded # Send token event yield f"data: {json.dumps({'type': 'token', 'content': decoded})}\n\n" # Wait for process to complete await process.wait() # Check for errors if process.returncode != 0: stderr = await process.stderr.read() error_text = stderr.decode("utf-8", errors="replace") yield f"data: {json.dumps({'type': 'error', 'message': f'Claude CLI error: {error_text[:500]}'})}\n\n" return # Parse the complete output to extract code code_match = re.search(r"```python\s*(.*?)\s*```", full_output, re.DOTALL) if code_match: code = code_match.group(1).strip() elif "def extract(" in full_output: code = full_output.strip() else: yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse generated code'})}\n\n" return # Detect output names detected_outputs: List[str] = [] return_match = re.search(r"return\s*\{([^}]+)\}", code) if return_match: key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1)) detected_outputs = key_matches final_outputs = detected_outputs if detected_outputs else request.output_names # Send completion event with parsed code yield f"data: {json.dumps({'type': 'done', 'code': code, 'outputs': final_outputs})}\n\n" except asyncio.TimeoutError: yield f"data: {json.dumps({'type': 'error', 'message': 'Generation timed out (60s limit)'})}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # Disable nginx buffering }, ) # ==================== Session Management ==================== # Store active WebSocket connections _active_connections: Dict[str, WebSocket] = {} @router.post("/sessions") async def create_claude_code_session(study_id: Optional[str] = None): """ Create a new Claude Code session. Args: study_id: Optional study to provide context Returns: Session info including session_id """ try: manager = get_claude_code_manager() session = manager.create_session(study_id) return { "session_id": session.session_id, "study_id": session.study_id, "working_dir": str(session.working_dir), "message": "Claude Code session created. Connect via WebSocket to chat.", } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/sessions/{session_id}") async def get_claude_code_session(session_id: str): """Get session info""" manager = get_claude_code_manager() session = manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return { "session_id": session.session_id, "study_id": session.study_id, "working_dir": str(session.working_dir), "has_canvas_state": session.canvas_state is not None, "conversation_length": len(session.conversation_history), } @router.delete("/sessions/{session_id}") async def delete_claude_code_session(session_id: str): """Delete a session""" manager = get_claude_code_manager() manager.remove_session(session_id) return {"message": "Session deleted"} @router.websocket("/ws") async def claude_code_websocket(websocket: WebSocket): """ WebSocket for full Claude Code CLI access (no session required). This is a simplified endpoint that creates a session per connection. Message formats (client -> server): {"type": "init", "study_id": "optional_study_name"} {"type": "message", "content": "user message"} {"type": "set_canvas", "canvas_state": {...}} {"type": "ping"} Message formats (server -> client): {"type": "initialized", "session_id": "...", "study_id": "..."} {"type": "text", "content": "..."} {"type": "done"} {"type": "refresh_canvas", "study_id": "...", "reason": "..."} {"type": "error", "content": "..."} {"type": "pong"} """ print("[ClaudeCode WS] Connection attempt received") await websocket.accept() print("[ClaudeCode WS] WebSocket accepted") manager = get_claude_code_manager() session: Optional[ClaudeCodeSession] = None try: while True: data = await websocket.receive_json() msg_type = data.get("type") if msg_type == "init": # Create or reinitialize session study_id = data.get("study_id") session = manager.create_session(study_id) _active_connections[session.session_id] = websocket await websocket.send_json( { "type": "initialized", "session_id": session.session_id, "study_id": session.study_id, "working_dir": str(session.working_dir), } ) elif msg_type == "message": if not session: # Auto-create session if not initialized session = manager.create_session() _active_connections[session.session_id] = websocket content = data.get("content", "") if not content: continue # Update canvas state if provided with message if data.get("canvas_state"): session.set_canvas_state(data["canvas_state"]) # Stream response from Claude Code CLI async for chunk in session.send_message(content): await websocket.send_json(chunk) elif msg_type == "set_canvas": if session: session.set_canvas_state(data.get("canvas_state", {})) await websocket.send_json( { "type": "canvas_updated", } ) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: # Clean up on disconnect if session: _active_connections.pop(session.session_id, None) # Keep session in manager for potential reconnect except Exception as e: try: await websocket.send_json( { "type": "error", "content": str(e), } ) except: pass if session: _active_connections.pop(session.session_id, None) @router.websocket("/ws/{study_id:path}") async def claude_code_websocket_with_study(websocket: WebSocket, study_id: str): """ WebSocket for Claude Code CLI with study context. Same as /ws but automatically initializes with the given study. Message formats (client -> server): {"type": "message", "content": "user message"} {"type": "set_canvas", "canvas_state": {...}} {"type": "ping"} Message formats (server -> client): {"type": "initialized", "session_id": "...", "study_id": "..."} {"type": "text", "content": "..."} {"type": "done"} {"type": "refresh_canvas", "study_id": "...", "reason": "..."} {"type": "error", "content": "..."} {"type": "pong"} """ print(f"[ClaudeCode WS] Connection attempt received for study: {study_id}") await websocket.accept() print(f"[ClaudeCode WS] WebSocket accepted for study: {study_id}") manager = get_claude_code_manager() session = manager.create_session(study_id) _active_connections[session.session_id] = websocket # Send initialization message await websocket.send_json( { "type": "initialized", "session_id": session.session_id, "study_id": session.study_id, "working_dir": str(session.working_dir), } ) try: while True: data = await websocket.receive_json() msg_type = data.get("type") if msg_type == "message": content = data.get("content", "") if not content: continue # Update canvas state if provided with message if data.get("canvas_state"): session.set_canvas_state(data["canvas_state"]) # Stream response from Claude Code CLI async for chunk in session.send_message(content): await websocket.send_json(chunk) elif msg_type == "set_canvas": session.set_canvas_state(data.get("canvas_state", {})) await websocket.send_json( { "type": "canvas_updated", } ) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: _active_connections.pop(session.session_id, None) except Exception as e: try: await websocket.send_json( { "type": "error", "content": str(e), } ) except: pass _active_connections.pop(session.session_id, None)