Dashboard enhancements:
- Add Analysis page with tabs: Overview, Parameters, Pareto, Correlations, Constraints, Surrogate, Runs
- Add PlotlyCorrelationHeatmap for parameter-objective correlation analysis
- Add PlotlyFeasibilityChart for constraint satisfaction visualization
- Add PlotlySurrogateQuality for FEA vs NN prediction comparison
- Add PlotlyRunComparison for comparing optimization runs within a study
Real-time improvements:
- Replace watchdog file-watching with SQLite database polling for better Windows reliability
- Add DatabasePoller class with 2-second polling interval
- Enhanced WebSocket messages: trial_completed, new_best, pareto_update, progress
Desktop notifications:
- Add useNotifications hook using Web Notifications API
- Add NotificationSettings toggle component
- Notify users when new best solutions are found
Config editor:
- Add PUT /studies/{study_id}/config endpoint with auto-backup
- Add ConfigEditor modal with tabs: General, Variables, Objectives, Settings, JSON
- Prevents editing while optimization is running
Enhanced Pareto visualization:
- Add dark mode styling with transparent backgrounds
- Add stats bar showing Pareto, FEA, NN, and infeasible counts
- Add Pareto front connecting line for 2D view
- Add table showing top 10 Pareto-optimal solutions
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
452 lines
16 KiB
Python
452 lines
16 KiB
Python
import asyncio
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Dict, Set, Optional, Any, List
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
import aiofiles
|
|
|
|
router = APIRouter()
|
|
|
|
# Base studies directory
|
|
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
|
|
|
|
|
|
def get_results_dir(study_dir: Path) -> Path:
|
|
"""Get the results directory for a study, supporting both 2_results and 3_results."""
|
|
results_dir = study_dir / "2_results"
|
|
if not results_dir.exists():
|
|
results_dir = study_dir / "3_results"
|
|
return results_dir
|
|
|
|
|
|
class DatabasePoller:
|
|
"""
|
|
Polls the Optuna SQLite database for changes.
|
|
More reliable than file watching, especially on Windows.
|
|
"""
|
|
|
|
def __init__(self, study_id: str, callback):
|
|
self.study_id = study_id
|
|
self.callback = callback
|
|
self.last_trial_count = 0
|
|
self.last_trial_id = 0
|
|
self.last_best_value: Optional[float] = None
|
|
self.last_pareto_count = 0
|
|
self.running = False
|
|
self._task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self):
|
|
"""Start the polling loop"""
|
|
self.running = True
|
|
self._task = asyncio.create_task(self._poll_loop())
|
|
|
|
async def stop(self):
|
|
"""Stop the polling loop"""
|
|
self.running = False
|
|
if self._task:
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _poll_loop(self):
|
|
"""Main polling loop - checks database every 2 seconds"""
|
|
study_dir = STUDIES_DIR / self.study_id
|
|
results_dir = get_results_dir(study_dir)
|
|
db_path = results_dir / "study.db"
|
|
|
|
while self.running:
|
|
try:
|
|
if db_path.exists():
|
|
await self._check_database(db_path)
|
|
await asyncio.sleep(2) # Poll every 2 seconds
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
print(f"[WebSocket] Polling error for {self.study_id}: {e}")
|
|
await asyncio.sleep(5) # Back off on error
|
|
|
|
async def _check_database(self, db_path: Path):
|
|
"""Check database for new trials and updates"""
|
|
try:
|
|
# Use timeout to avoid blocking on locked databases
|
|
conn = sqlite3.connect(str(db_path), timeout=2.0)
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
# Get total completed trial count
|
|
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
|
|
total_count = cursor.fetchone()[0]
|
|
|
|
# Check for new trials
|
|
if total_count > self.last_trial_count:
|
|
await self._process_new_trials(cursor, total_count)
|
|
|
|
# Check for new best value
|
|
await self._check_best_value(cursor)
|
|
|
|
# Check Pareto front for multi-objective
|
|
await self._check_pareto_front(cursor, db_path)
|
|
|
|
# Send progress update
|
|
await self._send_progress(cursor, total_count)
|
|
|
|
conn.close()
|
|
|
|
except sqlite3.OperationalError as e:
|
|
# Database locked - skip this poll
|
|
pass
|
|
except Exception as e:
|
|
print(f"[WebSocket] Database check error: {e}")
|
|
|
|
async def _process_new_trials(self, cursor, total_count: int):
|
|
"""Process and broadcast new trials"""
|
|
# Get new trials since last check
|
|
cursor.execute("""
|
|
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name
|
|
FROM trials t
|
|
JOIN studies s ON t.study_id = s.study_id
|
|
WHERE t.state = 'COMPLETE' AND t.trial_id > ?
|
|
ORDER BY t.trial_id ASC
|
|
""", (self.last_trial_id,))
|
|
|
|
new_trials = cursor.fetchall()
|
|
|
|
for row in new_trials:
|
|
trial_id = row['trial_id']
|
|
trial_data = await self._build_trial_data(cursor, row)
|
|
|
|
await self.callback({
|
|
"type": "trial_completed",
|
|
"data": trial_data
|
|
})
|
|
|
|
self.last_trial_id = trial_id
|
|
|
|
self.last_trial_count = total_count
|
|
|
|
async def _build_trial_data(self, cursor, row) -> Dict[str, Any]:
|
|
"""Build trial data dictionary from database row"""
|
|
trial_id = row['trial_id']
|
|
|
|
# Get objectives
|
|
cursor.execute("""
|
|
SELECT value FROM trial_values
|
|
WHERE trial_id = ? ORDER BY objective
|
|
""", (trial_id,))
|
|
values = [r[0] for r in cursor.fetchall()]
|
|
|
|
# Get parameters
|
|
cursor.execute("""
|
|
SELECT param_name, param_value FROM trial_params
|
|
WHERE trial_id = ?
|
|
""", (trial_id,))
|
|
params = {}
|
|
for r in cursor.fetchall():
|
|
try:
|
|
params[r[0]] = float(r[1]) if r[1] is not None else None
|
|
except (ValueError, TypeError):
|
|
params[r[0]] = r[1]
|
|
|
|
# Get user attributes
|
|
cursor.execute("""
|
|
SELECT key, value_json FROM trial_user_attributes
|
|
WHERE trial_id = ?
|
|
""", (trial_id,))
|
|
user_attrs = {}
|
|
for r in cursor.fetchall():
|
|
try:
|
|
user_attrs[r[0]] = json.loads(r[1])
|
|
except (ValueError, TypeError):
|
|
user_attrs[r[0]] = r[1]
|
|
|
|
# Extract source and design vars
|
|
source = user_attrs.get("source", "FEA")
|
|
design_vars = user_attrs.get("design_vars", params)
|
|
|
|
return {
|
|
"trial_number": trial_id, # Use trial_id for uniqueness
|
|
"trial_num": row['number'],
|
|
"objective": values[0] if values else None,
|
|
"values": values,
|
|
"params": design_vars,
|
|
"user_attrs": user_attrs,
|
|
"source": source,
|
|
"start_time": row['datetime_start'],
|
|
"end_time": row['datetime_complete'],
|
|
"study_name": row['study_name'],
|
|
"constraint_satisfied": user_attrs.get("constraint_satisfied", True)
|
|
}
|
|
|
|
async def _check_best_value(self, cursor):
|
|
"""Check for new best value and broadcast if changed"""
|
|
cursor.execute("""
|
|
SELECT MIN(tv.value) as best_value
|
|
FROM trial_values tv
|
|
JOIN trials t ON tv.trial_id = t.trial_id
|
|
WHERE t.state = 'COMPLETE' AND tv.objective = 0
|
|
""")
|
|
result = cursor.fetchone()
|
|
|
|
if result and result['best_value'] is not None:
|
|
best_value = result['best_value']
|
|
|
|
if self.last_best_value is None or best_value < self.last_best_value:
|
|
# Get the best trial details
|
|
cursor.execute("""
|
|
SELECT t.trial_id, t.number
|
|
FROM trials t
|
|
JOIN trial_values tv ON t.trial_id = tv.trial_id
|
|
WHERE t.state = 'COMPLETE' AND tv.objective = 0 AND tv.value = ?
|
|
LIMIT 1
|
|
""", (best_value,))
|
|
best_row = cursor.fetchone()
|
|
|
|
if best_row:
|
|
# Get params for best trial
|
|
cursor.execute("""
|
|
SELECT param_name, param_value FROM trial_params
|
|
WHERE trial_id = ?
|
|
""", (best_row['trial_id'],))
|
|
params = {r[0]: r[1] for r in cursor.fetchall()}
|
|
|
|
await self.callback({
|
|
"type": "new_best",
|
|
"data": {
|
|
"trial_number": best_row['trial_id'],
|
|
"value": best_value,
|
|
"params": params,
|
|
"improvement": (
|
|
((self.last_best_value - best_value) / abs(self.last_best_value) * 100)
|
|
if self.last_best_value else 0
|
|
)
|
|
}
|
|
})
|
|
|
|
self.last_best_value = best_value
|
|
|
|
async def _check_pareto_front(self, cursor, db_path: Path):
|
|
"""Check for Pareto front updates in multi-objective studies"""
|
|
try:
|
|
# Check if multi-objective by counting distinct objectives
|
|
cursor.execute("""
|
|
SELECT COUNT(DISTINCT objective) as obj_count
|
|
FROM trial_values
|
|
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state = 'COMPLETE')
|
|
""")
|
|
result = cursor.fetchone()
|
|
|
|
if result and result['obj_count'] > 1:
|
|
# Multi-objective - compute Pareto front
|
|
import optuna
|
|
storage = optuna.storages.RDBStorage(f"sqlite:///{db_path}")
|
|
|
|
# Get all study names
|
|
cursor.execute("SELECT study_name FROM studies")
|
|
study_names = [r[0] for r in cursor.fetchall()]
|
|
|
|
for study_name in study_names:
|
|
try:
|
|
study = optuna.load_study(study_name=study_name, storage=storage)
|
|
if len(study.directions) > 1:
|
|
pareto_trials = study.best_trials
|
|
|
|
if len(pareto_trials) != self.last_pareto_count:
|
|
pareto_data = [
|
|
{
|
|
"trial_number": t.number,
|
|
"values": t.values,
|
|
"params": t.params,
|
|
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True),
|
|
"source": t.user_attrs.get("source", "FEA")
|
|
}
|
|
for t in pareto_trials
|
|
]
|
|
|
|
await self.callback({
|
|
"type": "pareto_update",
|
|
"data": {
|
|
"pareto_front": pareto_data,
|
|
"count": len(pareto_trials)
|
|
}
|
|
})
|
|
|
|
self.last_pareto_count = len(pareto_trials)
|
|
break
|
|
except:
|
|
continue
|
|
|
|
except Exception as e:
|
|
# Non-critical - skip Pareto check
|
|
pass
|
|
|
|
async def _send_progress(self, cursor, total_count: int):
|
|
"""Send progress update"""
|
|
# Get total from config if available
|
|
study_dir = STUDIES_DIR / self.study_id
|
|
config_path = study_dir / "1_setup" / "optimization_config.json"
|
|
if not config_path.exists():
|
|
config_path = study_dir / "optimization_config.json"
|
|
|
|
total_target = 100 # Default
|
|
if config_path.exists():
|
|
try:
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
total_target = config.get('optimization_settings', {}).get('n_trials', 100)
|
|
except:
|
|
pass
|
|
|
|
# Count FEA vs NN trials
|
|
cursor.execute("""
|
|
SELECT
|
|
COUNT(CASE WHEN s.study_name LIKE '%_nn%' THEN 1 END) as nn_count,
|
|
COUNT(CASE WHEN s.study_name NOT LIKE '%_nn%' THEN 1 END) as fea_count
|
|
FROM trials t
|
|
JOIN studies s ON t.study_id = s.study_id
|
|
WHERE t.state = 'COMPLETE'
|
|
""")
|
|
counts = cursor.fetchone()
|
|
|
|
await self.callback({
|
|
"type": "progress",
|
|
"data": {
|
|
"current": total_count,
|
|
"total": total_target,
|
|
"percentage": min(100, (total_count / total_target * 100)) if total_target > 0 else 0,
|
|
"fea_count": counts['fea_count'] if counts else total_count,
|
|
"nn_count": counts['nn_count'] if counts else 0,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
class ConnectionManager:
|
|
"""
|
|
Manages WebSocket connections and database pollers for real-time updates.
|
|
Uses database polling instead of file watching for better reliability on Windows.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
|
self.pollers: Dict[str, DatabasePoller] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, study_id: str):
|
|
"""Connect a new WebSocket client"""
|
|
await websocket.accept()
|
|
|
|
if study_id not in self.active_connections:
|
|
self.active_connections[study_id] = set()
|
|
|
|
self.active_connections[study_id].add(websocket)
|
|
|
|
# Start polling if not already running
|
|
if study_id not in self.pollers:
|
|
await self._start_polling(study_id)
|
|
|
|
async def disconnect(self, websocket: WebSocket, study_id: str):
|
|
"""Disconnect a WebSocket client"""
|
|
if study_id in self.active_connections:
|
|
self.active_connections[study_id].discard(websocket)
|
|
|
|
# Stop polling if no more connections
|
|
if not self.active_connections[study_id]:
|
|
del self.active_connections[study_id]
|
|
await self._stop_polling(study_id)
|
|
|
|
async def broadcast(self, message: dict, study_id: str):
|
|
"""Broadcast message to all connected clients for a study"""
|
|
if study_id not in self.active_connections:
|
|
return
|
|
|
|
disconnected = []
|
|
for connection in self.active_connections[study_id]:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
print(f"[WebSocket] Error broadcasting to client: {e}")
|
|
disconnected.append(connection)
|
|
|
|
# Clean up disconnected clients
|
|
for conn in disconnected:
|
|
self.active_connections[study_id].discard(conn)
|
|
|
|
async def _start_polling(self, study_id: str):
|
|
"""Start database polling for a study"""
|
|
async def callback(message):
|
|
await self.broadcast(message, study_id)
|
|
|
|
poller = DatabasePoller(study_id, callback)
|
|
self.pollers[study_id] = poller
|
|
await poller.start()
|
|
print(f"[WebSocket] Started polling for study: {study_id}")
|
|
|
|
async def _stop_polling(self, study_id: str):
|
|
"""Stop database polling for a study"""
|
|
if study_id in self.pollers:
|
|
await self.pollers[study_id].stop()
|
|
del self.pollers[study_id]
|
|
print(f"[WebSocket] Stopped polling for study: {study_id}")
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket("/optimization/{study_id}")
|
|
async def optimization_stream(websocket: WebSocket, study_id: str):
|
|
"""
|
|
WebSocket endpoint for real-time optimization updates.
|
|
|
|
Sends messages:
|
|
- connected: Initial connection confirmation
|
|
- trial_completed: New trial completed with full data
|
|
- new_best: New best value found
|
|
- progress: Progress update (current/total, FEA/NN counts)
|
|
- pareto_update: Pareto front updated (multi-objective)
|
|
"""
|
|
await manager.connect(websocket, study_id)
|
|
|
|
try:
|
|
# Send initial connection message
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {
|
|
"study_id": study_id,
|
|
"message": f"Connected to real-time stream for study {study_id}",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
# Keep connection alive
|
|
while True:
|
|
try:
|
|
# Wait for messages (ping/pong or commands)
|
|
data = await asyncio.wait_for(
|
|
websocket.receive_text(),
|
|
timeout=30.0 # 30 second timeout for ping
|
|
)
|
|
|
|
# Handle client commands
|
|
try:
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
except asyncio.TimeoutError:
|
|
# Send heartbeat
|
|
try:
|
|
await websocket.send_json({"type": "heartbeat"})
|
|
except:
|
|
break
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
print(f"[WebSocket] Connection error for {study_id}: {e}")
|
|
finally:
|
|
await manager.disconnect(websocket, study_id)
|