- Add embedded Claude Code terminal with xterm.js for full CLI experience - Create WebSocket PTY backend for real-time terminal communication - Add terminal status endpoint to check CLI availability - Update dashboard to use Claude Code terminal instead of API chat - Add optimization control panel with start/stop/validate actions - Add study context provider for global state management - Update frontend with new dependencies (xterm.js addons) - Comprehensive README documentation for all new features 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
716 lines
28 KiB
Python
716 lines
28 KiB
Python
"""
|
|
Atomizer Claude Agent Service
|
|
|
|
Provides Claude AI integration with Atomizer-specific tools for:
|
|
- Analyzing optimization results
|
|
- Querying trial data
|
|
- Modifying configurations
|
|
- Creating new studies
|
|
- Explaining FEA/Zernike concepts
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict, Any, AsyncGenerator
|
|
from datetime import datetime
|
|
import anthropic
|
|
|
|
# Base studies directory
|
|
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
|
|
ATOMIZER_ROOT = Path(__file__).parent.parent.parent.parent.parent
|
|
|
|
|
|
class AtomizerClaudeAgent:
|
|
"""Claude agent with Atomizer-specific tools and context"""
|
|
|
|
def __init__(self, study_id: Optional[str] = None):
|
|
self.client = anthropic.Anthropic()
|
|
self.study_id = study_id
|
|
self.study_dir = STUDIES_DIR / study_id if study_id else None
|
|
self.tools = self._define_tools()
|
|
self.system_prompt = self._build_system_prompt()
|
|
|
|
def _build_system_prompt(self) -> str:
|
|
"""Build context-aware system prompt for Atomizer"""
|
|
base_prompt = """You are Claude Code embedded in the Atomizer FEA optimization dashboard.
|
|
|
|
## Your Role
|
|
You help engineers with structural optimization using NX Nastran simulations. You can:
|
|
1. **Analyze Results** - Interpret optimization progress, identify trends, explain convergence
|
|
2. **Query Data** - Fetch trial data, compare configurations, find best designs
|
|
3. **Modify Settings** - Update design variable bounds, objectives, constraints
|
|
4. **Explain Concepts** - FEA, Zernike polynomials, wavefront error, stress analysis
|
|
5. **Troubleshoot** - Debug failed trials, identify issues, suggest fixes
|
|
|
|
## Atomizer Context
|
|
- Atomizer uses Optuna for Bayesian optimization
|
|
- Studies can use FEA-only or hybrid FEA/Neural surrogate approaches
|
|
- Results are stored in SQLite databases (study.db)
|
|
- Design variables are NX expressions in CAD models
|
|
- Objectives include stress, displacement, frequency, Zernike WFE
|
|
|
|
## Guidelines
|
|
- Be concise but thorough
|
|
- Use technical language appropriate for engineers
|
|
- When showing data, format it clearly (tables, lists)
|
|
- If uncertain, say so and suggest how to verify
|
|
- Proactively suggest next steps or insights
|
|
|
|
"""
|
|
|
|
# Add study-specific context if available
|
|
if self.study_id and self.study_dir and self.study_dir.exists():
|
|
context = self._get_study_context()
|
|
base_prompt += f"\n## Current Study: {self.study_id}\n{context}\n"
|
|
else:
|
|
base_prompt += "\n## Current Study: None selected\nAsk the user to select a study or help them create a new one.\n"
|
|
|
|
return base_prompt
|
|
|
|
def _get_study_context(self) -> str:
|
|
"""Get context information about the current study"""
|
|
context_parts = []
|
|
|
|
# Try to load config
|
|
config_path = self.study_dir / "1_setup" / "optimization_config.json"
|
|
if not config_path.exists():
|
|
config_path = self.study_dir / "optimization_config.json"
|
|
|
|
if config_path.exists():
|
|
try:
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
|
|
# Design variables
|
|
dvs = config.get('design_variables', [])
|
|
if dvs:
|
|
context_parts.append(f"**Design Variables ({len(dvs)})**: " +
|
|
", ".join(dv['name'] for dv in dvs[:5]) +
|
|
("..." if len(dvs) > 5 else ""))
|
|
|
|
# Objectives
|
|
objs = config.get('objectives', [])
|
|
if objs:
|
|
context_parts.append(f"**Objectives ({len(objs)})**: " +
|
|
", ".join(f"{o['name']} ({o.get('direction', 'minimize')})"
|
|
for o in objs))
|
|
|
|
# Constraints
|
|
constraints = config.get('constraints', [])
|
|
if constraints:
|
|
context_parts.append(f"**Constraints**: " +
|
|
", ".join(c['name'] for c in constraints))
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
# Try to get trial count from database
|
|
results_dir = self.study_dir / "2_results"
|
|
if not results_dir.exists():
|
|
results_dir = self.study_dir / "3_results"
|
|
|
|
db_path = results_dir / "study.db" if results_dir.exists() else None
|
|
if db_path and db_path.exists():
|
|
try:
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'")
|
|
trial_count = cursor.fetchone()[0]
|
|
context_parts.append(f"**Completed Trials**: {trial_count}")
|
|
|
|
# Get best value
|
|
cursor.execute("""
|
|
SELECT MIN(value) FROM trial_values
|
|
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state='COMPLETE')
|
|
""")
|
|
best = cursor.fetchone()[0]
|
|
if best is not None:
|
|
context_parts.append(f"**Best Objective**: {best:.6f}")
|
|
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
return "\n".join(context_parts) if context_parts else "No configuration found."
|
|
|
|
def _define_tools(self) -> List[Dict[str, Any]]:
|
|
"""Define Atomizer-specific tools for Claude"""
|
|
return [
|
|
{
|
|
"name": "read_study_config",
|
|
"description": "Read the optimization configuration for the current or specified study. Returns design variables, objectives, constraints, and algorithm settings.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID to read config from. Uses current study if not specified."
|
|
}
|
|
},
|
|
"required": []
|
|
}
|
|
},
|
|
{
|
|
"name": "query_trials",
|
|
"description": "Query trial data from the Optuna database. Can filter by state, source (FEA/NN), objective value range, or parameter values.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID to query. Uses current study if not specified."
|
|
},
|
|
"state": {
|
|
"type": "string",
|
|
"enum": ["COMPLETE", "PRUNED", "FAIL", "RUNNING", "all"],
|
|
"description": "Filter by trial state. Default: COMPLETE"
|
|
},
|
|
"source": {
|
|
"type": "string",
|
|
"enum": ["fea", "nn", "all"],
|
|
"description": "Filter by trial source (FEA simulation or Neural Network). Default: all"
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Maximum number of trials to return. Default: 20"
|
|
},
|
|
"order_by": {
|
|
"type": "string",
|
|
"enum": ["value_asc", "value_desc", "trial_id_asc", "trial_id_desc"],
|
|
"description": "Sort order. Default: value_asc (best first)"
|
|
}
|
|
},
|
|
"required": []
|
|
}
|
|
},
|
|
{
|
|
"name": "get_trial_details",
|
|
"description": "Get detailed information about a specific trial including all parameters, objective values, and user attributes.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID. Uses current study if not specified."
|
|
},
|
|
"trial_id": {
|
|
"type": "integer",
|
|
"description": "The trial number to get details for."
|
|
}
|
|
},
|
|
"required": ["trial_id"]
|
|
}
|
|
},
|
|
{
|
|
"name": "compare_trials",
|
|
"description": "Compare two or more trials side-by-side, showing parameter differences and objective values.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID. Uses current study if not specified."
|
|
},
|
|
"trial_ids": {
|
|
"type": "array",
|
|
"items": {"type": "integer"},
|
|
"description": "List of trial IDs to compare (2-5 trials)."
|
|
}
|
|
},
|
|
"required": ["trial_ids"]
|
|
}
|
|
},
|
|
{
|
|
"name": "get_optimization_summary",
|
|
"description": "Get a high-level summary of the optimization progress including trial counts, convergence status, best designs, and parameter sensitivity.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID. Uses current study if not specified."
|
|
}
|
|
},
|
|
"required": []
|
|
}
|
|
},
|
|
{
|
|
"name": "read_study_readme",
|
|
"description": "Read the README.md documentation for a study, which contains the engineering problem description, mathematical formulation, and methodology.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"study_id": {
|
|
"type": "string",
|
|
"description": "Study ID. Uses current study if not specified."
|
|
}
|
|
},
|
|
"required": []
|
|
}
|
|
},
|
|
{
|
|
"name": "list_studies",
|
|
"description": "List all available optimization studies with their status and trial counts.",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": []
|
|
}
|
|
}
|
|
]
|
|
|
|
def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str:
|
|
"""Execute an Atomizer tool and return the result"""
|
|
try:
|
|
if tool_name == "read_study_config":
|
|
return self._tool_read_config(tool_input.get('study_id'))
|
|
elif tool_name == "query_trials":
|
|
return self._tool_query_trials(tool_input)
|
|
elif tool_name == "get_trial_details":
|
|
return self._tool_get_trial_details(tool_input)
|
|
elif tool_name == "compare_trials":
|
|
return self._tool_compare_trials(tool_input)
|
|
elif tool_name == "get_optimization_summary":
|
|
return self._tool_get_summary(tool_input.get('study_id'))
|
|
elif tool_name == "read_study_readme":
|
|
return self._tool_read_readme(tool_input.get('study_id'))
|
|
elif tool_name == "list_studies":
|
|
return self._tool_list_studies()
|
|
else:
|
|
return f"Unknown tool: {tool_name}"
|
|
except Exception as e:
|
|
return f"Error executing {tool_name}: {str(e)}"
|
|
|
|
def _get_study_dir(self, study_id: Optional[str]) -> Path:
|
|
"""Get study directory, using current study if not specified"""
|
|
sid = study_id or self.study_id
|
|
if not sid:
|
|
raise ValueError("No study specified and no current study selected")
|
|
study_dir = STUDIES_DIR / sid
|
|
if not study_dir.exists():
|
|
raise ValueError(f"Study '{sid}' not found")
|
|
return study_dir
|
|
|
|
def _get_db_path(self, study_id: Optional[str]) -> Path:
|
|
"""Get database path for a study"""
|
|
study_dir = self._get_study_dir(study_id)
|
|
for results_dir_name in ["2_results", "3_results"]:
|
|
db_path = study_dir / results_dir_name / "study.db"
|
|
if db_path.exists():
|
|
return db_path
|
|
raise ValueError(f"No database found for study")
|
|
|
|
def _tool_read_config(self, study_id: Optional[str]) -> str:
|
|
"""Read study configuration"""
|
|
study_dir = self._get_study_dir(study_id)
|
|
|
|
config_path = study_dir / "1_setup" / "optimization_config.json"
|
|
if not config_path.exists():
|
|
config_path = study_dir / "optimization_config.json"
|
|
|
|
if not config_path.exists():
|
|
return "No configuration file found for this study."
|
|
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
|
|
# Format nicely
|
|
result = [f"# Configuration for {study_id or self.study_id}\n"]
|
|
|
|
# Design variables
|
|
dvs = config.get('design_variables', [])
|
|
if dvs:
|
|
result.append("## Design Variables")
|
|
result.append("| Name | Min | Max | Baseline | Units |")
|
|
result.append("|------|-----|-----|----------|-------|")
|
|
for dv in dvs:
|
|
result.append(f"| {dv['name']} | {dv.get('min', '-')} | {dv.get('max', '-')} | {dv.get('baseline', '-')} | {dv.get('units', '-')} |")
|
|
|
|
# Objectives
|
|
objs = config.get('objectives', [])
|
|
if objs:
|
|
result.append("\n## Objectives")
|
|
result.append("| Name | Direction | Weight | Target | Units |")
|
|
result.append("|------|-----------|--------|--------|-------|")
|
|
for obj in objs:
|
|
result.append(f"| {obj['name']} | {obj.get('direction', 'minimize')} | {obj.get('weight', 1.0)} | {obj.get('target', '-')} | {obj.get('units', '-')} |")
|
|
|
|
# Constraints
|
|
constraints = config.get('constraints', [])
|
|
if constraints:
|
|
result.append("\n## Constraints")
|
|
for c in constraints:
|
|
result.append(f"- **{c['name']}**: {c.get('type', 'bound')} {c.get('max_value', c.get('min_value', ''))} {c.get('units', '')}")
|
|
|
|
return "\n".join(result)
|
|
|
|
def _tool_query_trials(self, params: Dict[str, Any]) -> str:
|
|
"""Query trials from database"""
|
|
db_path = self._get_db_path(params.get('study_id'))
|
|
|
|
state = params.get('state', 'COMPLETE')
|
|
source = params.get('source', 'all')
|
|
limit = params.get('limit', 20)
|
|
order_by = params.get('order_by', 'value_asc')
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
# Build query
|
|
query = """
|
|
SELECT t.trial_id, t.state, tv.value,
|
|
GROUP_CONCAT(tp.param_name || '=' || ROUND(tp.param_value, 4), ', ') as params
|
|
FROM trials t
|
|
LEFT JOIN trial_values tv ON t.trial_id = tv.trial_id
|
|
LEFT JOIN trial_params tp ON t.trial_id = tp.trial_id
|
|
"""
|
|
|
|
conditions = []
|
|
if state != 'all':
|
|
conditions.append(f"t.state = '{state}'")
|
|
|
|
if conditions:
|
|
query += " WHERE " + " AND ".join(conditions)
|
|
|
|
query += " GROUP BY t.trial_id"
|
|
|
|
# Order
|
|
if order_by == 'value_asc':
|
|
query += " ORDER BY tv.value ASC"
|
|
elif order_by == 'value_desc':
|
|
query += " ORDER BY tv.value DESC"
|
|
elif order_by == 'trial_id_desc':
|
|
query += " ORDER BY t.trial_id DESC"
|
|
else:
|
|
query += " ORDER BY t.trial_id ASC"
|
|
|
|
query += f" LIMIT {limit}"
|
|
|
|
cursor.execute(query)
|
|
rows = cursor.fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return "No trials found matching the criteria."
|
|
|
|
# Filter by source if needed (check user_attrs)
|
|
if source != 'all':
|
|
# Would need another query to filter by trial_source attr
|
|
pass
|
|
|
|
# Format results
|
|
result = [f"# Trials (showing {len(rows)}/{limit} max)\n"]
|
|
result.append("| Trial | State | Objective | Parameters |")
|
|
result.append("|-------|-------|-----------|------------|")
|
|
|
|
for row in rows:
|
|
value = f"{row['value']:.6f}" if row['value'] else "N/A"
|
|
params = row['params'][:50] + "..." if row['params'] and len(row['params']) > 50 else (row['params'] or "")
|
|
result.append(f"| {row['trial_id']} | {row['state']} | {value} | {params} |")
|
|
|
|
return "\n".join(result)
|
|
|
|
def _tool_get_trial_details(self, params: Dict[str, Any]) -> str:
|
|
"""Get detailed trial information"""
|
|
db_path = self._get_db_path(params.get('study_id'))
|
|
trial_id = params['trial_id']
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
# Get trial info
|
|
cursor.execute("SELECT * FROM trials WHERE trial_id = ?", (trial_id,))
|
|
trial = cursor.fetchone()
|
|
|
|
if not trial:
|
|
conn.close()
|
|
return f"Trial {trial_id} not found."
|
|
|
|
result = [f"# Trial {trial_id} Details\n"]
|
|
result.append(f"**State**: {trial['state']}")
|
|
|
|
# Get objective value
|
|
cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (trial_id,))
|
|
value_row = cursor.fetchone()
|
|
if value_row:
|
|
result.append(f"**Objective Value**: {value_row['value']:.6f}")
|
|
|
|
# Get parameters
|
|
cursor.execute("SELECT param_name, param_value FROM trial_params WHERE trial_id = ? ORDER BY param_name", (trial_id,))
|
|
params_rows = cursor.fetchall()
|
|
|
|
if params_rows:
|
|
result.append("\n## Parameters")
|
|
result.append("| Parameter | Value |")
|
|
result.append("|-----------|-------|")
|
|
for p in params_rows:
|
|
result.append(f"| {p['param_name']} | {p['param_value']:.6f} |")
|
|
|
|
# Get user attributes
|
|
cursor.execute("SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ?", (trial_id,))
|
|
attrs = cursor.fetchall()
|
|
|
|
if attrs:
|
|
result.append("\n## Attributes")
|
|
for attr in attrs:
|
|
try:
|
|
value = json.loads(attr['value_json'])
|
|
if isinstance(value, float):
|
|
result.append(f"- **{attr['key']}**: {value:.6f}")
|
|
else:
|
|
result.append(f"- **{attr['key']}**: {value}")
|
|
except:
|
|
result.append(f"- **{attr['key']}**: {attr['value_json']}")
|
|
|
|
conn.close()
|
|
return "\n".join(result)
|
|
|
|
def _tool_compare_trials(self, params: Dict[str, Any]) -> str:
|
|
"""Compare multiple trials"""
|
|
db_path = self._get_db_path(params.get('study_id'))
|
|
trial_ids = params['trial_ids']
|
|
|
|
if len(trial_ids) < 2:
|
|
return "Need at least 2 trials to compare."
|
|
if len(trial_ids) > 5:
|
|
return "Maximum 5 trials for comparison."
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
result = ["# Trial Comparison\n"]
|
|
|
|
# Get all parameter names
|
|
cursor.execute("SELECT DISTINCT param_name FROM trial_params ORDER BY param_name")
|
|
param_names = [row['param_name'] for row in cursor.fetchall()]
|
|
|
|
# Build comparison table header
|
|
header = "| Parameter | " + " | ".join(f"Trial {tid}" for tid in trial_ids) + " |"
|
|
separator = "|-----------|" + "|".join("-" * 10 for _ in trial_ids) + "|"
|
|
|
|
result.append(header)
|
|
result.append(separator)
|
|
|
|
# Objective values row
|
|
obj_values = []
|
|
for tid in trial_ids:
|
|
cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (tid,))
|
|
row = cursor.fetchone()
|
|
obj_values.append(f"{row['value']:.4f}" if row else "N/A")
|
|
result.append("| **Objective** | " + " | ".join(obj_values) + " |")
|
|
|
|
# Parameter rows
|
|
for pname in param_names:
|
|
values = []
|
|
for tid in trial_ids:
|
|
cursor.execute("SELECT param_value FROM trial_params WHERE trial_id = ? AND param_name = ?", (tid, pname))
|
|
row = cursor.fetchone()
|
|
values.append(f"{row['param_value']:.4f}" if row else "N/A")
|
|
result.append(f"| {pname} | " + " | ".join(values) + " |")
|
|
|
|
conn.close()
|
|
return "\n".join(result)
|
|
|
|
def _tool_get_summary(self, study_id: Optional[str]) -> str:
|
|
"""Get optimization summary"""
|
|
db_path = self._get_db_path(study_id)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
result = [f"# Optimization Summary\n"]
|
|
|
|
# Trial counts by state
|
|
cursor.execute("SELECT state, COUNT(*) as count FROM trials GROUP BY state")
|
|
states = {row['state']: row['count'] for row in cursor.fetchall()}
|
|
|
|
result.append("## Trial Counts")
|
|
total = sum(states.values())
|
|
result.append(f"- **Total**: {total}")
|
|
for state, count in states.items():
|
|
result.append(f"- {state}: {count}")
|
|
|
|
# Best trial
|
|
cursor.execute("""
|
|
SELECT t.trial_id, tv.value
|
|
FROM trials t
|
|
JOIN trial_values tv ON t.trial_id = tv.trial_id
|
|
WHERE t.state = 'COMPLETE'
|
|
ORDER BY tv.value ASC LIMIT 1
|
|
""")
|
|
best = cursor.fetchone()
|
|
if best:
|
|
result.append(f"\n## Best Trial")
|
|
result.append(f"- **Trial ID**: {best['trial_id']}")
|
|
result.append(f"- **Objective**: {best['value']:.6f}")
|
|
|
|
# FEA vs NN counts
|
|
cursor.execute("""
|
|
SELECT value_json, COUNT(*) as count
|
|
FROM trial_user_attributes
|
|
WHERE key = 'trial_source'
|
|
GROUP BY value_json
|
|
""")
|
|
sources = cursor.fetchall()
|
|
if sources:
|
|
result.append("\n## Trial Sources")
|
|
for src in sources:
|
|
source_name = json.loads(src['value_json']) if src['value_json'] else 'unknown'
|
|
result.append(f"- **{source_name}**: {src['count']}")
|
|
|
|
conn.close()
|
|
return "\n".join(result)
|
|
|
|
def _tool_read_readme(self, study_id: Optional[str]) -> str:
|
|
"""Read study README"""
|
|
study_dir = self._get_study_dir(study_id)
|
|
readme_path = study_dir / "README.md"
|
|
|
|
if not readme_path.exists():
|
|
return "No README.md found for this study."
|
|
|
|
content = readme_path.read_text(encoding='utf-8')
|
|
# Truncate if too long
|
|
if len(content) > 8000:
|
|
content = content[:8000] + "\n\n... (truncated)"
|
|
|
|
return content
|
|
|
|
def _tool_list_studies(self) -> str:
|
|
"""List all studies"""
|
|
if not STUDIES_DIR.exists():
|
|
return "Studies directory not found."
|
|
|
|
result = ["# Available Studies\n"]
|
|
result.append("| Study | Status | Trials |")
|
|
result.append("|-------|--------|--------|")
|
|
|
|
for study_dir in sorted(STUDIES_DIR.iterdir()):
|
|
if not study_dir.is_dir():
|
|
continue
|
|
|
|
study_id = study_dir.name
|
|
|
|
# Check for database
|
|
trial_count = 0
|
|
for results_dir_name in ["2_results", "3_results"]:
|
|
db_path = study_dir / results_dir_name / "study.db"
|
|
if db_path.exists():
|
|
try:
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'")
|
|
trial_count = cursor.fetchone()[0]
|
|
conn.close()
|
|
except:
|
|
pass
|
|
break
|
|
|
|
# Determine status
|
|
status = "ready" if trial_count > 0 else "not_started"
|
|
|
|
result.append(f"| {study_id} | {status} | {trial_count} |")
|
|
|
|
return "\n".join(result)
|
|
|
|
async def chat(self, message: str, conversation_history: Optional[List[Dict]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Process a chat message with tool use support
|
|
|
|
Args:
|
|
message: User's message
|
|
conversation_history: Previous messages for context
|
|
|
|
Returns:
|
|
Dict with response text and any tool calls made
|
|
"""
|
|
messages = conversation_history.copy() if conversation_history else []
|
|
messages.append({"role": "user", "content": message})
|
|
|
|
tool_calls_made = []
|
|
|
|
# Loop to handle tool use
|
|
while True:
|
|
response = self.client.messages.create(
|
|
model="claude-sonnet-4-20250514",
|
|
max_tokens=4096,
|
|
system=self.system_prompt,
|
|
tools=self.tools,
|
|
messages=messages
|
|
)
|
|
|
|
# Check if we need to handle tool use
|
|
if response.stop_reason == "tool_use":
|
|
# Process tool calls
|
|
assistant_content = response.content
|
|
tool_results = []
|
|
|
|
for block in assistant_content:
|
|
if block.type == "tool_use":
|
|
tool_name = block.name
|
|
tool_input = block.input
|
|
tool_id = block.id
|
|
|
|
# Execute the tool
|
|
result = self._execute_tool(tool_name, tool_input)
|
|
|
|
tool_calls_made.append({
|
|
"tool": tool_name,
|
|
"input": tool_input,
|
|
"result_preview": result[:200] + "..." if len(result) > 200 else result
|
|
})
|
|
|
|
tool_results.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": tool_id,
|
|
"content": result
|
|
})
|
|
|
|
# Add assistant response and tool results to messages
|
|
messages.append({"role": "assistant", "content": assistant_content})
|
|
messages.append({"role": "user", "content": tool_results})
|
|
|
|
else:
|
|
# No more tool use, extract final response
|
|
final_text = ""
|
|
for block in response.content:
|
|
if hasattr(block, 'text'):
|
|
final_text += block.text
|
|
|
|
return {
|
|
"response": final_text,
|
|
"tool_calls": tool_calls_made,
|
|
"conversation": messages + [{"role": "assistant", "content": response.content}]
|
|
}
|
|
|
|
async def chat_stream(self, message: str, conversation_history: Optional[List[Dict]] = None) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Stream a chat response token by token
|
|
|
|
Args:
|
|
message: User's message
|
|
conversation_history: Previous messages
|
|
|
|
Yields:
|
|
Response tokens as they arrive
|
|
"""
|
|
messages = conversation_history.copy() if conversation_history else []
|
|
messages.append({"role": "user", "content": message})
|
|
|
|
# For streaming, we'll do a simpler approach without tool use for now
|
|
# (Tool use with streaming is more complex)
|
|
with self.client.messages.stream(
|
|
model="claude-sonnet-4-20250514",
|
|
max_tokens=4096,
|
|
system=self.system_prompt,
|
|
messages=messages
|
|
) as stream:
|
|
for text in stream.text_stream:
|
|
yield text
|