feat: Add Claude Code terminal integration to dashboard
- Add embedded Claude Code terminal with xterm.js for full CLI experience - Create WebSocket PTY backend for real-time terminal communication - Add terminal status endpoint to check CLI availability - Update dashboard to use Claude Code terminal instead of API chat - Add optimization control panel with start/stop/validate actions - Add study context provider for global state management - Update frontend with new dependencies (xterm.js addons) - Comprehensive README documentation for all new features 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
7
atomizer-dashboard/backend/api/services/__init__.py
Normal file
7
atomizer-dashboard/backend/api/services/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Atomizer Dashboard Services
|
||||
"""
|
||||
|
||||
from .claude_agent import AtomizerClaudeAgent
|
||||
|
||||
__all__ = ['AtomizerClaudeAgent']
|
||||
715
atomizer-dashboard/backend/api/services/claude_agent.py
Normal file
715
atomizer-dashboard/backend/api/services/claude_agent.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""
|
||||
Atomizer Claude Agent Service
|
||||
|
||||
Provides Claude AI integration with Atomizer-specific tools for:
|
||||
- Analyzing optimization results
|
||||
- Querying trial data
|
||||
- Modifying configurations
|
||||
- Creating new studies
|
||||
- Explaining FEA/Zernike concepts
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, AsyncGenerator
|
||||
from datetime import datetime
|
||||
import anthropic
|
||||
|
||||
# Base studies directory
|
||||
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
|
||||
ATOMIZER_ROOT = Path(__file__).parent.parent.parent.parent.parent
|
||||
|
||||
|
||||
class AtomizerClaudeAgent:
|
||||
"""Claude agent with Atomizer-specific tools and context"""
|
||||
|
||||
def __init__(self, study_id: Optional[str] = None):
|
||||
self.client = anthropic.Anthropic()
|
||||
self.study_id = study_id
|
||||
self.study_dir = STUDIES_DIR / study_id if study_id else None
|
||||
self.tools = self._define_tools()
|
||||
self.system_prompt = self._build_system_prompt()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build context-aware system prompt for Atomizer"""
|
||||
base_prompt = """You are Claude Code embedded in the Atomizer FEA optimization dashboard.
|
||||
|
||||
## Your Role
|
||||
You help engineers with structural optimization using NX Nastran simulations. You can:
|
||||
1. **Analyze Results** - Interpret optimization progress, identify trends, explain convergence
|
||||
2. **Query Data** - Fetch trial data, compare configurations, find best designs
|
||||
3. **Modify Settings** - Update design variable bounds, objectives, constraints
|
||||
4. **Explain Concepts** - FEA, Zernike polynomials, wavefront error, stress analysis
|
||||
5. **Troubleshoot** - Debug failed trials, identify issues, suggest fixes
|
||||
|
||||
## Atomizer Context
|
||||
- Atomizer uses Optuna for Bayesian optimization
|
||||
- Studies can use FEA-only or hybrid FEA/Neural surrogate approaches
|
||||
- Results are stored in SQLite databases (study.db)
|
||||
- Design variables are NX expressions in CAD models
|
||||
- Objectives include stress, displacement, frequency, Zernike WFE
|
||||
|
||||
## Guidelines
|
||||
- Be concise but thorough
|
||||
- Use technical language appropriate for engineers
|
||||
- When showing data, format it clearly (tables, lists)
|
||||
- If uncertain, say so and suggest how to verify
|
||||
- Proactively suggest next steps or insights
|
||||
|
||||
"""
|
||||
|
||||
# Add study-specific context if available
|
||||
if self.study_id and self.study_dir and self.study_dir.exists():
|
||||
context = self._get_study_context()
|
||||
base_prompt += f"\n## Current Study: {self.study_id}\n{context}\n"
|
||||
else:
|
||||
base_prompt += "\n## Current Study: None selected\nAsk the user to select a study or help them create a new one.\n"
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _get_study_context(self) -> str:
|
||||
"""Get context information about the current study"""
|
||||
context_parts = []
|
||||
|
||||
# Try to load config
|
||||
config_path = self.study_dir / "1_setup" / "optimization_config.json"
|
||||
if not config_path.exists():
|
||||
config_path = self.study_dir / "optimization_config.json"
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Design variables
|
||||
dvs = config.get('design_variables', [])
|
||||
if dvs:
|
||||
context_parts.append(f"**Design Variables ({len(dvs)})**: " +
|
||||
", ".join(dv['name'] for dv in dvs[:5]) +
|
||||
("..." if len(dvs) > 5 else ""))
|
||||
|
||||
# Objectives
|
||||
objs = config.get('objectives', [])
|
||||
if objs:
|
||||
context_parts.append(f"**Objectives ({len(objs)})**: " +
|
||||
", ".join(f"{o['name']} ({o.get('direction', 'minimize')})"
|
||||
for o in objs))
|
||||
|
||||
# Constraints
|
||||
constraints = config.get('constraints', [])
|
||||
if constraints:
|
||||
context_parts.append(f"**Constraints**: " +
|
||||
", ".join(c['name'] for c in constraints))
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get trial count from database
|
||||
results_dir = self.study_dir / "2_results"
|
||||
if not results_dir.exists():
|
||||
results_dir = self.study_dir / "3_results"
|
||||
|
||||
db_path = results_dir / "study.db" if results_dir.exists() else None
|
||||
if db_path and db_path.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'")
|
||||
trial_count = cursor.fetchone()[0]
|
||||
context_parts.append(f"**Completed Trials**: {trial_count}")
|
||||
|
||||
# Get best value
|
||||
cursor.execute("""
|
||||
SELECT MIN(value) FROM trial_values
|
||||
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state='COMPLETE')
|
||||
""")
|
||||
best = cursor.fetchone()[0]
|
||||
if best is not None:
|
||||
context_parts.append(f"**Best Objective**: {best:.6f}")
|
||||
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "\n".join(context_parts) if context_parts else "No configuration found."
|
||||
|
||||
def _define_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Define Atomizer-specific tools for Claude"""
|
||||
return [
|
||||
{
|
||||
"name": "read_study_config",
|
||||
"description": "Read the optimization configuration for the current or specified study. Returns design variables, objectives, constraints, and algorithm settings.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID to read config from. Uses current study if not specified."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "query_trials",
|
||||
"description": "Query trial data from the Optuna database. Can filter by state, source (FEA/NN), objective value range, or parameter values.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID to query. Uses current study if not specified."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": ["COMPLETE", "PRUNED", "FAIL", "RUNNING", "all"],
|
||||
"description": "Filter by trial state. Default: COMPLETE"
|
||||
},
|
||||
"source": {
|
||||
"type": "string",
|
||||
"enum": ["fea", "nn", "all"],
|
||||
"description": "Filter by trial source (FEA simulation or Neural Network). Default: all"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of trials to return. Default: 20"
|
||||
},
|
||||
"order_by": {
|
||||
"type": "string",
|
||||
"enum": ["value_asc", "value_desc", "trial_id_asc", "trial_id_desc"],
|
||||
"description": "Sort order. Default: value_asc (best first)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_trial_details",
|
||||
"description": "Get detailed information about a specific trial including all parameters, objective values, and user attributes.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID. Uses current study if not specified."
|
||||
},
|
||||
"trial_id": {
|
||||
"type": "integer",
|
||||
"description": "The trial number to get details for."
|
||||
}
|
||||
},
|
||||
"required": ["trial_id"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "compare_trials",
|
||||
"description": "Compare two or more trials side-by-side, showing parameter differences and objective values.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID. Uses current study if not specified."
|
||||
},
|
||||
"trial_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "List of trial IDs to compare (2-5 trials)."
|
||||
}
|
||||
},
|
||||
"required": ["trial_ids"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_optimization_summary",
|
||||
"description": "Get a high-level summary of the optimization progress including trial counts, convergence status, best designs, and parameter sensitivity.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID. Uses current study if not specified."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "read_study_readme",
|
||||
"description": "Read the README.md documentation for a study, which contains the engineering problem description, mathematical formulation, and methodology.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"study_id": {
|
||||
"type": "string",
|
||||
"description": "Study ID. Uses current study if not specified."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "list_studies",
|
||||
"description": "List all available optimization studies with their status and trial counts.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str:
|
||||
"""Execute an Atomizer tool and return the result"""
|
||||
try:
|
||||
if tool_name == "read_study_config":
|
||||
return self._tool_read_config(tool_input.get('study_id'))
|
||||
elif tool_name == "query_trials":
|
||||
return self._tool_query_trials(tool_input)
|
||||
elif tool_name == "get_trial_details":
|
||||
return self._tool_get_trial_details(tool_input)
|
||||
elif tool_name == "compare_trials":
|
||||
return self._tool_compare_trials(tool_input)
|
||||
elif tool_name == "get_optimization_summary":
|
||||
return self._tool_get_summary(tool_input.get('study_id'))
|
||||
elif tool_name == "read_study_readme":
|
||||
return self._tool_read_readme(tool_input.get('study_id'))
|
||||
elif tool_name == "list_studies":
|
||||
return self._tool_list_studies()
|
||||
else:
|
||||
return f"Unknown tool: {tool_name}"
|
||||
except Exception as e:
|
||||
return f"Error executing {tool_name}: {str(e)}"
|
||||
|
||||
def _get_study_dir(self, study_id: Optional[str]) -> Path:
|
||||
"""Get study directory, using current study if not specified"""
|
||||
sid = study_id or self.study_id
|
||||
if not sid:
|
||||
raise ValueError("No study specified and no current study selected")
|
||||
study_dir = STUDIES_DIR / sid
|
||||
if not study_dir.exists():
|
||||
raise ValueError(f"Study '{sid}' not found")
|
||||
return study_dir
|
||||
|
||||
def _get_db_path(self, study_id: Optional[str]) -> Path:
|
||||
"""Get database path for a study"""
|
||||
study_dir = self._get_study_dir(study_id)
|
||||
for results_dir_name in ["2_results", "3_results"]:
|
||||
db_path = study_dir / results_dir_name / "study.db"
|
||||
if db_path.exists():
|
||||
return db_path
|
||||
raise ValueError(f"No database found for study")
|
||||
|
||||
def _tool_read_config(self, study_id: Optional[str]) -> str:
|
||||
"""Read study configuration"""
|
||||
study_dir = self._get_study_dir(study_id)
|
||||
|
||||
config_path = study_dir / "1_setup" / "optimization_config.json"
|
||||
if not config_path.exists():
|
||||
config_path = study_dir / "optimization_config.json"
|
||||
|
||||
if not config_path.exists():
|
||||
return "No configuration file found for this study."
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Format nicely
|
||||
result = [f"# Configuration for {study_id or self.study_id}\n"]
|
||||
|
||||
# Design variables
|
||||
dvs = config.get('design_variables', [])
|
||||
if dvs:
|
||||
result.append("## Design Variables")
|
||||
result.append("| Name | Min | Max | Baseline | Units |")
|
||||
result.append("|------|-----|-----|----------|-------|")
|
||||
for dv in dvs:
|
||||
result.append(f"| {dv['name']} | {dv.get('min', '-')} | {dv.get('max', '-')} | {dv.get('baseline', '-')} | {dv.get('units', '-')} |")
|
||||
|
||||
# Objectives
|
||||
objs = config.get('objectives', [])
|
||||
if objs:
|
||||
result.append("\n## Objectives")
|
||||
result.append("| Name | Direction | Weight | Target | Units |")
|
||||
result.append("|------|-----------|--------|--------|-------|")
|
||||
for obj in objs:
|
||||
result.append(f"| {obj['name']} | {obj.get('direction', 'minimize')} | {obj.get('weight', 1.0)} | {obj.get('target', '-')} | {obj.get('units', '-')} |")
|
||||
|
||||
# Constraints
|
||||
constraints = config.get('constraints', [])
|
||||
if constraints:
|
||||
result.append("\n## Constraints")
|
||||
for c in constraints:
|
||||
result.append(f"- **{c['name']}**: {c.get('type', 'bound')} {c.get('max_value', c.get('min_value', ''))} {c.get('units', '')}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def _tool_query_trials(self, params: Dict[str, Any]) -> str:
|
||||
"""Query trials from database"""
|
||||
db_path = self._get_db_path(params.get('study_id'))
|
||||
|
||||
state = params.get('state', 'COMPLETE')
|
||||
source = params.get('source', 'all')
|
||||
limit = params.get('limit', 20)
|
||||
order_by = params.get('order_by', 'value_asc')
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query
|
||||
query = """
|
||||
SELECT t.trial_id, t.state, tv.value,
|
||||
GROUP_CONCAT(tp.param_name || '=' || ROUND(tp.param_value, 4), ', ') as params
|
||||
FROM trials t
|
||||
LEFT JOIN trial_values tv ON t.trial_id = tv.trial_id
|
||||
LEFT JOIN trial_params tp ON t.trial_id = tp.trial_id
|
||||
"""
|
||||
|
||||
conditions = []
|
||||
if state != 'all':
|
||||
conditions.append(f"t.state = '{state}'")
|
||||
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
query += " GROUP BY t.trial_id"
|
||||
|
||||
# Order
|
||||
if order_by == 'value_asc':
|
||||
query += " ORDER BY tv.value ASC"
|
||||
elif order_by == 'value_desc':
|
||||
query += " ORDER BY tv.value DESC"
|
||||
elif order_by == 'trial_id_desc':
|
||||
query += " ORDER BY t.trial_id DESC"
|
||||
else:
|
||||
query += " ORDER BY t.trial_id ASC"
|
||||
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return "No trials found matching the criteria."
|
||||
|
||||
# Filter by source if needed (check user_attrs)
|
||||
if source != 'all':
|
||||
# Would need another query to filter by trial_source attr
|
||||
pass
|
||||
|
||||
# Format results
|
||||
result = [f"# Trials (showing {len(rows)}/{limit} max)\n"]
|
||||
result.append("| Trial | State | Objective | Parameters |")
|
||||
result.append("|-------|-------|-----------|------------|")
|
||||
|
||||
for row in rows:
|
||||
value = f"{row['value']:.6f}" if row['value'] else "N/A"
|
||||
params = row['params'][:50] + "..." if row['params'] and len(row['params']) > 50 else (row['params'] or "")
|
||||
result.append(f"| {row['trial_id']} | {row['state']} | {value} | {params} |")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def _tool_get_trial_details(self, params: Dict[str, Any]) -> str:
|
||||
"""Get detailed trial information"""
|
||||
db_path = self._get_db_path(params.get('study_id'))
|
||||
trial_id = params['trial_id']
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get trial info
|
||||
cursor.execute("SELECT * FROM trials WHERE trial_id = ?", (trial_id,))
|
||||
trial = cursor.fetchone()
|
||||
|
||||
if not trial:
|
||||
conn.close()
|
||||
return f"Trial {trial_id} not found."
|
||||
|
||||
result = [f"# Trial {trial_id} Details\n"]
|
||||
result.append(f"**State**: {trial['state']}")
|
||||
|
||||
# Get objective value
|
||||
cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (trial_id,))
|
||||
value_row = cursor.fetchone()
|
||||
if value_row:
|
||||
result.append(f"**Objective Value**: {value_row['value']:.6f}")
|
||||
|
||||
# Get parameters
|
||||
cursor.execute("SELECT param_name, param_value FROM trial_params WHERE trial_id = ? ORDER BY param_name", (trial_id,))
|
||||
params_rows = cursor.fetchall()
|
||||
|
||||
if params_rows:
|
||||
result.append("\n## Parameters")
|
||||
result.append("| Parameter | Value |")
|
||||
result.append("|-----------|-------|")
|
||||
for p in params_rows:
|
||||
result.append(f"| {p['param_name']} | {p['param_value']:.6f} |")
|
||||
|
||||
# Get user attributes
|
||||
cursor.execute("SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ?", (trial_id,))
|
||||
attrs = cursor.fetchall()
|
||||
|
||||
if attrs:
|
||||
result.append("\n## Attributes")
|
||||
for attr in attrs:
|
||||
try:
|
||||
value = json.loads(attr['value_json'])
|
||||
if isinstance(value, float):
|
||||
result.append(f"- **{attr['key']}**: {value:.6f}")
|
||||
else:
|
||||
result.append(f"- **{attr['key']}**: {value}")
|
||||
except:
|
||||
result.append(f"- **{attr['key']}**: {attr['value_json']}")
|
||||
|
||||
conn.close()
|
||||
return "\n".join(result)
|
||||
|
||||
def _tool_compare_trials(self, params: Dict[str, Any]) -> str:
|
||||
"""Compare multiple trials"""
|
||||
db_path = self._get_db_path(params.get('study_id'))
|
||||
trial_ids = params['trial_ids']
|
||||
|
||||
if len(trial_ids) < 2:
|
||||
return "Need at least 2 trials to compare."
|
||||
if len(trial_ids) > 5:
|
||||
return "Maximum 5 trials for comparison."
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
result = ["# Trial Comparison\n"]
|
||||
|
||||
# Get all parameter names
|
||||
cursor.execute("SELECT DISTINCT param_name FROM trial_params ORDER BY param_name")
|
||||
param_names = [row['param_name'] for row in cursor.fetchall()]
|
||||
|
||||
# Build comparison table header
|
||||
header = "| Parameter | " + " | ".join(f"Trial {tid}" for tid in trial_ids) + " |"
|
||||
separator = "|-----------|" + "|".join("-" * 10 for _ in trial_ids) + "|"
|
||||
|
||||
result.append(header)
|
||||
result.append(separator)
|
||||
|
||||
# Objective values row
|
||||
obj_values = []
|
||||
for tid in trial_ids:
|
||||
cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (tid,))
|
||||
row = cursor.fetchone()
|
||||
obj_values.append(f"{row['value']:.4f}" if row else "N/A")
|
||||
result.append("| **Objective** | " + " | ".join(obj_values) + " |")
|
||||
|
||||
# Parameter rows
|
||||
for pname in param_names:
|
||||
values = []
|
||||
for tid in trial_ids:
|
||||
cursor.execute("SELECT param_value FROM trial_params WHERE trial_id = ? AND param_name = ?", (tid, pname))
|
||||
row = cursor.fetchone()
|
||||
values.append(f"{row['param_value']:.4f}" if row else "N/A")
|
||||
result.append(f"| {pname} | " + " | ".join(values) + " |")
|
||||
|
||||
conn.close()
|
||||
return "\n".join(result)
|
||||
|
||||
def _tool_get_summary(self, study_id: Optional[str]) -> str:
|
||||
"""Get optimization summary"""
|
||||
db_path = self._get_db_path(study_id)
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
result = [f"# Optimization Summary\n"]
|
||||
|
||||
# Trial counts by state
|
||||
cursor.execute("SELECT state, COUNT(*) as count FROM trials GROUP BY state")
|
||||
states = {row['state']: row['count'] for row in cursor.fetchall()}
|
||||
|
||||
result.append("## Trial Counts")
|
||||
total = sum(states.values())
|
||||
result.append(f"- **Total**: {total}")
|
||||
for state, count in states.items():
|
||||
result.append(f"- {state}: {count}")
|
||||
|
||||
# Best trial
|
||||
cursor.execute("""
|
||||
SELECT t.trial_id, tv.value
|
||||
FROM trials t
|
||||
JOIN trial_values tv ON t.trial_id = tv.trial_id
|
||||
WHERE t.state = 'COMPLETE'
|
||||
ORDER BY tv.value ASC LIMIT 1
|
||||
""")
|
||||
best = cursor.fetchone()
|
||||
if best:
|
||||
result.append(f"\n## Best Trial")
|
||||
result.append(f"- **Trial ID**: {best['trial_id']}")
|
||||
result.append(f"- **Objective**: {best['value']:.6f}")
|
||||
|
||||
# FEA vs NN counts
|
||||
cursor.execute("""
|
||||
SELECT value_json, COUNT(*) as count
|
||||
FROM trial_user_attributes
|
||||
WHERE key = 'trial_source'
|
||||
GROUP BY value_json
|
||||
""")
|
||||
sources = cursor.fetchall()
|
||||
if sources:
|
||||
result.append("\n## Trial Sources")
|
||||
for src in sources:
|
||||
source_name = json.loads(src['value_json']) if src['value_json'] else 'unknown'
|
||||
result.append(f"- **{source_name}**: {src['count']}")
|
||||
|
||||
conn.close()
|
||||
return "\n".join(result)
|
||||
|
||||
def _tool_read_readme(self, study_id: Optional[str]) -> str:
|
||||
"""Read study README"""
|
||||
study_dir = self._get_study_dir(study_id)
|
||||
readme_path = study_dir / "README.md"
|
||||
|
||||
if not readme_path.exists():
|
||||
return "No README.md found for this study."
|
||||
|
||||
content = readme_path.read_text(encoding='utf-8')
|
||||
# Truncate if too long
|
||||
if len(content) > 8000:
|
||||
content = content[:8000] + "\n\n... (truncated)"
|
||||
|
||||
return content
|
||||
|
||||
def _tool_list_studies(self) -> str:
|
||||
"""List all studies"""
|
||||
if not STUDIES_DIR.exists():
|
||||
return "Studies directory not found."
|
||||
|
||||
result = ["# Available Studies\n"]
|
||||
result.append("| Study | Status | Trials |")
|
||||
result.append("|-------|--------|--------|")
|
||||
|
||||
for study_dir in sorted(STUDIES_DIR.iterdir()):
|
||||
if not study_dir.is_dir():
|
||||
continue
|
||||
|
||||
study_id = study_dir.name
|
||||
|
||||
# Check for database
|
||||
trial_count = 0
|
||||
for results_dir_name in ["2_results", "3_results"]:
|
||||
db_path = study_dir / results_dir_name / "study.db"
|
||||
if db_path.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'")
|
||||
trial_count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
break
|
||||
|
||||
# Determine status
|
||||
status = "ready" if trial_count > 0 else "not_started"
|
||||
|
||||
result.append(f"| {study_id} | {status} | {trial_count} |")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
async def chat(self, message: str, conversation_history: Optional[List[Dict]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a chat message with tool use support
|
||||
|
||||
Args:
|
||||
message: User's message
|
||||
conversation_history: Previous messages for context
|
||||
|
||||
Returns:
|
||||
Dict with response text and any tool calls made
|
||||
"""
|
||||
messages = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
tool_calls_made = []
|
||||
|
||||
# Loop to handle tool use
|
||||
while True:
|
||||
response = self.client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4096,
|
||||
system=self.system_prompt,
|
||||
tools=self.tools,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# Check if we need to handle tool use
|
||||
if response.stop_reason == "tool_use":
|
||||
# Process tool calls
|
||||
assistant_content = response.content
|
||||
tool_results = []
|
||||
|
||||
for block in assistant_content:
|
||||
if block.type == "tool_use":
|
||||
tool_name = block.name
|
||||
tool_input = block.input
|
||||
tool_id = block.id
|
||||
|
||||
# Execute the tool
|
||||
result = self._execute_tool(tool_name, tool_input)
|
||||
|
||||
tool_calls_made.append({
|
||||
"tool": tool_name,
|
||||
"input": tool_input,
|
||||
"result_preview": result[:200] + "..." if len(result) > 200 else result
|
||||
})
|
||||
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_id,
|
||||
"content": result
|
||||
})
|
||||
|
||||
# Add assistant response and tool results to messages
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
else:
|
||||
# No more tool use, extract final response
|
||||
final_text = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, 'text'):
|
||||
final_text += block.text
|
||||
|
||||
return {
|
||||
"response": final_text,
|
||||
"tool_calls": tool_calls_made,
|
||||
"conversation": messages + [{"role": "assistant", "content": response.content}]
|
||||
}
|
||||
|
||||
async def chat_stream(self, message: str, conversation_history: Optional[List[Dict]] = None) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream a chat response token by token
|
||||
|
||||
Args:
|
||||
message: User's message
|
||||
conversation_history: Previous messages
|
||||
|
||||
Yields:
|
||||
Response tokens as they arrive
|
||||
"""
|
||||
messages = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
# For streaming, we'll do a simpler approach without tool use for now
|
||||
# (Tool use with streaming is more complex)
|
||||
with self.client.messages.stream(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4096,
|
||||
system=self.system_prompt,
|
||||
messages=messages
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text
|
||||
Reference in New Issue
Block a user