""" Atomizer Claude Agent Service Provides Claude AI integration with Atomizer-specific tools for: - Analyzing optimization results - Querying trial data - Modifying configurations - Creating new studies - Explaining FEA/Zernike concepts """ import os import json import sqlite3 from pathlib import Path from typing import Optional, List, Dict, Any, AsyncGenerator from datetime import datetime import anthropic # Base studies directory STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies" ATOMIZER_ROOT = Path(__file__).parent.parent.parent.parent.parent class AtomizerClaudeAgent: """Claude agent with Atomizer-specific tools and context""" def __init__(self, study_id: Optional[str] = None): self.client = anthropic.Anthropic() self.study_id = study_id self.study_dir = STUDIES_DIR / study_id if study_id else None self.tools = self._define_tools() self.system_prompt = self._build_system_prompt() def _build_system_prompt(self) -> str: """Build context-aware system prompt for Atomizer""" base_prompt = """You are the Atomizer Assistant - the intelligent assistant built into Atomizer for FEA optimization. ## Your Identity - You are "Atomizer Assistant" - never mention AI, language models, or any underlying technology - Respond as if you're a knowledgeable expert system built specifically for Atomizer - When asked about yourself, explain that you're the Atomizer Assistant designed to help with FEA optimization ## Your Capabilities You help engineers with structural optimization using NX Nastran simulations: 1. **Analyze Results** - Interpret optimization progress, identify trends, explain convergence 2. **Query Data** - Fetch trial data, compare configurations, find best designs 3. **Modify Settings** - Update design variable bounds, objectives, constraints 4. **Explain Concepts** - FEA, Zernike polynomials, wavefront error, stress analysis 5. **Troubleshoot** - Debug failed trials, identify issues, suggest fixes 6. **Create Studies** - Guide users through setting up new optimization studies ## Atomizer Context - Atomizer uses Optuna for Bayesian optimization - Studies can use FEA-only or hybrid FEA/Neural surrogate approaches - Results are stored in SQLite databases (study.db) - Design variables are NX expressions in CAD models - Objectives include stress, displacement, frequency, Zernike WFE ## Communication Style - Be concise but thorough - Use technical language appropriate for engineers - When showing data, format it clearly (tables, lists) - If uncertain, say so and suggest how to verify - Proactively suggest next steps or insights - Sound confident and professional - you're a specialized expert system """ # Add study-specific context if available if self.study_id and self.study_dir and self.study_dir.exists(): context = self._get_study_context() base_prompt += f"\n## Current Study: {self.study_id}\n{context}\n" else: base_prompt += "\n## Current Study: None selected\nAsk the user to select a study or help them create a new one.\n" return base_prompt def _get_study_context(self) -> str: """Get context information about the current study""" context_parts = [] # Try to load config config_path = self.study_dir / "1_setup" / "optimization_config.json" if not config_path.exists(): config_path = self.study_dir / "optimization_config.json" if config_path.exists(): try: with open(config_path) as f: config = json.load(f) # Design variables dvs = config.get('design_variables', []) if dvs: context_parts.append(f"**Design Variables ({len(dvs)})**: " + ", ".join(dv['name'] for dv in dvs[:5]) + ("..." if len(dvs) > 5 else "")) # Objectives objs = config.get('objectives', []) if objs: context_parts.append(f"**Objectives ({len(objs)})**: " + ", ".join(f"{o['name']} ({o.get('direction', 'minimize')})" for o in objs)) # Constraints constraints = config.get('constraints', []) if constraints: context_parts.append(f"**Constraints**: " + ", ".join(c['name'] for c in constraints)) except Exception: pass # Try to get trial count from database results_dir = self.study_dir / "2_results" if not results_dir.exists(): results_dir = self.study_dir / "3_results" db_path = results_dir / "study.db" if results_dir.exists() else None if db_path and db_path.exists(): try: conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'") trial_count = cursor.fetchone()[0] context_parts.append(f"**Completed Trials**: {trial_count}") # Get best value cursor.execute(""" SELECT MIN(value) FROM trial_values WHERE trial_id IN (SELECT trial_id FROM trials WHERE state='COMPLETE') """) best = cursor.fetchone()[0] if best is not None: context_parts.append(f"**Best Objective**: {best:.6f}") conn.close() except Exception: pass return "\n".join(context_parts) if context_parts else "No configuration found." def _define_tools(self) -> List[Dict[str, Any]]: """Define Atomizer-specific tools for Claude""" return [ { "name": "read_study_config", "description": "Read the optimization configuration for the current or specified study. Returns design variables, objectives, constraints, and algorithm settings.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID to read config from. Uses current study if not specified." } }, "required": [] } }, { "name": "query_trials", "description": "Query trial data from the Optuna database. Can filter by state, source (FEA/NN), objective value range, or parameter values.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID to query. Uses current study if not specified." }, "state": { "type": "string", "enum": ["COMPLETE", "PRUNED", "FAIL", "RUNNING", "all"], "description": "Filter by trial state. Default: COMPLETE" }, "source": { "type": "string", "enum": ["fea", "nn", "all"], "description": "Filter by trial source (FEA simulation or Neural Network). Default: all" }, "limit": { "type": "integer", "description": "Maximum number of trials to return. Default: 20" }, "order_by": { "type": "string", "enum": ["value_asc", "value_desc", "trial_id_asc", "trial_id_desc"], "description": "Sort order. Default: value_asc (best first)" } }, "required": [] } }, { "name": "get_trial_details", "description": "Get detailed information about a specific trial including all parameters, objective values, and user attributes.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID. Uses current study if not specified." }, "trial_id": { "type": "integer", "description": "The trial number to get details for." } }, "required": ["trial_id"] } }, { "name": "compare_trials", "description": "Compare two or more trials side-by-side, showing parameter differences and objective values.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID. Uses current study if not specified." }, "trial_ids": { "type": "array", "items": {"type": "integer"}, "description": "List of trial IDs to compare (2-5 trials)." } }, "required": ["trial_ids"] } }, { "name": "get_optimization_summary", "description": "Get a high-level summary of the optimization progress including trial counts, convergence status, best designs, and parameter sensitivity.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID. Uses current study if not specified." } }, "required": [] } }, { "name": "read_study_readme", "description": "Read the README.md documentation for a study, which contains the engineering problem description, mathematical formulation, and methodology.", "input_schema": { "type": "object", "properties": { "study_id": { "type": "string", "description": "Study ID. Uses current study if not specified." } }, "required": [] } }, { "name": "list_studies", "description": "List all available optimization studies with their status and trial counts.", "input_schema": { "type": "object", "properties": {}, "required": [] } } ] def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str: """Execute an Atomizer tool and return the result""" try: if tool_name == "read_study_config": return self._tool_read_config(tool_input.get('study_id')) elif tool_name == "query_trials": return self._tool_query_trials(tool_input) elif tool_name == "get_trial_details": return self._tool_get_trial_details(tool_input) elif tool_name == "compare_trials": return self._tool_compare_trials(tool_input) elif tool_name == "get_optimization_summary": return self._tool_get_summary(tool_input.get('study_id')) elif tool_name == "read_study_readme": return self._tool_read_readme(tool_input.get('study_id')) elif tool_name == "list_studies": return self._tool_list_studies() else: return f"Unknown tool: {tool_name}" except Exception as e: return f"Error executing {tool_name}: {str(e)}" def _get_study_dir(self, study_id: Optional[str]) -> Path: """Get study directory, using current study if not specified""" sid = study_id or self.study_id if not sid: raise ValueError("No study specified and no current study selected") study_dir = STUDIES_DIR / sid if not study_dir.exists(): raise ValueError(f"Study '{sid}' not found") return study_dir def _get_db_path(self, study_id: Optional[str]) -> Path: """Get database path for a study""" study_dir = self._get_study_dir(study_id) for results_dir_name in ["2_results", "3_results"]: db_path = study_dir / results_dir_name / "study.db" if db_path.exists(): return db_path raise ValueError(f"No database found for study") def _tool_read_config(self, study_id: Optional[str]) -> str: """Read study configuration""" study_dir = self._get_study_dir(study_id) config_path = study_dir / "1_setup" / "optimization_config.json" if not config_path.exists(): config_path = study_dir / "optimization_config.json" if not config_path.exists(): return "No configuration file found for this study." with open(config_path) as f: config = json.load(f) # Format nicely result = [f"# Configuration for {study_id or self.study_id}\n"] # Design variables dvs = config.get('design_variables', []) if dvs: result.append("## Design Variables") result.append("| Name | Min | Max | Baseline | Units |") result.append("|------|-----|-----|----------|-------|") for dv in dvs: result.append(f"| {dv['name']} | {dv.get('min', '-')} | {dv.get('max', '-')} | {dv.get('baseline', '-')} | {dv.get('units', '-')} |") # Objectives objs = config.get('objectives', []) if objs: result.append("\n## Objectives") result.append("| Name | Direction | Weight | Target | Units |") result.append("|------|-----------|--------|--------|-------|") for obj in objs: result.append(f"| {obj['name']} | {obj.get('direction', 'minimize')} | {obj.get('weight', 1.0)} | {obj.get('target', '-')} | {obj.get('units', '-')} |") # Constraints constraints = config.get('constraints', []) if constraints: result.append("\n## Constraints") for c in constraints: result.append(f"- **{c['name']}**: {c.get('type', 'bound')} {c.get('max_value', c.get('min_value', ''))} {c.get('units', '')}") return "\n".join(result) def _tool_query_trials(self, params: Dict[str, Any]) -> str: """Query trials from database""" db_path = self._get_db_path(params.get('study_id')) state = params.get('state', 'COMPLETE') source = params.get('source', 'all') limit = params.get('limit', 20) order_by = params.get('order_by', 'value_asc') conn = sqlite3.connect(str(db_path)) conn.row_factory = sqlite3.Row cursor = conn.cursor() # Build query query = """ SELECT t.trial_id, t.state, tv.value, GROUP_CONCAT(tp.param_name || '=' || ROUND(tp.param_value, 4), ', ') as params FROM trials t LEFT JOIN trial_values tv ON t.trial_id = tv.trial_id LEFT JOIN trial_params tp ON t.trial_id = tp.trial_id """ conditions = [] if state != 'all': conditions.append(f"t.state = '{state}'") if conditions: query += " WHERE " + " AND ".join(conditions) query += " GROUP BY t.trial_id" # Order if order_by == 'value_asc': query += " ORDER BY tv.value ASC" elif order_by == 'value_desc': query += " ORDER BY tv.value DESC" elif order_by == 'trial_id_desc': query += " ORDER BY t.trial_id DESC" else: query += " ORDER BY t.trial_id ASC" query += f" LIMIT {limit}" cursor.execute(query) rows = cursor.fetchall() conn.close() if not rows: return "No trials found matching the criteria." # Filter by source if needed (check user_attrs) if source != 'all': # Would need another query to filter by trial_source attr pass # Format results result = [f"# Trials (showing {len(rows)}/{limit} max)\n"] result.append("| Trial | State | Objective | Parameters |") result.append("|-------|-------|-----------|------------|") for row in rows: value = f"{row['value']:.6f}" if row['value'] else "N/A" params = row['params'][:50] + "..." if row['params'] and len(row['params']) > 50 else (row['params'] or "") result.append(f"| {row['trial_id']} | {row['state']} | {value} | {params} |") return "\n".join(result) def _tool_get_trial_details(self, params: Dict[str, Any]) -> str: """Get detailed trial information""" db_path = self._get_db_path(params.get('study_id')) trial_id = params['trial_id'] conn = sqlite3.connect(str(db_path)) conn.row_factory = sqlite3.Row cursor = conn.cursor() # Get trial info cursor.execute("SELECT * FROM trials WHERE trial_id = ?", (trial_id,)) trial = cursor.fetchone() if not trial: conn.close() return f"Trial {trial_id} not found." result = [f"# Trial {trial_id} Details\n"] result.append(f"**State**: {trial['state']}") # Get objective value cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (trial_id,)) value_row = cursor.fetchone() if value_row: result.append(f"**Objective Value**: {value_row['value']:.6f}") # Get parameters cursor.execute("SELECT param_name, param_value FROM trial_params WHERE trial_id = ? ORDER BY param_name", (trial_id,)) params_rows = cursor.fetchall() if params_rows: result.append("\n## Parameters") result.append("| Parameter | Value |") result.append("|-----------|-------|") for p in params_rows: result.append(f"| {p['param_name']} | {p['param_value']:.6f} |") # Get user attributes cursor.execute("SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ?", (trial_id,)) attrs = cursor.fetchall() if attrs: result.append("\n## Attributes") for attr in attrs: try: value = json.loads(attr['value_json']) if isinstance(value, float): result.append(f"- **{attr['key']}**: {value:.6f}") else: result.append(f"- **{attr['key']}**: {value}") except: result.append(f"- **{attr['key']}**: {attr['value_json']}") conn.close() return "\n".join(result) def _tool_compare_trials(self, params: Dict[str, Any]) -> str: """Compare multiple trials""" db_path = self._get_db_path(params.get('study_id')) trial_ids = params['trial_ids'] if len(trial_ids) < 2: return "Need at least 2 trials to compare." if len(trial_ids) > 5: return "Maximum 5 trials for comparison." conn = sqlite3.connect(str(db_path)) conn.row_factory = sqlite3.Row cursor = conn.cursor() result = ["# Trial Comparison\n"] # Get all parameter names cursor.execute("SELECT DISTINCT param_name FROM trial_params ORDER BY param_name") param_names = [row['param_name'] for row in cursor.fetchall()] # Build comparison table header header = "| Parameter | " + " | ".join(f"Trial {tid}" for tid in trial_ids) + " |" separator = "|-----------|" + "|".join("-" * 10 for _ in trial_ids) + "|" result.append(header) result.append(separator) # Objective values row obj_values = [] for tid in trial_ids: cursor.execute("SELECT value FROM trial_values WHERE trial_id = ?", (tid,)) row = cursor.fetchone() obj_values.append(f"{row['value']:.4f}" if row else "N/A") result.append("| **Objective** | " + " | ".join(obj_values) + " |") # Parameter rows for pname in param_names: values = [] for tid in trial_ids: cursor.execute("SELECT param_value FROM trial_params WHERE trial_id = ? AND param_name = ?", (tid, pname)) row = cursor.fetchone() values.append(f"{row['param_value']:.4f}" if row else "N/A") result.append(f"| {pname} | " + " | ".join(values) + " |") conn.close() return "\n".join(result) def _tool_get_summary(self, study_id: Optional[str]) -> str: """Get optimization summary""" db_path = self._get_db_path(study_id) conn = sqlite3.connect(str(db_path)) conn.row_factory = sqlite3.Row cursor = conn.cursor() result = [f"# Optimization Summary\n"] # Trial counts by state cursor.execute("SELECT state, COUNT(*) as count FROM trials GROUP BY state") states = {row['state']: row['count'] for row in cursor.fetchall()} result.append("## Trial Counts") total = sum(states.values()) result.append(f"- **Total**: {total}") for state, count in states.items(): result.append(f"- {state}: {count}") # Best trial cursor.execute(""" SELECT t.trial_id, tv.value FROM trials t JOIN trial_values tv ON t.trial_id = tv.trial_id WHERE t.state = 'COMPLETE' ORDER BY tv.value ASC LIMIT 1 """) best = cursor.fetchone() if best: result.append(f"\n## Best Trial") result.append(f"- **Trial ID**: {best['trial_id']}") result.append(f"- **Objective**: {best['value']:.6f}") # FEA vs NN counts cursor.execute(""" SELECT value_json, COUNT(*) as count FROM trial_user_attributes WHERE key = 'trial_source' GROUP BY value_json """) sources = cursor.fetchall() if sources: result.append("\n## Trial Sources") for src in sources: source_name = json.loads(src['value_json']) if src['value_json'] else 'unknown' result.append(f"- **{source_name}**: {src['count']}") conn.close() return "\n".join(result) def _tool_read_readme(self, study_id: Optional[str]) -> str: """Read study README""" study_dir = self._get_study_dir(study_id) readme_path = study_dir / "README.md" if not readme_path.exists(): return "No README.md found for this study." content = readme_path.read_text(encoding='utf-8') # Truncate if too long if len(content) > 8000: content = content[:8000] + "\n\n... (truncated)" return content def _tool_list_studies(self) -> str: """List all studies""" if not STUDIES_DIR.exists(): return "Studies directory not found." result = ["# Available Studies\n"] result.append("| Study | Status | Trials |") result.append("|-------|--------|--------|") for study_dir in sorted(STUDIES_DIR.iterdir()): if not study_dir.is_dir(): continue study_id = study_dir.name # Check for database trial_count = 0 for results_dir_name in ["2_results", "3_results"]: db_path = study_dir / results_dir_name / "study.db" if db_path.exists(): try: conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM trials WHERE state='COMPLETE'") trial_count = cursor.fetchone()[0] conn.close() except: pass break # Determine status status = "ready" if trial_count > 0 else "not_started" result.append(f"| {study_id} | {status} | {trial_count} |") return "\n".join(result) async def chat(self, message: str, conversation_history: Optional[List[Dict]] = None) -> Dict[str, Any]: """ Process a chat message with tool use support Args: message: User's message conversation_history: Previous messages for context Returns: Dict with response text and any tool calls made """ messages = conversation_history.copy() if conversation_history else [] messages.append({"role": "user", "content": message}) tool_calls_made = [] # Loop to handle tool use while True: response = self.client.messages.create( model="claude-sonnet-4-20250514", max_tokens=4096, system=self.system_prompt, tools=self.tools, messages=messages ) # Check if we need to handle tool use if response.stop_reason == "tool_use": # Process tool calls assistant_content = response.content tool_results = [] for block in assistant_content: if block.type == "tool_use": tool_name = block.name tool_input = block.input tool_id = block.id # Execute the tool result = self._execute_tool(tool_name, tool_input) tool_calls_made.append({ "tool": tool_name, "input": tool_input, "result_preview": result[:200] + "..." if len(result) > 200 else result }) tool_results.append({ "type": "tool_result", "tool_use_id": tool_id, "content": result }) # Add assistant response and tool results to messages messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "user", "content": tool_results}) else: # No more tool use, extract final response final_text = "" for block in response.content: if hasattr(block, 'text'): final_text += block.text return { "response": final_text, "tool_calls": tool_calls_made, "conversation": messages + [{"role": "assistant", "content": response.content}] } async def chat_stream(self, message: str, conversation_history: Optional[List[Dict]] = None) -> AsyncGenerator[str, None]: """ Stream a chat response token by token Args: message: User's message conversation_history: Previous messages Yields: Response tokens as they arrive """ messages = conversation_history.copy() if conversation_history else [] messages.append({"role": "user", "content": message}) # For streaming, we'll do a simpler approach without tool use for now # (Tool use with streaming is more complex) with self.client.messages.stream( model="claude-sonnet-4-20250514", max_tokens=4096, system=self.system_prompt, messages=messages ) as stream: for text in stream.text_stream: yield text