import asyncio import json import sqlite3 from pathlib import Path from typing import Dict, Set, Optional, Any, List from datetime import datetime from fastapi import APIRouter, WebSocket, WebSocketDisconnect import aiofiles router = APIRouter() # Base studies directory STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies" def get_results_dir(study_dir: Path) -> Path: """Get the results directory for a study, supporting both 2_results and 3_results.""" results_dir = study_dir / "2_results" if not results_dir.exists(): results_dir = study_dir / "3_results" return results_dir class DatabasePoller: """ Polls the Optuna SQLite database for changes. More reliable than file watching, especially on Windows. """ def __init__(self, study_id: str, callback): self.study_id = study_id self.callback = callback self.last_trial_count = 0 self.last_trial_id = 0 self.last_best_value: Optional[float] = None self.last_pareto_count = 0 self.running = False self._task: Optional[asyncio.Task] = None async def start(self): """Start the polling loop""" self.running = True self._task = asyncio.create_task(self._poll_loop()) async def stop(self): """Stop the polling loop""" self.running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass async def _poll_loop(self): """Main polling loop - checks database every 2 seconds""" study_dir = STUDIES_DIR / self.study_id results_dir = get_results_dir(study_dir) db_path = results_dir / "study.db" while self.running: try: if db_path.exists(): await self._check_database(db_path) await asyncio.sleep(2) # Poll every 2 seconds except asyncio.CancelledError: break except Exception as e: print(f"[WebSocket] Polling error for {self.study_id}: {e}") await asyncio.sleep(5) # Back off on error async def _check_database(self, db_path: Path): """Check database for new trials and updates""" try: # Use timeout to avoid blocking on locked databases conn = sqlite3.connect(str(db_path), timeout=2.0) conn.row_factory = sqlite3.Row cursor = conn.cursor() # Get total completed trial count cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'") total_count = cursor.fetchone()[0] # Check for new trials if total_count > self.last_trial_count: await self._process_new_trials(cursor, total_count) # Check for new best value await self._check_best_value(cursor) # Check Pareto front for multi-objective await self._check_pareto_front(cursor, db_path) # Send progress update await self._send_progress(cursor, total_count) conn.close() except sqlite3.OperationalError as e: # Database locked - skip this poll pass except Exception as e: print(f"[WebSocket] Database check error: {e}") async def _process_new_trials(self, cursor, total_count: int): """Process and broadcast new trials""" # Get new trials since last check cursor.execute(""" SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name FROM trials t JOIN studies s ON t.study_id = s.study_id WHERE t.state = 'COMPLETE' AND t.trial_id > ? ORDER BY t.trial_id ASC """, (self.last_trial_id,)) new_trials = cursor.fetchall() for row in new_trials: trial_id = row['trial_id'] trial_data = await self._build_trial_data(cursor, row) await self.callback({ "type": "trial_completed", "data": trial_data }) self.last_trial_id = trial_id self.last_trial_count = total_count async def _build_trial_data(self, cursor, row) -> Dict[str, Any]: """Build trial data dictionary from database row""" trial_id = row['trial_id'] # Get objectives cursor.execute(""" SELECT value FROM trial_values WHERE trial_id = ? ORDER BY objective """, (trial_id,)) values = [r[0] for r in cursor.fetchall()] # Get parameters cursor.execute(""" SELECT param_name, param_value FROM trial_params WHERE trial_id = ? """, (trial_id,)) params = {} for r in cursor.fetchall(): try: params[r[0]] = float(r[1]) if r[1] is not None else None except (ValueError, TypeError): params[r[0]] = r[1] # Get user attributes cursor.execute(""" SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ? """, (trial_id,)) user_attrs = {} for r in cursor.fetchall(): try: user_attrs[r[0]] = json.loads(r[1]) except (ValueError, TypeError): user_attrs[r[0]] = r[1] # Extract source and design vars source = user_attrs.get("source", "FEA") design_vars = user_attrs.get("design_vars", params) return { "trial_number": trial_id, # Use trial_id for uniqueness "trial_num": row['number'], "objective": values[0] if values else None, "values": values, "params": design_vars, "user_attrs": user_attrs, "source": source, "start_time": row['datetime_start'], "end_time": row['datetime_complete'], "study_name": row['study_name'], "constraint_satisfied": user_attrs.get("constraint_satisfied", True) } async def _check_best_value(self, cursor): """Check for new best value and broadcast if changed""" cursor.execute(""" SELECT MIN(tv.value) as best_value FROM trial_values tv JOIN trials t ON tv.trial_id = t.trial_id WHERE t.state = 'COMPLETE' AND tv.objective = 0 """) result = cursor.fetchone() if result and result['best_value'] is not None: best_value = result['best_value'] if self.last_best_value is None or best_value < self.last_best_value: # Get the best trial details cursor.execute(""" SELECT t.trial_id, t.number FROM trials t JOIN trial_values tv ON t.trial_id = tv.trial_id WHERE t.state = 'COMPLETE' AND tv.objective = 0 AND tv.value = ? LIMIT 1 """, (best_value,)) best_row = cursor.fetchone() if best_row: # Get params for best trial cursor.execute(""" SELECT param_name, param_value FROM trial_params WHERE trial_id = ? """, (best_row['trial_id'],)) params = {r[0]: r[1] for r in cursor.fetchall()} await self.callback({ "type": "new_best", "data": { "trial_number": best_row['trial_id'], "value": best_value, "params": params, "improvement": ( ((self.last_best_value - best_value) / abs(self.last_best_value) * 100) if self.last_best_value else 0 ) } }) self.last_best_value = best_value async def _check_pareto_front(self, cursor, db_path: Path): """Check for Pareto front updates in multi-objective studies""" try: # Check if multi-objective by counting distinct objectives cursor.execute(""" SELECT COUNT(DISTINCT objective) as obj_count FROM trial_values WHERE trial_id IN (SELECT trial_id FROM trials WHERE state = 'COMPLETE') """) result = cursor.fetchone() if result and result['obj_count'] > 1: # Multi-objective - compute Pareto front import optuna storage = optuna.storages.RDBStorage(f"sqlite:///{db_path}") # Get all study names cursor.execute("SELECT study_name FROM studies") study_names = [r[0] for r in cursor.fetchall()] for study_name in study_names: try: study = optuna.load_study(study_name=study_name, storage=storage) if len(study.directions) > 1: pareto_trials = study.best_trials if len(pareto_trials) != self.last_pareto_count: pareto_data = [ { "trial_number": t.number, "values": t.values, "params": t.params, "constraint_satisfied": t.user_attrs.get("constraint_satisfied", True), "source": t.user_attrs.get("source", "FEA") } for t in pareto_trials ] await self.callback({ "type": "pareto_update", "data": { "pareto_front": pareto_data, "count": len(pareto_trials) } }) self.last_pareto_count = len(pareto_trials) break except: continue except Exception as e: # Non-critical - skip Pareto check pass async def _send_progress(self, cursor, total_count: int): """Send progress update""" # Get total from config if available study_dir = STUDIES_DIR / self.study_id config_path = study_dir / "1_setup" / "optimization_config.json" if not config_path.exists(): config_path = study_dir / "optimization_config.json" total_target = 100 # Default if config_path.exists(): try: with open(config_path) as f: config = json.load(f) total_target = config.get('optimization_settings', {}).get('n_trials', 100) except: pass # Count FEA vs NN trials cursor.execute(""" SELECT COUNT(CASE WHEN s.study_name LIKE '%_nn%' THEN 1 END) as nn_count, COUNT(CASE WHEN s.study_name NOT LIKE '%_nn%' THEN 1 END) as fea_count FROM trials t JOIN studies s ON t.study_id = s.study_id WHERE t.state = 'COMPLETE' """) counts = cursor.fetchone() await self.callback({ "type": "progress", "data": { "current": total_count, "total": total_target, "percentage": min(100, (total_count / total_target * 100)) if total_target > 0 else 0, "fea_count": counts['fea_count'] if counts else total_count, "nn_count": counts['nn_count'] if counts else 0, "timestamp": datetime.now().isoformat() } }) class ConnectionManager: """ Manages WebSocket connections and database pollers for real-time updates. Uses database polling instead of file watching for better reliability on Windows. """ def __init__(self): self.active_connections: Dict[str, Set[WebSocket]] = {} self.pollers: Dict[str, DatabasePoller] = {} async def connect(self, websocket: WebSocket, study_id: str): """Connect a new WebSocket client""" await websocket.accept() if study_id not in self.active_connections: self.active_connections[study_id] = set() self.active_connections[study_id].add(websocket) # Start polling if not already running if study_id not in self.pollers: await self._start_polling(study_id) async def disconnect(self, websocket: WebSocket, study_id: str): """Disconnect a WebSocket client""" if study_id in self.active_connections: self.active_connections[study_id].discard(websocket) # Stop polling if no more connections if not self.active_connections[study_id]: del self.active_connections[study_id] await self._stop_polling(study_id) async def broadcast(self, message: dict, study_id: str): """Broadcast message to all connected clients for a study""" if study_id not in self.active_connections: return disconnected = [] for connection in self.active_connections[study_id]: try: await connection.send_json(message) except Exception as e: print(f"[WebSocket] Error broadcasting to client: {e}") disconnected.append(connection) # Clean up disconnected clients for conn in disconnected: self.active_connections[study_id].discard(conn) async def _start_polling(self, study_id: str): """Start database polling for a study""" async def callback(message): await self.broadcast(message, study_id) poller = DatabasePoller(study_id, callback) self.pollers[study_id] = poller await poller.start() print(f"[WebSocket] Started polling for study: {study_id}") async def _stop_polling(self, study_id: str): """Stop database polling for a study""" if study_id in self.pollers: await self.pollers[study_id].stop() del self.pollers[study_id] print(f"[WebSocket] Stopped polling for study: {study_id}") manager = ConnectionManager() @router.websocket("/optimization/{study_id}") async def optimization_stream(websocket: WebSocket, study_id: str): """ WebSocket endpoint for real-time optimization updates. Sends messages: - connected: Initial connection confirmation - trial_completed: New trial completed with full data - new_best: New best value found - progress: Progress update (current/total, FEA/NN counts) - pareto_update: Pareto front updated (multi-objective) """ await manager.connect(websocket, study_id) try: # Send initial connection message await websocket.send_json({ "type": "connected", "data": { "study_id": study_id, "message": f"Connected to real-time stream for study {study_id}", "timestamp": datetime.now().isoformat() } }) # Keep connection alive while True: try: # Wait for messages (ping/pong or commands) data = await asyncio.wait_for( websocket.receive_text(), timeout=30.0 # 30 second timeout for ping ) # Handle client commands try: msg = json.loads(data) if msg.get("type") == "ping": await websocket.send_json({"type": "pong"}) except json.JSONDecodeError: pass except asyncio.TimeoutError: # Send heartbeat try: await websocket.send_json({"type": "heartbeat"}) except: break except WebSocketDisconnect: pass except Exception as e: print(f"[WebSocket] Connection error for {study_id}: {e}") finally: await manager.disconnect(websocket, study_id)