Files
Atomizer/atomizer-dashboard/backend/api/websocket/optimization_stream.py

452 lines
16 KiB
Python
Raw Normal View History

import asyncio
import json
2025-12-05 19:57:20 -05:00
import sqlite3
from pathlib import Path
2025-12-05 19:57:20 -05:00
from typing import Dict, Set, Optional, Any, List
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
import aiofiles
router = APIRouter()
# Base studies directory
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
2025-12-05 19:57:20 -05:00
def get_results_dir(study_dir: Path) -> Path:
"""Get the results directory for a study, supporting both 2_results and 3_results."""
results_dir = study_dir / "2_results"
if not results_dir.exists():
results_dir = study_dir / "3_results"
return results_dir
class DatabasePoller:
"""
Polls the Optuna SQLite database for changes.
More reliable than file watching, especially on Windows.
"""
def __init__(self, study_id: str, callback):
self.study_id = study_id
self.callback = callback
self.last_trial_count = 0
2025-12-05 19:57:20 -05:00
self.last_trial_id = 0
self.last_best_value: Optional[float] = None
self.last_pareto_count = 0
2025-12-05 19:57:20 -05:00
self.running = False
self._task: Optional[asyncio.Task] = None
async def start(self):
"""Start the polling loop"""
self.running = True
self._task = asyncio.create_task(self._poll_loop())
async def stop(self):
"""Stop the polling loop"""
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def _poll_loop(self):
"""Main polling loop - checks database every 2 seconds"""
study_dir = STUDIES_DIR / self.study_id
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
while self.running:
try:
if db_path.exists():
await self._check_database(db_path)
await asyncio.sleep(2) # Poll every 2 seconds
except asyncio.CancelledError:
break
except Exception as e:
print(f"[WebSocket] Polling error for {self.study_id}: {e}")
await asyncio.sleep(5) # Back off on error
2025-12-05 19:57:20 -05:00
async def _check_database(self, db_path: Path):
"""Check database for new trials and updates"""
try:
2025-12-05 19:57:20 -05:00
# Use timeout to avoid blocking on locked databases
conn = sqlite3.connect(str(db_path), timeout=2.0)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get total completed trial count
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
total_count = cursor.fetchone()[0]
# Check for new trials
if total_count > self.last_trial_count:
await self._process_new_trials(cursor, total_count)
# Check for new best value
await self._check_best_value(cursor)
# Check Pareto front for multi-objective
await self._check_pareto_front(cursor, db_path)
# Send progress update
await self._send_progress(cursor, total_count)
conn.close()
except sqlite3.OperationalError as e:
# Database locked - skip this poll
pass
except Exception as e:
2025-12-05 19:57:20 -05:00
print(f"[WebSocket] Database check error: {e}")
async def _process_new_trials(self, cursor, total_count: int):
"""Process and broadcast new trials"""
# Get new trials since last check
cursor.execute("""
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE' AND t.trial_id > ?
ORDER BY t.trial_id ASC
""", (self.last_trial_id,))
new_trials = cursor.fetchall()
for row in new_trials:
trial_id = row['trial_id']
trial_data = await self._build_trial_data(cursor, row)
await self.callback({
"type": "trial_completed",
"data": trial_data
})
self.last_trial_id = trial_id
self.last_trial_count = total_count
async def _build_trial_data(self, cursor, row) -> Dict[str, Any]:
"""Build trial data dictionary from database row"""
trial_id = row['trial_id']
# Get objectives
cursor.execute("""
SELECT value FROM trial_values
WHERE trial_id = ? ORDER BY objective
""", (trial_id,))
values = [r[0] for r in cursor.fetchall()]
# Get parameters
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {}
for r in cursor.fetchall():
try:
params[r[0]] = float(r[1]) if r[1] is not None else None
except (ValueError, TypeError):
params[r[0]] = r[1]
# Get user attributes
cursor.execute("""
SELECT key, value_json FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for r in cursor.fetchall():
try:
user_attrs[r[0]] = json.loads(r[1])
except (ValueError, TypeError):
user_attrs[r[0]] = r[1]
# Extract source and design vars
source = user_attrs.get("source", "FEA")
design_vars = user_attrs.get("design_vars", params)
return {
"trial_number": trial_id, # Use trial_id for uniqueness
"trial_num": row['number'],
"objective": values[0] if values else None,
"values": values,
"params": design_vars,
"user_attrs": user_attrs,
"source": source,
"start_time": row['datetime_start'],
"end_time": row['datetime_complete'],
"study_name": row['study_name'],
"constraint_satisfied": user_attrs.get("constraint_satisfied", True)
}
async def _check_best_value(self, cursor):
"""Check for new best value and broadcast if changed"""
cursor.execute("""
SELECT MIN(tv.value) as best_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0
""")
result = cursor.fetchone()
if result and result['best_value'] is not None:
best_value = result['best_value']
if self.last_best_value is None or best_value < self.last_best_value:
# Get the best trial details
cursor.execute("""
SELECT t.trial_id, t.number
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE' AND tv.objective = 0 AND tv.value = ?
LIMIT 1
""", (best_value,))
best_row = cursor.fetchone()
if best_row:
# Get params for best trial
cursor.execute("""
SELECT param_name, param_value FROM trial_params
WHERE trial_id = ?
""", (best_row['trial_id'],))
params = {r[0]: r[1] for r in cursor.fetchall()}
await self.callback({
2025-12-05 19:57:20 -05:00
"type": "new_best",
"data": {
2025-12-05 19:57:20 -05:00
"trial_number": best_row['trial_id'],
"value": best_value,
"params": params,
"improvement": (
((self.last_best_value - best_value) / abs(self.last_best_value) * 100)
if self.last_best_value else 0
)
}
})
2025-12-05 19:57:20 -05:00
self.last_best_value = best_value
async def _check_pareto_front(self, cursor, db_path: Path):
"""Check for Pareto front updates in multi-objective studies"""
try:
2025-12-05 19:57:20 -05:00
# Check if multi-objective by counting distinct objectives
cursor.execute("""
SELECT COUNT(DISTINCT objective) as obj_count
FROM trial_values
WHERE trial_id IN (SELECT trial_id FROM trials WHERE state = 'COMPLETE')
""")
result = cursor.fetchone()
if result and result['obj_count'] > 1:
# Multi-objective - compute Pareto front
import optuna
storage = optuna.storages.RDBStorage(f"sqlite:///{db_path}")
# Get all study names
cursor.execute("SELECT study_name FROM studies")
study_names = [r[0] for r in cursor.fetchall()]
for study_name in study_names:
try:
study = optuna.load_study(study_name=study_name, storage=storage)
if len(study.directions) > 1:
pareto_trials = study.best_trials
if len(pareto_trials) != self.last_pareto_count:
pareto_data = [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True),
"source": t.user_attrs.get("source", "FEA")
}
for t in pareto_trials
]
await self.callback({
"type": "pareto_update",
"data": {
"pareto_front": pareto_data,
"count": len(pareto_trials)
}
})
self.last_pareto_count = len(pareto_trials)
break
except:
continue
except Exception as e:
2025-12-05 19:57:20 -05:00
# Non-critical - skip Pareto check
pass
async def _send_progress(self, cursor, total_count: int):
"""Send progress update"""
# Get total from config if available
study_dir = STUDIES_DIR / self.study_id
config_path = study_dir / "1_setup" / "optimization_config.json"
if not config_path.exists():
config_path = study_dir / "optimization_config.json"
total_target = 100 # Default
if config_path.exists():
try:
with open(config_path) as f:
config = json.load(f)
total_target = config.get('optimization_settings', {}).get('n_trials', 100)
except:
pass
# Count FEA vs NN trials
cursor.execute("""
SELECT
COUNT(CASE WHEN s.study_name LIKE '%_nn%' THEN 1 END) as nn_count,
COUNT(CASE WHEN s.study_name NOT LIKE '%_nn%' THEN 1 END) as fea_count
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
""")
counts = cursor.fetchone()
await self.callback({
"type": "progress",
"data": {
"current": total_count,
"total": total_target,
"percentage": min(100, (total_count / total_target * 100)) if total_target > 0 else 0,
"fea_count": counts['fea_count'] if counts else total_count,
"nn_count": counts['nn_count'] if counts else 0,
"timestamp": datetime.now().isoformat()
}
})
class ConnectionManager:
2025-12-05 19:57:20 -05:00
"""
Manages WebSocket connections and database pollers for real-time updates.
Uses database polling instead of file watching for better reliability on Windows.
"""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
2025-12-05 19:57:20 -05:00
self.pollers: Dict[str, DatabasePoller] = {}
async def connect(self, websocket: WebSocket, study_id: str):
2025-12-05 19:57:20 -05:00
"""Connect a new WebSocket client"""
await websocket.accept()
2025-12-05 19:57:20 -05:00
if study_id not in self.active_connections:
self.active_connections[study_id] = set()
2025-12-05 19:57:20 -05:00
self.active_connections[study_id].add(websocket)
2025-12-05 19:57:20 -05:00
# Start polling if not already running
if study_id not in self.pollers:
await self._start_polling(study_id)
async def disconnect(self, websocket: WebSocket, study_id: str):
"""Disconnect a WebSocket client"""
if study_id in self.active_connections:
2025-12-05 19:57:20 -05:00
self.active_connections[study_id].discard(websocket)
# Stop polling if no more connections
if not self.active_connections[study_id]:
del self.active_connections[study_id]
2025-12-05 19:57:20 -05:00
await self._stop_polling(study_id)
async def broadcast(self, message: dict, study_id: str):
2025-12-05 19:57:20 -05:00
"""Broadcast message to all connected clients for a study"""
if study_id not in self.active_connections:
return
2025-12-05 19:57:20 -05:00
disconnected = []
for connection in self.active_connections[study_id]:
try:
await connection.send_json(message)
except Exception as e:
print(f"[WebSocket] Error broadcasting to client: {e}")
disconnected.append(connection)
# Clean up disconnected clients
for conn in disconnected:
self.active_connections[study_id].discard(conn)
async def _start_polling(self, study_id: str):
"""Start database polling for a study"""
async def callback(message):
await self.broadcast(message, study_id)
2025-12-05 19:57:20 -05:00
poller = DatabasePoller(study_id, callback)
self.pollers[study_id] = poller
await poller.start()
print(f"[WebSocket] Started polling for study: {study_id}")
async def _stop_polling(self, study_id: str):
"""Stop database polling for a study"""
if study_id in self.pollers:
await self.pollers[study_id].stop()
del self.pollers[study_id]
print(f"[WebSocket] Stopped polling for study: {study_id}")
manager = ConnectionManager()
2025-12-05 19:57:20 -05:00
@router.websocket("/optimization/{study_id}")
async def optimization_stream(websocket: WebSocket, study_id: str):
2025-12-05 19:57:20 -05:00
"""
WebSocket endpoint for real-time optimization updates.
Sends messages:
- connected: Initial connection confirmation
- trial_completed: New trial completed with full data
- new_best: New best value found
- progress: Progress update (current/total, FEA/NN counts)
- pareto_update: Pareto front updated (multi-objective)
"""
await manager.connect(websocket, study_id)
2025-12-05 19:57:20 -05:00
try:
2025-12-05 19:57:20 -05:00
# Send initial connection message
await websocket.send_json({
"type": "connected",
2025-12-05 19:57:20 -05:00
"data": {
"study_id": study_id,
"message": f"Connected to real-time stream for study {study_id}",
"timestamp": datetime.now().isoformat()
}
})
2025-12-05 19:57:20 -05:00
# Keep connection alive
while True:
2025-12-05 19:57:20 -05:00
try:
# Wait for messages (ping/pong or commands)
data = await asyncio.wait_for(
websocket.receive_text(),
timeout=30.0 # 30 second timeout for ping
)
# Handle client commands
try:
msg = json.loads(data)
if msg.get("type") == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except asyncio.TimeoutError:
# Send heartbeat
try:
await websocket.send_json({"type": "heartbeat"})
except:
break
except WebSocketDisconnect:
2025-12-05 19:57:20 -05:00
pass
except Exception as e:
2025-12-05 19:57:20 -05:00
print(f"[WebSocket] Connection error for {study_id}: {e}")
finally:
await manager.disconnect(websocket, study_id)