feat: Add Analysis page, run comparison, notifications, and config editor

Dashboard enhancements:
- Add Analysis page with tabs: Overview, Parameters, Pareto, Correlations, Constraints, Surrogate, Runs
- Add PlotlyCorrelationHeatmap for parameter-objective correlation analysis
- Add PlotlyFeasibilityChart for constraint satisfaction visualization
- Add PlotlySurrogateQuality for FEA vs NN prediction comparison
- Add PlotlyRunComparison for comparing optimization runs within a study

Real-time improvements:
- Replace watchdog file-watching with SQLite database polling for better Windows reliability
- Add DatabasePoller class with 2-second polling interval
- Enhanced WebSocket messages: trial_completed, new_best, pareto_update, progress

Desktop notifications:
- Add useNotifications hook using Web Notifications API
- Add NotificationSettings toggle component
- Notify users when new best solutions are found

Config editor:
- Add PUT /studies/{study_id}/config endpoint with auto-backup
- Add ConfigEditor modal with tabs: General, Variables, Objectives, Settings, JSON
- Prevents editing while optimization is running

Enhanced Pareto visualization:
- Add dark mode styling with transparent backgrounds
- Add stats bar showing Pareto, FEA, NN, and infeasible counts
- Add Pareto front connecting line for 2D view
- Add table showing top 10 Pareto-optimal solutions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Antoine
2025-12-05 19:57:20 -05:00
parent 5c660ff270
commit 5fb94fdf01
27 changed files with 5878 additions and 722 deletions

View File

@@ -1794,3 +1794,563 @@ run_server("{storage_url}", host="0.0.0.0", port={port})
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to launch Optuna dashboard: {str(e)}")
# ============================================================================
# Model Files Endpoint
# ============================================================================
@router.get("/studies/{study_id}/model-files")
async def get_model_files(study_id: str):
"""
Get list of NX model files (.prt, .sim, .fem, .bdf, .dat, .op2) for a study
Args:
study_id: Study identifier
Returns:
JSON with list of model files and their paths
"""
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for model directory (check multiple locations)
model_dirs = [
study_dir / "1_setup" / "model",
study_dir / "model",
study_dir / "1_setup",
study_dir
]
model_files = []
model_dir_path = None
# NX and FEA file extensions to look for
nx_extensions = {'.prt', '.sim', '.fem', '.bdf', '.dat', '.op2', '.f06', '.inp'}
for model_dir in model_dirs:
if model_dir.exists() and model_dir.is_dir():
for file_path in model_dir.iterdir():
if file_path.is_file() and file_path.suffix.lower() in nx_extensions:
model_files.append({
"name": file_path.name,
"path": str(file_path),
"extension": file_path.suffix.lower(),
"size_bytes": file_path.stat().st_size,
"size_display": _format_file_size(file_path.stat().st_size),
"modified": datetime.fromtimestamp(file_path.stat().st_mtime).isoformat()
})
if model_dir_path is None:
model_dir_path = str(model_dir)
# Sort by extension for better display (prt first, then sim, fem, etc.)
extension_order = {'.prt': 0, '.sim': 1, '.fem': 2, '.bdf': 3, '.dat': 4, '.op2': 5, '.f06': 6, '.inp': 7}
model_files.sort(key=lambda x: (extension_order.get(x['extension'], 99), x['name']))
return {
"study_id": study_id,
"model_dir": model_dir_path or str(study_dir / "1_setup" / "model"),
"files": model_files,
"count": len(model_files)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get model files: {str(e)}")
def _format_file_size(size_bytes: int) -> str:
"""Format file size in human-readable form"""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024 * 1024:
return f"{size_bytes / 1024:.1f} KB"
elif size_bytes < 1024 * 1024 * 1024:
return f"{size_bytes / (1024 * 1024):.1f} MB"
else:
return f"{size_bytes / (1024 * 1024 * 1024):.2f} GB"
@router.post("/studies/{study_id}/open-folder")
async def open_model_folder(study_id: str, folder_type: str = "model"):
"""
Open the model folder in system file explorer
Args:
study_id: Study identifier
folder_type: Type of folder to open (model, results, setup)
Returns:
JSON with success status
"""
import os
import platform
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Determine which folder to open
if folder_type == "model":
target_dir = study_dir / "1_setup" / "model"
if not target_dir.exists():
target_dir = study_dir / "1_setup"
elif folder_type == "results":
target_dir = get_results_dir(study_dir)
elif folder_type == "setup":
target_dir = study_dir / "1_setup"
else:
target_dir = study_dir
if not target_dir.exists():
target_dir = study_dir
# Open in file explorer based on platform
system = platform.system()
try:
if system == "Windows":
os.startfile(str(target_dir))
elif system == "Darwin": # macOS
subprocess.Popen(["open", str(target_dir)])
else: # Linux
subprocess.Popen(["xdg-open", str(target_dir)])
return {
"success": True,
"message": f"Opened {target_dir}",
"path": str(target_dir)
}
except Exception as e:
return {
"success": False,
"message": f"Failed to open folder: {str(e)}",
"path": str(target_dir)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to open folder: {str(e)}")
@router.get("/studies/{study_id}/best-solution")
async def get_best_solution(study_id: str):
"""Get the best trial(s) for a study with improvement metrics"""
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
return {
"study_id": study_id,
"best_trial": None,
"first_trial": None,
"improvements": {},
"total_trials": 0
}
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get best trial (single objective - minimize by default)
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective,
datetime(tv.value_id, 'unixepoch') as timestamp
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY tv.value ASC
LIMIT 1
""")
best_row = cursor.fetchone()
# Get first completed trial for comparison
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY t.number ASC
LIMIT 1
""")
first_row = cursor.fetchone()
# Get total trial count
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
total_trials = cursor.fetchone()[0]
best_trial = None
first_trial = None
improvements = {}
if best_row:
best_trial_id = best_row['trial_id']
# Get design variables
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (best_trial_id,))
params = {row['param_name']: row['param_value'] for row in cursor.fetchall()}
# Get user attributes (including results)
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (best_trial_id,))
user_attrs = {}
for row in cursor.fetchall():
try:
user_attrs[row['key']] = json.loads(row['value_json'])
except:
user_attrs[row['key']] = row['value_json']
best_trial = {
"trial_number": best_row['number'],
"objective": best_row['objective'],
"design_variables": params,
"user_attrs": user_attrs,
"timestamp": best_row['timestamp']
}
if first_row:
first_trial_id = first_row['trial_id']
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (first_trial_id,))
first_params = {row['param_name']: row['param_value'] for row in cursor.fetchall()}
first_trial = {
"trial_number": first_row['number'],
"objective": first_row['objective'],
"design_variables": first_params
}
# Calculate improvement
if best_row and first_row['objective'] != 0:
improvement_pct = ((first_row['objective'] - best_row['objective']) / abs(first_row['objective'])) * 100
improvements["objective"] = {
"initial": first_row['objective'],
"final": best_row['objective'],
"improvement_pct": round(improvement_pct, 2),
"absolute_change": round(first_row['objective'] - best_row['objective'], 6)
}
conn.close()
return {
"study_id": study_id,
"best_trial": best_trial,
"first_trial": first_trial,
"improvements": improvements,
"total_trials": total_trials
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get best solution: {str(e)}")
@router.get("/studies/{study_id}/runs")
async def get_study_runs(study_id: str):
"""
Get all optimization runs/studies in the database for comparison.
Many studies have multiple Optuna studies (e.g., v11_fea, v11_iter1_nn, v11_iter2_nn).
This endpoint returns metrics for each sub-study.
"""
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
return {"runs": [], "total_runs": 0}
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all Optuna studies in this database
cursor.execute("""
SELECT study_id, study_name
FROM studies
ORDER BY study_id
""")
studies = cursor.fetchall()
runs = []
for study_row in studies:
optuna_study_id = study_row['study_id']
study_name = study_row['study_name']
# Get trial count
cursor.execute("""
SELECT COUNT(*) FROM trials
WHERE study_id = ? AND state = 'COMPLETE'
""", (optuna_study_id,))
trial_count = cursor.fetchone()[0]
if trial_count == 0:
continue
# Get best value (first objective)
cursor.execute("""
SELECT MIN(tv.value) as best_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0
""", (optuna_study_id,))
best_result = cursor.fetchone()
best_value = best_result['best_value'] if best_result else None
# Get average value
cursor.execute("""
SELECT AVG(tv.value) as avg_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0
""", (optuna_study_id,))
avg_result = cursor.fetchone()
avg_value = avg_result['avg_value'] if avg_result else None
# Get time range
cursor.execute("""
SELECT MIN(datetime_start) as first_trial, MAX(datetime_complete) as last_trial
FROM trials
WHERE study_id = ? AND state = 'COMPLETE'
""", (optuna_study_id,))
time_result = cursor.fetchone()
# Determine source type (FEA or NN)
source = "NN" if "_nn" in study_name.lower() else "FEA"
runs.append({
"run_id": optuna_study_id,
"name": study_name,
"source": source,
"trial_count": trial_count,
"best_value": best_value,
"avg_value": avg_value,
"first_trial": time_result['first_trial'] if time_result else None,
"last_trial": time_result['last_trial'] if time_result else None
})
conn.close()
return {
"runs": runs,
"total_runs": len(runs),
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get runs: {str(e)}")
class UpdateConfigRequest(BaseModel):
config: dict
@router.put("/studies/{study_id}/config")
async def update_study_config(study_id: str, request: UpdateConfigRequest):
"""
Update the optimization_config.json for a study
Args:
study_id: Study identifier
request: New configuration data
Returns:
JSON with success status
"""
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Check if optimization is running - don't allow config changes while running
if is_optimization_running(study_id):
raise HTTPException(
status_code=409,
detail="Cannot modify config while optimization is running. Stop the optimization first."
)
# Find config file location
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}")
# Backup existing config
backup_file = config_file.with_suffix('.json.backup')
shutil.copy(config_file, backup_file)
# Write new config
with open(config_file, 'w') as f:
json.dump(request.config, f, indent=2)
return {
"success": True,
"message": "Configuration updated successfully",
"path": str(config_file),
"backup_path": str(backup_file)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update config: {str(e)}")
@router.get("/studies/{study_id}/export/{format}")
async def export_study_data(study_id: str, format: str):
"""Export study data in various formats: csv, json, excel"""
try:
study_dir = STUDIES_DIR / study_id
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
raise HTTPException(status_code=404, detail="No study data available")
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all completed trials with their params and values
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY t.number
""")
trials_data = []
for row in cursor.fetchall():
trial_id = row['trial_id']
# Get params
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {r['param_name']: r['param_value'] for r in cursor.fetchall()}
# Get user attrs
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for r in cursor.fetchall():
try:
user_attrs[r['key']] = json.loads(r['value_json'])
except:
user_attrs[r['key']] = r['value_json']
trials_data.append({
"trial_number": row['number'],
"objective": row['objective'],
"params": params,
"user_attrs": user_attrs
})
conn.close()
if format.lower() == "json":
return JSONResponse(content={
"study_id": study_id,
"total_trials": len(trials_data),
"trials": trials_data
})
elif format.lower() == "csv":
import io
import csv
if not trials_data:
return JSONResponse(content={"error": "No data to export"})
# Build CSV
output = io.StringIO()
# Get all param names
param_names = sorted(set(
key for trial in trials_data
for key in trial['params'].keys()
))
fieldnames = ['trial_number', 'objective'] + param_names
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()
for trial in trials_data:
row_data = {
'trial_number': trial['trial_number'],
'objective': trial['objective']
}
row_data.update(trial['params'])
writer.writerow(row_data)
csv_content = output.getvalue()
return JSONResponse(content={
"filename": f"{study_id}_data.csv",
"content": csv_content,
"content_type": "text/csv"
})
elif format.lower() == "config":
# Export optimization config
setup_dir = study_dir / "1_setup"
config_path = setup_dir / "optimization_config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
return JSONResponse(content={
"filename": f"{study_id}_config.json",
"content": json.dumps(config, indent=2),
"content_type": "application/json"
})
else:
raise HTTPException(status_code=404, detail="Config file not found")
else:
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to export data: {str(e)}")

View File

@@ -19,6 +19,82 @@ router = APIRouter()
# Store active terminal sessions
_terminal_sessions: dict = {}
# Path to Atomizer root (for loading prompts)
ATOMIZER_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)))
def get_session_prompt(study_name: str = None) -> str:
"""
Generate the initial prompt for a Claude session.
This injects the Protocol Operating System context and study-specific info.
"""
prompt_lines = [
"# Atomizer Session Context",
"",
"You are assisting with **Atomizer** - an LLM-first FEA optimization framework.",
"",
"## Bootstrap (READ FIRST)",
"",
"Read these files to understand how to help:",
"- `.claude/skills/00_BOOTSTRAP.md` - Task classification and routing",
"- `.claude/skills/01_CHEATSHEET.md` - Quick reference (I want X → Use Y)",
"- `.claude/skills/02_CONTEXT_LOADER.md` - What to load per task",
"",
"## Protocol System",
"",
"| Layer | Location | Purpose |",
"|-------|----------|---------|",
"| Operations | `docs/protocols/operations/OP_*.md` | How-to guides |",
"| System | `docs/protocols/system/SYS_*.md` | Core specs |",
"| Extensions | `docs/protocols/extensions/EXT_*.md` | Adding features |",
"",
]
if study_name:
prompt_lines.extend([
f"## Current Study: `{study_name}`",
"",
f"**Directory**: `studies/{study_name}/`",
"",
"Key files:",
f"- `studies/{study_name}/1_setup/optimization_config.json` - Configuration",
f"- `studies/{study_name}/2_results/study.db` - Optuna database",
f"- `studies/{study_name}/README.md` - Study documentation",
"",
"Quick status check:",
"```bash",
f"python -c \"import optuna; s=optuna.load_study('{study_name}', 'sqlite:///studies/{study_name}/2_results/study.db'); print(f'Trials: {{len(s.trials)}}, Best: {{s.best_value}}')\"",
"```",
"",
])
else:
prompt_lines.extend([
"## No Study Selected",
"",
"No specific study context. You can:",
"- List studies: `ls studies/`",
"- Create new study: Ask user what they want to optimize",
"- Load context: Read `.claude/skills/core/study-creation-core.md`",
"",
])
prompt_lines.extend([
"## Key Principles",
"",
"1. **Read bootstrap first** - Follow task routing from 00_BOOTSTRAP.md",
"2. **Use centralized extractors** - Check `optimization_engine/extractors/`",
"3. **Never modify master models** - Work on copies",
"4. **Python env**: Always use `conda activate atomizer`",
"",
"---",
"*Session launched from Atomizer Dashboard*",
])
return "\n".join(prompt_lines)
# Check if winpty is available (for Windows)
try:
from winpty import PtyProcess
@@ -44,9 +120,6 @@ class TerminalSession:
self.websocket = websocket
self._running = True
# Determine the claude command
claude_cmd = "claude"
try:
if self._use_winpty:
# Use winpty for proper PTY on Windows
@@ -306,14 +379,13 @@ async def claude_terminal(websocket: WebSocket, working_dir: str = None, study_i
{"type": "output", "data": "terminal output"}
{"type": "exit", "code": 0}
{"type": "error", "message": "..."}
{"type": "context", "prompt": "..."} # Initial context prompt
"""
await websocket.accept()
# Default to Atomizer root directory
if not working_dir:
working_dir = str(os.path.dirname(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
))))
working_dir = ATOMIZER_ROOT
# Create session
session_id = f"claude-{id(websocket)}"
@@ -321,13 +393,24 @@ async def claude_terminal(websocket: WebSocket, working_dir: str = None, study_i
_terminal_sessions[session_id] = session
try:
# Send context prompt to frontend (for display/reference)
context_prompt = get_session_prompt(study_id)
await websocket.send_json({
"type": "context",
"prompt": context_prompt,
"study_id": study_id
})
# Start Claude Code
await session.start(websocket)
# Note: Claude is started in Atomizer root directory so it has access to:
# - CLAUDE.md (system instructions)
# - .claude/skills/ (skill definitions)
# The study_id is available for the user to reference in their prompts
# If study_id provided, send initial context to Claude after startup
if study_id:
# Wait a moment for Claude to initialize
await asyncio.sleep(1.0)
# Send the context as the first message
initial_message = f"I'm working with the Atomizer study '{study_id}'. Please read .claude/skills/00_BOOTSTRAP.md first to understand the Protocol Operating System, then help me with this study.\n"
await session.write(initial_message)
# Handle incoming messages
while session._running:
@@ -370,3 +453,31 @@ async def terminal_status():
"winpty_available": HAS_WINPTY,
"message": "Claude Code CLI is available" if claude_path else "Claude Code CLI not found. Install with: npm install -g @anthropic-ai/claude-code"
}
@router.get("/context")
async def get_context(study_id: str = None):
"""
Get the context prompt for a Claude session without starting a terminal.
Useful for displaying context in the UI or preparing prompts.
Query params:
study_id: Optional study ID to include study-specific context
"""
prompt = get_session_prompt(study_id)
return {
"study_id": study_id,
"prompt": prompt,
"bootstrap_files": [
".claude/skills/00_BOOTSTRAP.md",
".claude/skills/01_CHEATSHEET.md",
".claude/skills/02_CONTEXT_LOADER.md",
],
"study_files": [
f"studies/{study_id}/1_setup/optimization_config.json",
f"studies/{study_id}/2_results/study.db",
f"studies/{study_id}/README.md",
] if study_id else []
}

View File

@@ -1,10 +1,10 @@
import asyncio
import json
import sqlite3
from pathlib import Path
from typing import Dict, Set
from typing import Dict, Set, Optional, Any, List
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import aiofiles
router = APIRouter()
@@ -12,185 +12,440 @@ router = APIRouter()
# Base studies directory
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
class OptimizationFileHandler(FileSystemEventHandler):
def get_results_dir(study_dir: Path) -> Path:
"""Get the results directory for a study, supporting both 2_results and 3_results."""
results_dir = study_dir / "2_results"
if not results_dir.exists():
results_dir = study_dir / "3_results"
return results_dir
class DatabasePoller:
"""
Polls the Optuna SQLite database for changes.
More reliable than file watching, especially on Windows.
"""
def __init__(self, study_id: str, callback):
self.study_id = study_id
self.callback = callback
self.last_trial_count = 0
self.last_pruned_count = 0
self.last_trial_id = 0
self.last_best_value: Optional[float] = None
self.last_pareto_count = 0
self.last_state_timestamp = ""
self.running = False
self._task: Optional[asyncio.Task] = None
def on_modified(self, event):
if event.src_path.endswith("optimization_history_incremental.json"):
asyncio.run(self.process_history_update(event.src_path))
elif event.src_path.endswith("pruning_history.json"):
asyncio.run(self.process_pruning_update(event.src_path))
elif event.src_path.endswith("study.db"): # Watch for Optuna DB changes (Pareto front)
asyncio.run(self.process_pareto_update(event.src_path))
elif event.src_path.endswith("optimizer_state.json"):
asyncio.run(self.process_state_update(event.src_path))
async def start(self):
"""Start the polling loop"""
self.running = True
self._task = asyncio.create_task(self._poll_loop())
async def process_history_update(self, file_path):
async def stop(self):
"""Stop the polling loop"""
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def _poll_loop(self):
"""Main polling loop - checks database every 2 seconds"""
study_dir = STUDIES_DIR / self.study_id
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
while self.running:
try:
if db_path.exists():
await self._check_database(db_path)
await asyncio.sleep(2) # Poll every 2 seconds
except asyncio.CancelledError:
break
except Exception as e:
print(f"[WebSocket] Polling error for {self.study_id}: {e}")
await asyncio.sleep(5) # Back off on error
async def _check_database(self, db_path: Path):
"""Check database for new trials and updates"""
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
history = json.loads(content)
current_count = len(history)
if current_count > self.last_trial_count:
# New trials added
new_trials = history[self.last_trial_count:]
for trial in new_trials:
await self.callback({
"type": "trial_completed",
"data": trial
})
self.last_trial_count = current_count
# Use timeout to avoid blocking on locked databases
conn = sqlite3.connect(str(db_path), timeout=2.0)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get total completed trial count
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
total_count = cursor.fetchone()[0]
# Check for new trials
if total_count > self.last_trial_count:
await self._process_new_trials(cursor, total_count)
# Check for new best value
await self._check_best_value(cursor)
# Check Pareto front for multi-objective
await self._check_pareto_front(cursor, db_path)
# Send progress update
await self._send_progress(cursor, total_count)
conn.close()
except sqlite3.OperationalError as e:
# Database locked - skip this poll
pass
except Exception as e:
print(f"Error processing history update: {e}")
print(f"[WebSocket] Database check error: {e}")
async def process_pruning_update(self, file_path):
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
history = json.loads(content)
current_count = len(history)
if current_count > self.last_pruned_count:
# New pruned trials
new_pruned = history[self.last_pruned_count:]
for trial in new_pruned:
await self.callback({
"type": "trial_pruned",
"data": trial
})
self.last_pruned_count = current_count
except Exception as e:
print(f"Error processing pruning update: {e}")
async def _process_new_trials(self, cursor, total_count: int):
"""Process and broadcast new trials"""
# Get new trials since last check
cursor.execute("""
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE' AND t.trial_id > ?
ORDER BY t.trial_id ASC
""", (self.last_trial_id,))
new_trials = cursor.fetchall()
for row in new_trials:
trial_id = row['trial_id']
trial_data = await self._build_trial_data(cursor, row)
await self.callback({
"type": "trial_completed",
"data": trial_data
})
self.last_trial_id = trial_id
self.last_trial_count = total_count
async def _build_trial_data(self, cursor, row) -> Dict[str, Any]:
"""Build trial data dictionary from database row"""
trial_id = row['trial_id']
# Get objectives
cursor.execute("""
SELECT value FROM trial_values
WHERE trial_id = ? ORDER BY objective
""", (trial_id,))
values = [r[0] for r in cursor.fetchall()]
# Get parameters
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {}
for r in cursor.fetchall():
try:
params[r[0]] = float(r[1]) if r[1] is not None else None
except (ValueError, TypeError):
params[r[0]] = r[1]
# Get user attributes
cursor.execute("""
SELECT key, value_json FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for r in cursor.fetchall():
try:
user_attrs[r[0]] = json.loads(r[1])
except (ValueError, TypeError):
user_attrs[r[0]] = r[1]
# Extract source and design vars
source = user_attrs.get("source", "FEA")
design_vars = user_attrs.get("design_vars", params)
return {
"trial_number": trial_id, # Use trial_id for uniqueness
"trial_num": row['number'],
"objective": values[0] if values else None,
"values": values,
"params": design_vars,
"user_attrs": user_attrs,
"source": source,
"start_time": row['datetime_start'],
"end_time": row['datetime_complete'],
"study_name": row['study_name'],
"constraint_satisfied": user_attrs.get("constraint_satisfied", True)
}
async def _check_best_value(self, cursor):
"""Check for new best value and broadcast if changed"""
cursor.execute("""
SELECT MIN(tv.value) as best_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0
""")
result = cursor.fetchone()
if result and result['best_value'] is not None:
best_value = result['best_value']
if self.last_best_value is None or best_value < self.last_best_value:
# Get the best trial details
cursor.execute("""
SELECT t.trial_id, t.number
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0 AND tv.value = ?
LIMIT 1
""", (best_value,))
best_row = cursor.fetchone()
if best_row:
# Get params for best trial
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (best_row['trial_id'],))
params = {r[0]: r[1] for r in cursor.fetchall()}
async def process_pareto_update(self, file_path):
# This is tricky because study.db is binary.
# Instead of reading it directly, we'll trigger a re-fetch of the Pareto front via Optuna
# We debounce this to avoid excessive reads
try:
# Import here to avoid circular imports or heavy load at startup
import optuna
# Connect to DB
storage = optuna.storages.RDBStorage(f"sqlite:///{file_path}")
study = optuna.load_study(study_name=self.study_id, storage=storage)
# Check if multi-objective
if len(study.directions) > 1:
pareto_trials = study.best_trials
# Only broadcast if count changed (simple heuristic)
# In a real app, we might check content hash
if len(pareto_trials) != self.last_pareto_count:
pareto_data = [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"user_attrs": dict(t.user_attrs),
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True)
}
for t in pareto_trials
]
await self.callback({
"type": "pareto_front",
"type": "new_best",
"data": {
"pareto_front": pareto_data,
"count": len(pareto_trials)
"trial_number": best_row['trial_id'],
"value": best_value,
"params": params,
"improvement": (
((self.last_best_value - best_value) / abs(self.last_best_value) * 100)
if self.last_best_value else 0
)
}
})
self.last_pareto_count = len(pareto_trials)
self.last_best_value = best_value
async def _check_pareto_front(self, cursor, db_path: Path):
"""Check for Pareto front updates in multi-objective studies"""
try:
# Check if multi-objective by counting distinct objectives
cursor.execute("""
SELECT COUNT(DISTINCT objective) as obj_count
FROM trial_values
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state = 'COMPLETE')
""")
result = cursor.fetchone()
if result and result['obj_count'] > 1:
# Multi-objective - compute Pareto front
import optuna
storage = optuna.storages.RDBStorage(f"sqlite:///{db_path}")
# Get all study names
cursor.execute("SELECT study_name FROM studies")
study_names = [r[0] for r in cursor.fetchall()]
for study_name in study_names:
try:
study = optuna.load_study(study_name=study_name, storage=storage)
if len(study.directions) > 1:
pareto_trials = study.best_trials
if len(pareto_trials) != self.last_pareto_count:
pareto_data = [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True),
"source": t.user_attrs.get("source", "FEA")
}
for t in pareto_trials
]
await self.callback({
"type": "pareto_update",
"data": {
"pareto_front": pareto_data,
"count": len(pareto_trials)
}
})
self.last_pareto_count = len(pareto_trials)
break
except:
continue
except Exception as e:
# DB might be locked, ignore transient errors
# Non-critical - skip Pareto check
pass
async def process_state_update(self, file_path):
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
state = json.loads(content)
# Check timestamp to avoid duplicate broadcasts
if state.get("timestamp") != self.last_state_timestamp:
await self.callback({
"type": "optimizer_state",
"data": state
})
self.last_state_timestamp = state.get("timestamp")
except Exception as e:
print(f"Error processing state update: {e}")
async def _send_progress(self, cursor, total_count: int):
"""Send progress update"""
# Get total from config if available
study_dir = STUDIES_DIR / self.study_id
config_path = study_dir / "1_setup" / "optimization_config.json"
if not config_path.exists():
config_path = study_dir / "optimization_config.json"
total_target = 100 # Default
if config_path.exists():
try:
with open(config_path) as f:
config = json.load(f)
total_target = config.get('optimization_settings', {}).get('n_trials', 100)
except:
pass
# Count FEA vs NN trials
cursor.execute("""
SELECT
COUNT(CASE WHEN s.study_name LIKE '%_nn%' THEN 1 END) as nn_count,
COUNT(CASE WHEN s.study_name NOT LIKE '%_nn%' THEN 1 END) as fea_count
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
""")
counts = cursor.fetchone()
await self.callback({
"type": "progress",
"data": {
"current": total_count,
"total": total_target,
"percentage": min(100, (total_count / total_target * 100)) if total_target > 0 else 0,
"fea_count": counts['fea_count'] if counts else total_count,
"nn_count": counts['nn_count'] if counts else 0,
"timestamp": datetime.now().isoformat()
}
})
class ConnectionManager:
"""
Manages WebSocket connections and database pollers for real-time updates.
Uses database polling instead of file watching for better reliability on Windows.
"""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
self.observers: Dict[str, Observer] = {}
self.pollers: Dict[str, DatabasePoller] = {}
async def connect(self, websocket: WebSocket, study_id: str):
"""Connect a new WebSocket client"""
await websocket.accept()
if study_id not in self.active_connections:
self.active_connections[study_id] = set()
self.start_watching(study_id)
self.active_connections[study_id].add(websocket)
def disconnect(self, websocket: WebSocket, study_id: str):
# Start polling if not already running
if study_id not in self.pollers:
await self._start_polling(study_id)
async def disconnect(self, websocket: WebSocket, study_id: str):
"""Disconnect a WebSocket client"""
if study_id in self.active_connections:
self.active_connections[study_id].remove(websocket)
self.active_connections[study_id].discard(websocket)
# Stop polling if no more connections
if not self.active_connections[study_id]:
del self.active_connections[study_id]
self.stop_watching(study_id)
await self._stop_polling(study_id)
async def broadcast(self, message: dict, study_id: str):
if study_id in self.active_connections:
for connection in self.active_connections[study_id]:
try:
await connection.send_json(message)
except Exception as e:
print(f"Error broadcasting to client: {e}")
def start_watching(self, study_id: str):
study_dir = STUDIES_DIR / study_id / "2_results"
if not study_dir.exists():
"""Broadcast message to all connected clients for a study"""
if study_id not in self.active_connections:
return
disconnected = []
for connection in self.active_connections[study_id]:
try:
await connection.send_json(message)
except Exception as e:
print(f"[WebSocket] Error broadcasting to client: {e}")
disconnected.append(connection)
# Clean up disconnected clients
for conn in disconnected:
self.active_connections[study_id].discard(conn)
async def _start_polling(self, study_id: str):
"""Start database polling for a study"""
async def callback(message):
await self.broadcast(message, study_id)
event_handler = OptimizationFileHandler(study_id, callback)
observer = Observer()
observer.schedule(event_handler, str(study_dir), recursive=True)
observer.start()
self.observers[study_id] = observer
poller = DatabasePoller(study_id, callback)
self.pollers[study_id] = poller
await poller.start()
print(f"[WebSocket] Started polling for study: {study_id}")
async def _stop_polling(self, study_id: str):
"""Stop database polling for a study"""
if study_id in self.pollers:
await self.pollers[study_id].stop()
del self.pollers[study_id]
print(f"[WebSocket] Stopped polling for study: {study_id}")
def stop_watching(self, study_id: str):
if study_id in self.observers:
self.observers[study_id].stop()
self.observers[study_id].join()
del self.observers[study_id]
manager = ConnectionManager()
@router.websocket("/optimization/{study_id}")
async def optimization_stream(websocket: WebSocket, study_id: str):
"""
WebSocket endpoint for real-time optimization updates.
Sends messages:
- connected: Initial connection confirmation
- trial_completed: New trial completed with full data
- new_best: New best value found
- progress: Progress update (current/total, FEA/NN counts)
- pareto_update: Pareto front updated (multi-objective)
"""
await manager.connect(websocket, study_id)
try:
# Send initial connection message
await websocket.send_json({
"type": "connected",
"data": {"message": f"Connected to stream for study {study_id}"}
"data": {
"study_id": study_id,
"message": f"Connected to real-time stream for study {study_id}",
"timestamp": datetime.now().isoformat()
}
})
# Keep connection alive
while True:
# Keep connection alive and handle incoming messages if needed
data = await websocket.receive_text()
# We could handle client commands here (e.g., "pause", "stop")
try:
# Wait for messages (ping/pong or commands)
data = await asyncio.wait_for(
websocket.receive_text(),
timeout=30.0 # 30 second timeout for ping
)
# Handle client commands
try:
msg = json.loads(data)
if msg.get("type") == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except asyncio.TimeoutError:
# Send heartbeat
try:
await websocket.send_json({"type": "heartbeat"})
except:
break
except WebSocketDisconnect:
manager.disconnect(websocket, study_id)
pass
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, study_id)
print(f"[WebSocket] Connection error for {study_id}: {e}")
finally:
await manager.disconnect(websocket, study_id)