import asyncio import json from pathlib import Path from typing import Dict, Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler import aiofiles router = APIRouter() # Base studies directory STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies" class OptimizationFileHandler(FileSystemEventHandler): def __init__(self, study_id: str, callback): self.study_id = study_id self.callback = callback self.last_trial_count = 0 self.last_pruned_count = 0 self.last_pareto_count = 0 self.last_state_timestamp = "" def on_modified(self, event): if event.src_path.endswith("optimization_history_incremental.json"): asyncio.run(self.process_history_update(event.src_path)) elif event.src_path.endswith("pruning_history.json"): asyncio.run(self.process_pruning_update(event.src_path)) elif event.src_path.endswith("study.db"): # Watch for Optuna DB changes (Pareto front) asyncio.run(self.process_pareto_update(event.src_path)) elif event.src_path.endswith("optimizer_state.json"): asyncio.run(self.process_state_update(event.src_path)) async def process_history_update(self, file_path): try: async with aiofiles.open(file_path, mode='r') as f: content = await f.read() history = json.loads(content) current_count = len(history) if current_count > self.last_trial_count: # New trials added new_trials = history[self.last_trial_count:] for trial in new_trials: await self.callback({ "type": "trial_completed", "data": trial }) self.last_trial_count = current_count except Exception as e: print(f"Error processing history update: {e}") async def process_pruning_update(self, file_path): try: async with aiofiles.open(file_path, mode='r') as f: content = await f.read() history = json.loads(content) current_count = len(history) if current_count > self.last_pruned_count: # New pruned trials new_pruned = history[self.last_pruned_count:] for trial in new_pruned: await self.callback({ "type": "trial_pruned", "data": trial }) self.last_pruned_count = current_count except Exception as e: print(f"Error processing pruning update: {e}") async def process_pareto_update(self, file_path): # This is tricky because study.db is binary. # Instead of reading it directly, we'll trigger a re-fetch of the Pareto front via Optuna # We debounce this to avoid excessive reads try: # Import here to avoid circular imports or heavy load at startup import optuna # Connect to DB storage = optuna.storages.RDBStorage(f"sqlite:///{file_path}") study = optuna.load_study(study_name=self.study_id, storage=storage) # Check if multi-objective if len(study.directions) > 1: pareto_trials = study.best_trials # Only broadcast if count changed (simple heuristic) # In a real app, we might check content hash if len(pareto_trials) != self.last_pareto_count: pareto_data = [ { "trial_number": t.number, "values": t.values, "params": t.params, "user_attrs": dict(t.user_attrs), "constraint_satisfied": t.user_attrs.get("constraint_satisfied", True) } for t in pareto_trials ] await self.callback({ "type": "pareto_front", "data": { "pareto_front": pareto_data, "count": len(pareto_trials) } }) self.last_pareto_count = len(pareto_trials) except Exception as e: # DB might be locked, ignore transient errors pass async def process_state_update(self, file_path): try: async with aiofiles.open(file_path, mode='r') as f: content = await f.read() state = json.loads(content) # Check timestamp to avoid duplicate broadcasts if state.get("timestamp") != self.last_state_timestamp: await self.callback({ "type": "optimizer_state", "data": state }) self.last_state_timestamp = state.get("timestamp") except Exception as e: print(f"Error processing state update: {e}") class ConnectionManager: def __init__(self): self.active_connections: Dict[str, Set[WebSocket]] = {} self.observers: Dict[str, Observer] = {} async def connect(self, websocket: WebSocket, study_id: str): await websocket.accept() if study_id not in self.active_connections: self.active_connections[study_id] = set() self.start_watching(study_id) self.active_connections[study_id].add(websocket) def disconnect(self, websocket: WebSocket, study_id: str): if study_id in self.active_connections: self.active_connections[study_id].remove(websocket) if not self.active_connections[study_id]: del self.active_connections[study_id] self.stop_watching(study_id) async def broadcast(self, message: dict, study_id: str): if study_id in self.active_connections: for connection in self.active_connections[study_id]: try: await connection.send_json(message) except Exception as e: print(f"Error broadcasting to client: {e}") def start_watching(self, study_id: str): study_dir = STUDIES_DIR / study_id / "2_results" if not study_dir.exists(): return async def callback(message): await self.broadcast(message, study_id) event_handler = OptimizationFileHandler(study_id, callback) observer = Observer() observer.schedule(event_handler, str(study_dir), recursive=True) observer.start() self.observers[study_id] = observer def stop_watching(self, study_id: str): if study_id in self.observers: self.observers[study_id].stop() self.observers[study_id].join() del self.observers[study_id] manager = ConnectionManager() @router.websocket("/optimization/{study_id}") async def optimization_stream(websocket: WebSocket, study_id: str): await manager.connect(websocket, study_id) try: await websocket.send_json({ "type": "connected", "data": {"message": f"Connected to stream for study {study_id}"} }) while True: # Keep connection alive and handle incoming messages if needed data = await websocket.receive_text() # We could handle client commands here (e.g., "pause", "stop") except WebSocketDisconnect: manager.disconnect(websocket, study_id) except Exception as e: print(f"WebSocket error: {e}") manager.disconnect(websocket, study_id)