Drag-drop fixes: - Fix Objective default data: use nested 'source' object with extractor_id/output_name - Fix Constraint default data: use 'type' field (not constraint_type), 'threshold' (not limit) Undo/Redo fixes: - Remove dependency on isDirty flag (which is always false due to auto-save) - Record snapshots based on actual spec changes via deep comparison Code generation improvements: - Update system prompt to support multiple extractor types: * OP2-based extractors for FEA results (stress, displacement, frequency) * Expression-based extractors for NX model values (dimensions, volumes) * Computed extractors for derived values (no FEA needed) - Claude will now choose appropriate signature based on user's description
921 lines
31 KiB
Python
921 lines
31 KiB
Python
"""
|
|
Claude Code WebSocket Routes
|
|
|
|
Provides WebSocket endpoint that connects to actual Claude Code CLI.
|
|
This gives dashboard users the same power as terminal Claude Code users.
|
|
|
|
Unlike the MCP-based approach in claude.py:
|
|
- Spawns actual Claude Code CLI processes
|
|
- Full file editing capabilities
|
|
- Full command execution
|
|
- Opus 4.5 model with unlimited tool use
|
|
|
|
Also provides single-shot endpoints for code generation:
|
|
- POST /generate-extractor: Generate Python extractor code
|
|
- POST /validate-extractor: Validate Python syntax
|
|
"""
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Body
|
|
from pydantic import BaseModel
|
|
from typing import Dict, Optional, List
|
|
import json
|
|
import asyncio
|
|
import re
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from api.services.claude_code_session import (
|
|
get_claude_code_manager,
|
|
ClaudeCodeSession,
|
|
ATOMIZER_ROOT,
|
|
)
|
|
|
|
router = APIRouter(prefix="/claude-code", tags=["Claude Code"])
|
|
|
|
|
|
# ==================== Extractor Code Generation ====================
|
|
|
|
|
|
class ExtractorGenerationRequest(BaseModel):
|
|
"""Request model for extractor code generation"""
|
|
|
|
prompt: str # User's description
|
|
study_id: Optional[str] = None # Study context
|
|
existing_code: Optional[str] = None # Current code to improve
|
|
output_names: List[str] = [] # Expected outputs
|
|
|
|
|
|
class ExtractorGenerationResponse(BaseModel):
|
|
"""Response model for generated code"""
|
|
|
|
code: str # Generated Python code
|
|
outputs: List[str] # Detected output names
|
|
explanation: Optional[str] = None # Brief explanation
|
|
|
|
|
|
class CodeValidationRequest(BaseModel):
|
|
"""Request model for code validation"""
|
|
|
|
code: str
|
|
|
|
|
|
class CodeValidationResponse(BaseModel):
|
|
"""Response model for validation result"""
|
|
|
|
valid: bool
|
|
error: Optional[str] = None
|
|
|
|
|
|
@router.post("/generate-extractor", response_model=ExtractorGenerationResponse)
|
|
async def generate_extractor_code(request: ExtractorGenerationRequest):
|
|
"""
|
|
Generate Python extractor code using Claude Code CLI.
|
|
|
|
Uses --print mode for single-shot generation (no session state).
|
|
Focused system prompt for fast, accurate results.
|
|
|
|
Args:
|
|
request: ExtractorGenerationRequest with prompt and context
|
|
|
|
Returns:
|
|
ExtractorGenerationResponse with generated code and detected outputs
|
|
"""
|
|
# Build focused system prompt for extractor generation
|
|
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
|
|
|
|
IMPORTANT: Choose the appropriate function signature based on what data is needed:
|
|
|
|
## Option 1: FEA Results (OP2) - Use for stresses, displacements, frequencies, forces
|
|
```python
|
|
def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict:
|
|
from pyNastran.op2.op2 import OP2
|
|
op2 = OP2()
|
|
op2.read_op2(op2_path)
|
|
# Access: op2.displacements[subcase_id], op2.cquad4_stress[subcase_id], etc.
|
|
return {"max_stress": value}
|
|
```
|
|
|
|
## Option 2: Expression/Computed Values (no FEA needed) - Use for dimensions, volumes, derived values
|
|
```python
|
|
def extract(trial_dir: str, config: dict, context: dict) -> dict:
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Read mass properties (if available from model introspection)
|
|
mass_file = Path(trial_dir) / "mass_properties.json"
|
|
if mass_file.exists():
|
|
with open(mass_file) as f:
|
|
props = json.load(f)
|
|
mass = props.get("mass_kg", 0)
|
|
|
|
# Or use config values directly (e.g., expression values)
|
|
length_mm = config.get("length_expression", 100)
|
|
|
|
# context has results from other extractors
|
|
other_value = context.get("other_extractor_output", 0)
|
|
|
|
return {"computed_value": length_mm * 2}
|
|
```
|
|
|
|
Available imports: pyNastran.op2.op2.OP2, numpy, pathlib.Path, json
|
|
|
|
Common OP2 patterns:
|
|
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z)
|
|
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
|
|
- Eigenvalues: op2.eigenvalues[subcase_id]
|
|
- Mass: op2.grid_point_weight (if available)
|
|
|
|
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations."""
|
|
|
|
# Build user prompt with context
|
|
user_prompt = f"Generate a custom extractor that: {request.prompt}"
|
|
|
|
if request.existing_code:
|
|
user_prompt += (
|
|
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
|
|
)
|
|
|
|
if request.output_names:
|
|
user_prompt += (
|
|
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
|
|
)
|
|
|
|
try:
|
|
# Call Claude CLI with focused prompt (single-shot, no session)
|
|
process = await asyncio.create_subprocess_exec(
|
|
"claude",
|
|
"--print",
|
|
"--system-prompt",
|
|
system_prompt,
|
|
stdin=asyncio.subprocess.PIPE,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
cwd=str(ATOMIZER_ROOT),
|
|
env={
|
|
**os.environ,
|
|
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
|
|
},
|
|
)
|
|
|
|
# Send prompt and wait for response (60 second timeout)
|
|
stdout, stderr = await asyncio.wait_for(
|
|
process.communicate(user_prompt.encode("utf-8")), timeout=60.0
|
|
)
|
|
|
|
if process.returncode != 0:
|
|
error_text = stderr.decode("utf-8", errors="replace")
|
|
raise HTTPException(status_code=500, detail=f"Claude CLI error: {error_text[:500]}")
|
|
|
|
output = stdout.decode("utf-8", errors="replace")
|
|
|
|
# Extract Python code from markdown code block
|
|
code_match = re.search(r"```python\s*(.*?)\s*```", output, re.DOTALL)
|
|
if code_match:
|
|
code = code_match.group(1).strip()
|
|
else:
|
|
# Try to find def extract( directly (Claude might not use code blocks)
|
|
if "def extract(" in output:
|
|
# Extract from def extract to end of function
|
|
code = output.strip()
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to parse generated code - no Python code block found",
|
|
)
|
|
|
|
# Detect output names from return statement
|
|
detected_outputs: List[str] = []
|
|
return_match = re.search(r"return\s*\{([^}]+)\}", code)
|
|
if return_match:
|
|
# Parse dict keys like 'max_stress': ... or "mass": ...
|
|
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
|
|
detected_outputs = key_matches
|
|
|
|
# Use detected outputs or fall back to requested ones
|
|
final_outputs = detected_outputs if detected_outputs else request.output_names
|
|
|
|
# Extract any explanation text before the code block
|
|
explanation = None
|
|
parts = output.split("```python")
|
|
if len(parts) > 1 and parts[0].strip():
|
|
explanation = parts[0].strip()[:300] # First 300 chars max
|
|
|
|
return ExtractorGenerationResponse(
|
|
code=code, outputs=final_outputs, explanation=explanation
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
raise HTTPException(
|
|
status_code=504, detail="Code generation timed out (60s limit). Try a simpler prompt."
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
class DependencyCheckResponse(BaseModel):
|
|
"""Response model for dependency check"""
|
|
|
|
imports: List[str]
|
|
available: List[str]
|
|
missing: List[str]
|
|
warnings: List[str]
|
|
|
|
|
|
# Known available packages in the atomizer environment
|
|
KNOWN_PACKAGES = {
|
|
"pyNastran": ["pyNastran", "pyNastran.op2", "pyNastran.bdf"],
|
|
"numpy": ["numpy", "np"],
|
|
"scipy": ["scipy"],
|
|
"pandas": ["pandas", "pd"],
|
|
"pathlib": ["pathlib", "Path"],
|
|
"json": ["json"],
|
|
"os": ["os"],
|
|
"re": ["re"],
|
|
"math": ["math"],
|
|
"typing": ["typing"],
|
|
"collections": ["collections"],
|
|
"itertools": ["itertools"],
|
|
"functools": ["functools"],
|
|
}
|
|
|
|
|
|
def extract_imports(code: str) -> List[str]:
|
|
"""Extract import statements from Python code using AST"""
|
|
import ast
|
|
|
|
imports = []
|
|
try:
|
|
tree = ast.parse(code)
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
imports.append(alias.name.split(".")[0])
|
|
elif isinstance(node, ast.ImportFrom):
|
|
if node.module:
|
|
imports.append(node.module.split(".")[0])
|
|
except SyntaxError:
|
|
# Fall back to regex if AST fails
|
|
import re
|
|
|
|
import_pattern = r"^(?:from\s+(\w+)|import\s+(\w+))"
|
|
for line in code.split("\n"):
|
|
match = re.match(import_pattern, line.strip())
|
|
if match:
|
|
imports.append(match.group(1) or match.group(2))
|
|
|
|
return list(set(imports))
|
|
|
|
|
|
@router.post("/check-dependencies", response_model=DependencyCheckResponse)
|
|
async def check_code_dependencies(request: CodeValidationRequest):
|
|
"""
|
|
Check which imports in the code are available in the atomizer environment.
|
|
|
|
Args:
|
|
request: CodeValidationRequest with code to check
|
|
|
|
Returns:
|
|
DependencyCheckResponse with available and missing packages
|
|
"""
|
|
imports = extract_imports(request.code)
|
|
|
|
available = []
|
|
missing = []
|
|
warnings = []
|
|
|
|
# Known available in atomizer
|
|
known_available = set()
|
|
for pkg, aliases in KNOWN_PACKAGES.items():
|
|
known_available.update([a.split(".")[0] for a in aliases])
|
|
|
|
for imp in imports:
|
|
if imp in known_available:
|
|
available.append(imp)
|
|
else:
|
|
# Check if it's a standard library module
|
|
try:
|
|
import importlib.util
|
|
|
|
spec = importlib.util.find_spec(imp)
|
|
if spec is not None:
|
|
available.append(imp)
|
|
else:
|
|
missing.append(imp)
|
|
except (ImportError, ModuleNotFoundError):
|
|
missing.append(imp)
|
|
|
|
# Add warnings for potentially problematic imports
|
|
if "matplotlib" in imports:
|
|
warnings.append("matplotlib may cause issues in headless NX environment")
|
|
if "tensorflow" in imports or "torch" in imports:
|
|
warnings.append("Deep learning frameworks may cause memory issues during optimization")
|
|
|
|
return DependencyCheckResponse(
|
|
imports=imports, available=available, missing=missing, warnings=warnings
|
|
)
|
|
|
|
|
|
@router.post("/validate-extractor", response_model=CodeValidationResponse)
|
|
async def validate_extractor_code(request: CodeValidationRequest):
|
|
"""
|
|
Validate Python extractor code syntax and structure.
|
|
|
|
Args:
|
|
request: CodeValidationRequest with code to validate
|
|
|
|
Returns:
|
|
CodeValidationResponse with valid flag and optional error message
|
|
"""
|
|
import ast
|
|
|
|
try:
|
|
tree = ast.parse(request.code)
|
|
|
|
# Check for extract function
|
|
has_extract = False
|
|
extract_returns_dict = False
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.FunctionDef) and node.name == "extract":
|
|
has_extract = True
|
|
# Check if it has a return statement
|
|
for child in ast.walk(node):
|
|
if isinstance(child, ast.Return) and child.value:
|
|
if isinstance(child.value, ast.Dict):
|
|
extract_returns_dict = True
|
|
elif isinstance(child.value, ast.Name):
|
|
# Variable return, could be a dict
|
|
extract_returns_dict = True
|
|
|
|
if not has_extract:
|
|
return CodeValidationResponse(
|
|
valid=False, error="Code must define a function named 'extract'"
|
|
)
|
|
|
|
if not extract_returns_dict:
|
|
return CodeValidationResponse(
|
|
valid=False, error="extract() function should return a dict"
|
|
)
|
|
|
|
return CodeValidationResponse(valid=True, error=None)
|
|
|
|
except SyntaxError as e:
|
|
return CodeValidationResponse(valid=False, error=f"Line {e.lineno}: {e.msg}")
|
|
except Exception as e:
|
|
return CodeValidationResponse(valid=False, error=str(e))
|
|
|
|
|
|
# ==================== Live Preview / Test Execution ====================
|
|
|
|
|
|
class TestExtractorRequest(BaseModel):
|
|
"""Request model for testing extractor code"""
|
|
|
|
code: str
|
|
study_id: Optional[str] = None
|
|
subcase_id: int = 1
|
|
|
|
|
|
class TestExtractorResponse(BaseModel):
|
|
"""Response model for extractor test"""
|
|
|
|
success: bool
|
|
outputs: Optional[Dict[str, float]] = None
|
|
error: Optional[str] = None
|
|
execution_time_ms: Optional[float] = None
|
|
|
|
|
|
@router.post("/test-extractor", response_model=TestExtractorResponse)
|
|
async def test_extractor_code(request: TestExtractorRequest):
|
|
"""
|
|
Test extractor code against a sample or study OP2 file.
|
|
|
|
This executes the code in a sandboxed environment and returns the results.
|
|
If a study_id is provided, it uses the most recent trial's OP2 file.
|
|
Otherwise, it uses mock data for testing.
|
|
|
|
Args:
|
|
request: TestExtractorRequest with code and optional study context
|
|
|
|
Returns:
|
|
TestExtractorResponse with extracted outputs or error
|
|
"""
|
|
import time
|
|
import tempfile
|
|
import traceback
|
|
|
|
start_time = time.time()
|
|
|
|
# Find OP2 file to test against
|
|
op2_path = None
|
|
fem_path = None
|
|
|
|
if request.study_id:
|
|
# Look for the most recent trial's OP2 file
|
|
from pathlib import Path
|
|
|
|
study_path = ATOMIZER_ROOT / "studies" / request.study_id
|
|
|
|
if not study_path.exists():
|
|
# Try nested path
|
|
for parent in (ATOMIZER_ROOT / "studies").iterdir():
|
|
if parent.is_dir():
|
|
nested = parent / request.study_id
|
|
if nested.exists():
|
|
study_path = nested
|
|
break
|
|
|
|
if study_path.exists():
|
|
# Look in 2_iterations for trial folders
|
|
iterations_dir = study_path / "2_iterations"
|
|
if iterations_dir.exists():
|
|
# Find the latest trial folder with an OP2 file
|
|
trial_folders = sorted(
|
|
[
|
|
d
|
|
for d in iterations_dir.iterdir()
|
|
if d.is_dir() and d.name.startswith("trial_")
|
|
],
|
|
reverse=True,
|
|
)
|
|
for trial_dir in trial_folders:
|
|
op2_files = list(trial_dir.glob("*.op2"))
|
|
fem_files = list(trial_dir.glob("*.fem"))
|
|
if op2_files:
|
|
op2_path = str(op2_files[0])
|
|
if fem_files:
|
|
fem_path = str(fem_files[0])
|
|
break
|
|
|
|
if not op2_path:
|
|
# No OP2 file available - run in "dry run" mode with mock
|
|
return TestExtractorResponse(
|
|
success=False,
|
|
error="No OP2 file available for testing. Run at least one optimization trial first.",
|
|
execution_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
# Execute the code in a sandboxed way
|
|
try:
|
|
# Create a temporary module
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
f.write(request.code)
|
|
temp_file = f.name
|
|
|
|
try:
|
|
# Import the module
|
|
import importlib.util
|
|
|
|
spec = importlib.util.spec_from_file_location("temp_extractor", temp_file)
|
|
if spec is None or spec.loader is None:
|
|
return TestExtractorResponse(
|
|
success=False,
|
|
error="Failed to load code as module",
|
|
execution_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
# Check for extract function
|
|
if not hasattr(module, "extract"):
|
|
return TestExtractorResponse(
|
|
success=False,
|
|
error="Code does not define an 'extract' function",
|
|
execution_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
# Call the extract function
|
|
extract_fn = module.extract
|
|
result = extract_fn(
|
|
op2_path=op2_path,
|
|
fem_path=fem_path or "",
|
|
params={}, # Empty params for testing
|
|
subcase_id=request.subcase_id,
|
|
)
|
|
|
|
if not isinstance(result, dict):
|
|
return TestExtractorResponse(
|
|
success=False,
|
|
error=f"extract() returned {type(result).__name__}, expected dict",
|
|
execution_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
# Convert all values to float for JSON serialization
|
|
outputs = {}
|
|
for k, v in result.items():
|
|
try:
|
|
outputs[k] = float(v)
|
|
except (TypeError, ValueError):
|
|
outputs[k] = 0.0 # Can't convert, use 0
|
|
|
|
return TestExtractorResponse(
|
|
success=True, outputs=outputs, execution_time_ms=(time.time() - start_time) * 1000
|
|
)
|
|
|
|
finally:
|
|
# Clean up temp file
|
|
import os
|
|
|
|
try:
|
|
os.unlink(temp_file)
|
|
except:
|
|
pass
|
|
|
|
except Exception as e:
|
|
error_msg = f"{type(e).__name__}: {str(e)}"
|
|
tb = traceback.format_exc()
|
|
# Include relevant part of traceback
|
|
if "temp_extractor.py" in tb:
|
|
lines = tb.split("\n")
|
|
relevant = [l for l in lines if "temp_extractor.py" in l or "line" in l.lower()]
|
|
if relevant:
|
|
error_msg += f"\n{relevant[-1]}"
|
|
|
|
return TestExtractorResponse(
|
|
success=False, error=error_msg, execution_time_ms=(time.time() - start_time) * 1000
|
|
)
|
|
|
|
|
|
# ==================== Streaming Generation ====================
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
|
@router.post("/generate-extractor/stream")
|
|
async def generate_extractor_code_stream(request: ExtractorGenerationRequest):
|
|
"""
|
|
Stream Python extractor code generation using Claude Code CLI.
|
|
|
|
Uses Server-Sent Events (SSE) to stream tokens as they arrive.
|
|
|
|
Event types:
|
|
- data: {"type": "token", "content": "..."} - Partial code token
|
|
- data: {"type": "done", "code": "...", "outputs": [...]} - Final result
|
|
- data: {"type": "error", "message": "..."} - Error occurred
|
|
|
|
Args:
|
|
request: ExtractorGenerationRequest with prompt and context
|
|
|
|
Returns:
|
|
StreamingResponse with text/event-stream content type
|
|
"""
|
|
# Build focused system prompt for extractor generation
|
|
system_prompt = """You are generating a Python custom extractor function for Atomizer FEA optimization.
|
|
|
|
The function MUST:
|
|
1. Have signature: def extract(op2_path: str, fem_path: str, params: dict, subcase_id: int = 1) -> dict
|
|
2. Return a dict with extracted values (e.g., {"max_stress": 150.5, "mass": 2.3})
|
|
3. Use pyNastran.op2.op2.OP2 for reading OP2 results
|
|
4. Handle missing data gracefully with try/except blocks
|
|
|
|
Available imports (already available, just use them):
|
|
- from pyNastran.op2.op2 import OP2
|
|
- import numpy as np
|
|
- from pathlib import Path
|
|
|
|
Common patterns:
|
|
- Displacement: op2.displacements[subcase_id].data[0, :, 1:4] (x,y,z components)
|
|
- Stress: op2.cquad4_stress[subcase_id] or op2.ctria3_stress[subcase_id]
|
|
- Eigenvalues: op2.eigenvalues[subcase_id]
|
|
|
|
Return ONLY the complete Python code wrapped in ```python ... ```. No explanations outside the code block."""
|
|
|
|
# Build user prompt with context
|
|
user_prompt = f"Generate a custom extractor that: {request.prompt}"
|
|
|
|
if request.existing_code:
|
|
user_prompt += (
|
|
f"\n\nImprove or modify this existing code:\n```python\n{request.existing_code}\n```"
|
|
)
|
|
|
|
if request.output_names:
|
|
user_prompt += (
|
|
f"\n\nThe function should output these keys: {', '.join(request.output_names)}"
|
|
)
|
|
|
|
async def generate():
|
|
full_output = ""
|
|
|
|
try:
|
|
# Call Claude CLI with streaming output
|
|
process = await asyncio.create_subprocess_exec(
|
|
"claude",
|
|
"--print",
|
|
"--system-prompt",
|
|
system_prompt,
|
|
stdin=asyncio.subprocess.PIPE,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
cwd=str(ATOMIZER_ROOT),
|
|
env={
|
|
**os.environ,
|
|
"ATOMIZER_ROOT": str(ATOMIZER_ROOT),
|
|
},
|
|
)
|
|
|
|
# Write prompt to stdin and close
|
|
process.stdin.write(user_prompt.encode("utf-8"))
|
|
await process.stdin.drain()
|
|
process.stdin.close()
|
|
|
|
# Stream stdout chunks as they arrive
|
|
while True:
|
|
chunk = await asyncio.wait_for(
|
|
process.stdout.read(256), # Read in small chunks for responsiveness
|
|
timeout=60.0,
|
|
)
|
|
if not chunk:
|
|
break
|
|
|
|
decoded = chunk.decode("utf-8", errors="replace")
|
|
full_output += decoded
|
|
|
|
# Send token event
|
|
yield f"data: {json.dumps({'type': 'token', 'content': decoded})}\n\n"
|
|
|
|
# Wait for process to complete
|
|
await process.wait()
|
|
|
|
# Check for errors
|
|
if process.returncode != 0:
|
|
stderr = await process.stderr.read()
|
|
error_text = stderr.decode("utf-8", errors="replace")
|
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Claude CLI error: {error_text[:500]}'})}\n\n"
|
|
return
|
|
|
|
# Parse the complete output to extract code
|
|
code_match = re.search(r"```python\s*(.*?)\s*```", full_output, re.DOTALL)
|
|
if code_match:
|
|
code = code_match.group(1).strip()
|
|
elif "def extract(" in full_output:
|
|
code = full_output.strip()
|
|
else:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse generated code'})}\n\n"
|
|
return
|
|
|
|
# Detect output names
|
|
detected_outputs: List[str] = []
|
|
return_match = re.search(r"return\s*\{([^}]+)\}", code)
|
|
if return_match:
|
|
key_matches = re.findall(r"['\"]([^'\"]+)['\"]:", return_match.group(1))
|
|
detected_outputs = key_matches
|
|
|
|
final_outputs = detected_outputs if detected_outputs else request.output_names
|
|
|
|
# Send completion event with parsed code
|
|
yield f"data: {json.dumps({'type': 'done', 'code': code, 'outputs': final_outputs})}\n\n"
|
|
|
|
except asyncio.TimeoutError:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': 'Generation timed out (60s limit)'})}\n\n"
|
|
except Exception as e:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
},
|
|
)
|
|
|
|
|
|
# ==================== Session Management ====================
|
|
|
|
# Store active WebSocket connections
|
|
_active_connections: Dict[str, WebSocket] = {}
|
|
|
|
|
|
@router.post("/sessions")
|
|
async def create_claude_code_session(study_id: Optional[str] = None):
|
|
"""
|
|
Create a new Claude Code session.
|
|
|
|
Args:
|
|
study_id: Optional study to provide context
|
|
|
|
Returns:
|
|
Session info including session_id
|
|
"""
|
|
try:
|
|
manager = get_claude_code_manager()
|
|
session = manager.create_session(study_id)
|
|
|
|
return {
|
|
"session_id": session.session_id,
|
|
"study_id": session.study_id,
|
|
"working_dir": str(session.working_dir),
|
|
"message": "Claude Code session created. Connect via WebSocket to chat.",
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/sessions/{session_id}")
|
|
async def get_claude_code_session(session_id: str):
|
|
"""Get session info"""
|
|
manager = get_claude_code_manager()
|
|
session = manager.get_session(session_id)
|
|
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
return {
|
|
"session_id": session.session_id,
|
|
"study_id": session.study_id,
|
|
"working_dir": str(session.working_dir),
|
|
"has_canvas_state": session.canvas_state is not None,
|
|
"conversation_length": len(session.conversation_history),
|
|
}
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def delete_claude_code_session(session_id: str):
|
|
"""Delete a session"""
|
|
manager = get_claude_code_manager()
|
|
manager.remove_session(session_id)
|
|
return {"message": "Session deleted"}
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def claude_code_websocket(websocket: WebSocket):
|
|
"""
|
|
WebSocket for full Claude Code CLI access (no session required).
|
|
|
|
This is a simplified endpoint that creates a session per connection.
|
|
|
|
Message formats (client -> server):
|
|
{"type": "init", "study_id": "optional_study_name"}
|
|
{"type": "message", "content": "user message"}
|
|
{"type": "set_canvas", "canvas_state": {...}}
|
|
{"type": "ping"}
|
|
|
|
Message formats (server -> client):
|
|
{"type": "initialized", "session_id": "...", "study_id": "..."}
|
|
{"type": "text", "content": "..."}
|
|
{"type": "done"}
|
|
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
|
|
{"type": "error", "content": "..."}
|
|
{"type": "pong"}
|
|
"""
|
|
print("[ClaudeCode WS] Connection attempt received")
|
|
await websocket.accept()
|
|
print("[ClaudeCode WS] WebSocket accepted")
|
|
|
|
manager = get_claude_code_manager()
|
|
session: Optional[ClaudeCodeSession] = None
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "init":
|
|
# Create or reinitialize session
|
|
study_id = data.get("study_id")
|
|
session = manager.create_session(study_id)
|
|
_active_connections[session.session_id] = websocket
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"type": "initialized",
|
|
"session_id": session.session_id,
|
|
"study_id": session.study_id,
|
|
"working_dir": str(session.working_dir),
|
|
}
|
|
)
|
|
|
|
elif msg_type == "message":
|
|
if not session:
|
|
# Auto-create session if not initialized
|
|
session = manager.create_session()
|
|
_active_connections[session.session_id] = websocket
|
|
|
|
content = data.get("content", "")
|
|
if not content:
|
|
continue
|
|
|
|
# Update canvas state if provided with message
|
|
if data.get("canvas_state"):
|
|
session.set_canvas_state(data["canvas_state"])
|
|
|
|
# Stream response from Claude Code CLI
|
|
async for chunk in session.send_message(content):
|
|
await websocket.send_json(chunk)
|
|
|
|
elif msg_type == "set_canvas":
|
|
if session:
|
|
session.set_canvas_state(data.get("canvas_state", {}))
|
|
await websocket.send_json(
|
|
{
|
|
"type": "canvas_updated",
|
|
}
|
|
)
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
# Clean up on disconnect
|
|
if session:
|
|
_active_connections.pop(session.session_id, None)
|
|
# Keep session in manager for potential reconnect
|
|
except Exception as e:
|
|
try:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"content": str(e),
|
|
}
|
|
)
|
|
except:
|
|
pass
|
|
if session:
|
|
_active_connections.pop(session.session_id, None)
|
|
|
|
|
|
@router.websocket("/ws/{study_id:path}")
|
|
async def claude_code_websocket_with_study(websocket: WebSocket, study_id: str):
|
|
"""
|
|
WebSocket for Claude Code CLI with study context.
|
|
|
|
Same as /ws but automatically initializes with the given study.
|
|
|
|
Message formats (client -> server):
|
|
{"type": "message", "content": "user message"}
|
|
{"type": "set_canvas", "canvas_state": {...}}
|
|
{"type": "ping"}
|
|
|
|
Message formats (server -> client):
|
|
{"type": "initialized", "session_id": "...", "study_id": "..."}
|
|
{"type": "text", "content": "..."}
|
|
{"type": "done"}
|
|
{"type": "refresh_canvas", "study_id": "...", "reason": "..."}
|
|
{"type": "error", "content": "..."}
|
|
{"type": "pong"}
|
|
"""
|
|
print(f"[ClaudeCode WS] Connection attempt received for study: {study_id}")
|
|
await websocket.accept()
|
|
print(f"[ClaudeCode WS] WebSocket accepted for study: {study_id}")
|
|
|
|
manager = get_claude_code_manager()
|
|
session = manager.create_session(study_id)
|
|
_active_connections[session.session_id] = websocket
|
|
|
|
# Send initialization message
|
|
await websocket.send_json(
|
|
{
|
|
"type": "initialized",
|
|
"session_id": session.session_id,
|
|
"study_id": session.study_id,
|
|
"working_dir": str(session.working_dir),
|
|
}
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "message":
|
|
content = data.get("content", "")
|
|
if not content:
|
|
continue
|
|
|
|
# Update canvas state if provided with message
|
|
if data.get("canvas_state"):
|
|
session.set_canvas_state(data["canvas_state"])
|
|
|
|
# Stream response from Claude Code CLI
|
|
async for chunk in session.send_message(content):
|
|
await websocket.send_json(chunk)
|
|
|
|
elif msg_type == "set_canvas":
|
|
session.set_canvas_state(data.get("canvas_state", {}))
|
|
await websocket.send_json(
|
|
{
|
|
"type": "canvas_updated",
|
|
}
|
|
)
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
_active_connections.pop(session.session_id, None)
|
|
except Exception as e:
|
|
try:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"content": str(e),
|
|
}
|
|
)
|
|
except:
|
|
pass
|
|
_active_connections.pop(session.session_id, None)
|