feat: Add Analysis page, run comparison, notifications, and config editor

Dashboard enhancements:
- Add Analysis page with tabs: Overview, Parameters, Pareto, Correlations, Constraints, Surrogate, Runs
- Add PlotlyCorrelationHeatmap for parameter-objective correlation analysis
- Add PlotlyFeasibilityChart for constraint satisfaction visualization
- Add PlotlySurrogateQuality for FEA vs NN prediction comparison
- Add PlotlyRunComparison for comparing optimization runs within a study

Real-time improvements:
- Replace watchdog file-watching with SQLite database polling for better Windows reliability
- Add DatabasePoller class with 2-second polling interval
- Enhanced WebSocket messages: trial_completed, new_best, pareto_update, progress

Desktop notifications:
- Add useNotifications hook using Web Notifications API
- Add NotificationSettings toggle component
- Notify users when new best solutions are found

Config editor:
- Add PUT /studies/{study_id}/config endpoint with auto-backup
- Add ConfigEditor modal with tabs: General, Variables, Objectives, Settings, JSON
- Prevents editing while optimization is running

Enhanced Pareto visualization:
- Add dark mode styling with transparent backgrounds
- Add stats bar showing Pareto, FEA, NN, and infeasible counts
- Add Pareto front connecting line for 2D view
- Add table showing top 10 Pareto-optimal solutions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Antoine
2025-12-05 19:57:20 -05:00
parent 5c660ff270
commit 5fb94fdf01
27 changed files with 5878 additions and 722 deletions

View File

@@ -1,10 +1,10 @@
import asyncio
import json
import sqlite3
from pathlib import Path
from typing import Dict, Set
from typing import Dict, Set, Optional, Any, List
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import aiofiles
router = APIRouter()
@@ -12,185 +12,440 @@ router = APIRouter()
# Base studies directory
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
class OptimizationFileHandler(FileSystemEventHandler):
def get_results_dir(study_dir: Path) -> Path:
"""Get the results directory for a study, supporting both 2_results and 3_results."""
results_dir = study_dir / "2_results"
if not results_dir.exists():
results_dir = study_dir / "3_results"
return results_dir
class DatabasePoller:
"""
Polls the Optuna SQLite database for changes.
More reliable than file watching, especially on Windows.
"""
def __init__(self, study_id: str, callback):
self.study_id = study_id
self.callback = callback
self.last_trial_count = 0
self.last_pruned_count = 0
self.last_trial_id = 0
self.last_best_value: Optional[float] = None
self.last_pareto_count = 0
self.last_state_timestamp = ""
self.running = False
self._task: Optional[asyncio.Task] = None
def on_modified(self, event):
if event.src_path.endswith("optimization_history_incremental.json"):
asyncio.run(self.process_history_update(event.src_path))
elif event.src_path.endswith("pruning_history.json"):
asyncio.run(self.process_pruning_update(event.src_path))
elif event.src_path.endswith("study.db"): # Watch for Optuna DB changes (Pareto front)
asyncio.run(self.process_pareto_update(event.src_path))
elif event.src_path.endswith("optimizer_state.json"):
asyncio.run(self.process_state_update(event.src_path))
async def start(self):
"""Start the polling loop"""
self.running = True
self._task = asyncio.create_task(self._poll_loop())
async def process_history_update(self, file_path):
async def stop(self):
"""Stop the polling loop"""
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def _poll_loop(self):
"""Main polling loop - checks database every 2 seconds"""
study_dir = STUDIES_DIR / self.study_id
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
while self.running:
try:
if db_path.exists():
await self._check_database(db_path)
await asyncio.sleep(2) # Poll every 2 seconds
except asyncio.CancelledError:
break
except Exception as e:
print(f"[WebSocket] Polling error for {self.study_id}: {e}")
await asyncio.sleep(5) # Back off on error
async def _check_database(self, db_path: Path):
"""Check database for new trials and updates"""
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
history = json.loads(content)
current_count = len(history)
if current_count > self.last_trial_count:
# New trials added
new_trials = history[self.last_trial_count:]
for trial in new_trials:
await self.callback({
"type": "trial_completed",
"data": trial
})
self.last_trial_count = current_count
# Use timeout to avoid blocking on locked databases
conn = sqlite3.connect(str(db_path), timeout=2.0)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get total completed trial count
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
total_count = cursor.fetchone()[0]
# Check for new trials
if total_count > self.last_trial_count:
await self._process_new_trials(cursor, total_count)
# Check for new best value
await self._check_best_value(cursor)
# Check Pareto front for multi-objective
await self._check_pareto_front(cursor, db_path)
# Send progress update
await self._send_progress(cursor, total_count)
conn.close()
except sqlite3.OperationalError as e:
# Database locked - skip this poll
pass
except Exception as e:
print(f"Error processing history update: {e}")
print(f"[WebSocket] Database check error: {e}")
async def process_pruning_update(self, file_path):
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
history = json.loads(content)
current_count = len(history)
if current_count > self.last_pruned_count:
# New pruned trials
new_pruned = history[self.last_pruned_count:]
for trial in new_pruned:
await self.callback({
"type": "trial_pruned",
"data": trial
})
self.last_pruned_count = current_count
except Exception as e:
print(f"Error processing pruning update: {e}")
async def _process_new_trials(self, cursor, total_count: int):
"""Process and broadcast new trials"""
# Get new trials since last check
cursor.execute("""
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE' AND t.trial_id > ?
ORDER BY t.trial_id ASC
""", (self.last_trial_id,))
new_trials = cursor.fetchall()
for row in new_trials:
trial_id = row['trial_id']
trial_data = await self._build_trial_data(cursor, row)
await self.callback({
"type": "trial_completed",
"data": trial_data
})
self.last_trial_id = trial_id
self.last_trial_count = total_count
async def _build_trial_data(self, cursor, row) -> Dict[str, Any]:
"""Build trial data dictionary from database row"""
trial_id = row['trial_id']
# Get objectives
cursor.execute("""
SELECT value FROM trial_values
WHERE trial_id = ? ORDER BY objective
""", (trial_id,))
values = [r[0] for r in cursor.fetchall()]
# Get parameters
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {}
for r in cursor.fetchall():
try:
params[r[0]] = float(r[1]) if r[1] is not None else None
except (ValueError, TypeError):
params[r[0]] = r[1]
# Get user attributes
cursor.execute("""
SELECT key, value_json FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for r in cursor.fetchall():
try:
user_attrs[r[0]] = json.loads(r[1])
except (ValueError, TypeError):
user_attrs[r[0]] = r[1]
# Extract source and design vars
source = user_attrs.get("source", "FEA")
design_vars = user_attrs.get("design_vars", params)
return {
"trial_number": trial_id, # Use trial_id for uniqueness
"trial_num": row['number'],
"objective": values[0] if values else None,
"values": values,
"params": design_vars,
"user_attrs": user_attrs,
"source": source,
"start_time": row['datetime_start'],
"end_time": row['datetime_complete'],
"study_name": row['study_name'],
"constraint_satisfied": user_attrs.get("constraint_satisfied", True)
}
async def _check_best_value(self, cursor):
"""Check for new best value and broadcast if changed"""
cursor.execute("""
SELECT MIN(tv.value) as best_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0
""")
result = cursor.fetchone()
if result and result['best_value'] is not None:
best_value = result['best_value']
if self.last_best_value is None or best_value < self.last_best_value:
# Get the best trial details
cursor.execute("""
SELECT t.trial_id, t.number
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0 AND tv.value = ?
LIMIT 1
""", (best_value,))
best_row = cursor.fetchone()
if best_row:
# Get params for best trial
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (best_row['trial_id'],))
params = {r[0]: r[1] for r in cursor.fetchall()}
async def process_pareto_update(self, file_path):
# This is tricky because study.db is binary.
# Instead of reading it directly, we'll trigger a re-fetch of the Pareto front via Optuna
# We debounce this to avoid excessive reads
try:
# Import here to avoid circular imports or heavy load at startup
import optuna
# Connect to DB
storage = optuna.storages.RDBStorage(f"sqlite:///{file_path}")
study = optuna.load_study(study_name=self.study_id, storage=storage)
# Check if multi-objective
if len(study.directions) > 1:
pareto_trials = study.best_trials
# Only broadcast if count changed (simple heuristic)
# In a real app, we might check content hash
if len(pareto_trials) != self.last_pareto_count:
pareto_data = [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"user_attrs": dict(t.user_attrs),
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True)
}
for t in pareto_trials
]
await self.callback({
"type": "pareto_front",
"type": "new_best",
"data": {
"pareto_front": pareto_data,
"count": len(pareto_trials)
"trial_number": best_row['trial_id'],
"value": best_value,
"params": params,
"improvement": (
((self.last_best_value - best_value) / abs(self.last_best_value) * 100)
if self.last_best_value else 0
)
}
})
self.last_pareto_count = len(pareto_trials)
self.last_best_value = best_value
async def _check_pareto_front(self, cursor, db_path: Path):
"""Check for Pareto front updates in multi-objective studies"""
try:
# Check if multi-objective by counting distinct objectives
cursor.execute("""
SELECT COUNT(DISTINCT objective) as obj_count
FROM trial_values
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state = 'COMPLETE')
""")
result = cursor.fetchone()
if result and result['obj_count'] > 1:
# Multi-objective - compute Pareto front
import optuna
storage = optuna.storages.RDBStorage(f"sqlite:///{db_path}")
# Get all study names
cursor.execute("SELECT study_name FROM studies")
study_names = [r[0] for r in cursor.fetchall()]
for study_name in study_names:
try:
study = optuna.load_study(study_name=study_name, storage=storage)
if len(study.directions) > 1:
pareto_trials = study.best_trials
if len(pareto_trials) != self.last_pareto_count:
pareto_data = [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True),
"source": t.user_attrs.get("source", "FEA")
}
for t in pareto_trials
]
await self.callback({
"type": "pareto_update",
"data": {
"pareto_front": pareto_data,
"count": len(pareto_trials)
}
})
self.last_pareto_count = len(pareto_trials)
break
except:
continue
except Exception as e:
# DB might be locked, ignore transient errors
# Non-critical - skip Pareto check
pass
async def process_state_update(self, file_path):
try:
async with aiofiles.open(file_path, mode='r') as f:
content = await f.read()
state = json.loads(content)
# Check timestamp to avoid duplicate broadcasts
if state.get("timestamp") != self.last_state_timestamp:
await self.callback({
"type": "optimizer_state",
"data": state
})
self.last_state_timestamp = state.get("timestamp")
except Exception as e:
print(f"Error processing state update: {e}")
async def _send_progress(self, cursor, total_count: int):
"""Send progress update"""
# Get total from config if available
study_dir = STUDIES_DIR / self.study_id
config_path = study_dir / "1_setup" / "optimization_config.json"
if not config_path.exists():
config_path = study_dir / "optimization_config.json"
total_target = 100 # Default
if config_path.exists():
try:
with open(config_path) as f:
config = json.load(f)
total_target = config.get('optimization_settings', {}).get('n_trials', 100)
except:
pass
# Count FEA vs NN trials
cursor.execute("""
SELECT
COUNT(CASE WHEN s.study_name LIKE '%_nn%' THEN 1 END) as nn_count,
COUNT(CASE WHEN s.study_name NOT LIKE '%_nn%' THEN 1 END) as fea_count
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
""")
counts = cursor.fetchone()
await self.callback({
"type": "progress",
"data": {
"current": total_count,
"total": total_target,
"percentage": min(100, (total_count / total_target * 100)) if total_target > 0 else 0,
"fea_count": counts['fea_count'] if counts else total_count,
"nn_count": counts['nn_count'] if counts else 0,
"timestamp": datetime.now().isoformat()
}
})
class ConnectionManager:
"""
Manages WebSocket connections and database pollers for real-time updates.
Uses database polling instead of file watching for better reliability on Windows.
"""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
self.observers: Dict[str, Observer] = {}
self.pollers: Dict[str, DatabasePoller] = {}
async def connect(self, websocket: WebSocket, study_id: str):
"""Connect a new WebSocket client"""
await websocket.accept()
if study_id not in self.active_connections:
self.active_connections[study_id] = set()
self.start_watching(study_id)
self.active_connections[study_id].add(websocket)
def disconnect(self, websocket: WebSocket, study_id: str):
# Start polling if not already running
if study_id not in self.pollers:
await self._start_polling(study_id)
async def disconnect(self, websocket: WebSocket, study_id: str):
"""Disconnect a WebSocket client"""
if study_id in self.active_connections:
self.active_connections[study_id].remove(websocket)
self.active_connections[study_id].discard(websocket)
# Stop polling if no more connections
if not self.active_connections[study_id]:
del self.active_connections[study_id]
self.stop_watching(study_id)
await self._stop_polling(study_id)
async def broadcast(self, message: dict, study_id: str):
if study_id in self.active_connections:
for connection in self.active_connections[study_id]:
try:
await connection.send_json(message)
except Exception as e:
print(f"Error broadcasting to client: {e}")
def start_watching(self, study_id: str):
study_dir = STUDIES_DIR / study_id / "2_results"
if not study_dir.exists():
"""Broadcast message to all connected clients for a study"""
if study_id not in self.active_connections:
return
disconnected = []
for connection in self.active_connections[study_id]:
try:
await connection.send_json(message)
except Exception as e:
print(f"[WebSocket] Error broadcasting to client: {e}")
disconnected.append(connection)
# Clean up disconnected clients
for conn in disconnected:
self.active_connections[study_id].discard(conn)
async def _start_polling(self, study_id: str):
"""Start database polling for a study"""
async def callback(message):
await self.broadcast(message, study_id)
event_handler = OptimizationFileHandler(study_id, callback)
observer = Observer()
observer.schedule(event_handler, str(study_dir), recursive=True)
observer.start()
self.observers[study_id] = observer
poller = DatabasePoller(study_id, callback)
self.pollers[study_id] = poller
await poller.start()
print(f"[WebSocket] Started polling for study: {study_id}")
async def _stop_polling(self, study_id: str):
"""Stop database polling for a study"""
if study_id in self.pollers:
await self.pollers[study_id].stop()
del self.pollers[study_id]
print(f"[WebSocket] Stopped polling for study: {study_id}")
def stop_watching(self, study_id: str):
if study_id in self.observers:
self.observers[study_id].stop()
self.observers[study_id].join()
del self.observers[study_id]
manager = ConnectionManager()
@router.websocket("/optimization/{study_id}")
async def optimization_stream(websocket: WebSocket, study_id: str):
"""
WebSocket endpoint for real-time optimization updates.
Sends messages:
- connected: Initial connection confirmation
- trial_completed: New trial completed with full data
- new_best: New best value found
- progress: Progress update (current/total, FEA/NN counts)
- pareto_update: Pareto front updated (multi-objective)
"""
await manager.connect(websocket, study_id)
try:
# Send initial connection message
await websocket.send_json({
"type": "connected",
"data": {"message": f"Connected to stream for study {study_id}"}
"data": {
"study_id": study_id,
"message": f"Connected to real-time stream for study {study_id}",
"timestamp": datetime.now().isoformat()
}
})
# Keep connection alive
while True:
# Keep connection alive and handle incoming messages if needed
data = await websocket.receive_text()
# We could handle client commands here (e.g., "pause", "stop")
try:
# Wait for messages (ping/pong or commands)
data = await asyncio.wait_for(
websocket.receive_text(),
timeout=30.0 # 30 second timeout for ping
)
# Handle client commands
try:
msg = json.loads(data)
if msg.get("type") == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except asyncio.TimeoutError:
# Send heartbeat
try:
await websocket.send_json({"type": "heartbeat"})
except:
break
except WebSocketDisconnect:
manager.disconnect(websocket, study_id)
pass
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, study_id)
print(f"[WebSocket] Connection error for {study_id}: {e}")
finally:
await manager.disconnect(websocket, study_id)