feat: Major update with validators, skills, dashboard, and docs reorganization
- Add validation framework (config, model, results, study validators) - Add Claude Code skills (create-study, run-optimization, generate-report, troubleshoot, analyze-model) - Add Atomizer Dashboard (React frontend + FastAPI backend) - Reorganize docs into structured directories (00-09) - Add neural surrogate modules and training infrastructure - Add multi-objective optimization support 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
196
atomizer-dashboard/backend/api/websocket/optimization_stream.py
Normal file
196
atomizer-dashboard/backend/api/websocket/optimization_stream.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
import aiofiles
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Base studies directory
|
||||
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
|
||||
|
||||
class OptimizationFileHandler(FileSystemEventHandler):
|
||||
def __init__(self, study_id: str, callback):
|
||||
self.study_id = study_id
|
||||
self.callback = callback
|
||||
self.last_trial_count = 0
|
||||
self.last_pruned_count = 0
|
||||
self.last_pareto_count = 0
|
||||
self.last_state_timestamp = ""
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.src_path.endswith("optimization_history_incremental.json"):
|
||||
asyncio.run(self.process_history_update(event.src_path))
|
||||
elif event.src_path.endswith("pruning_history.json"):
|
||||
asyncio.run(self.process_pruning_update(event.src_path))
|
||||
elif event.src_path.endswith("study.db"): # Watch for Optuna DB changes (Pareto front)
|
||||
asyncio.run(self.process_pareto_update(event.src_path))
|
||||
elif event.src_path.endswith("optimizer_state.json"):
|
||||
asyncio.run(self.process_state_update(event.src_path))
|
||||
|
||||
async def process_history_update(self, file_path):
|
||||
try:
|
||||
async with aiofiles.open(file_path, mode='r') as f:
|
||||
content = await f.read()
|
||||
history = json.loads(content)
|
||||
|
||||
current_count = len(history)
|
||||
if current_count > self.last_trial_count:
|
||||
# New trials added
|
||||
new_trials = history[self.last_trial_count:]
|
||||
for trial in new_trials:
|
||||
await self.callback({
|
||||
"type": "trial_completed",
|
||||
"data": trial
|
||||
})
|
||||
self.last_trial_count = current_count
|
||||
except Exception as e:
|
||||
print(f"Error processing history update: {e}")
|
||||
|
||||
async def process_pruning_update(self, file_path):
|
||||
try:
|
||||
async with aiofiles.open(file_path, mode='r') as f:
|
||||
content = await f.read()
|
||||
history = json.loads(content)
|
||||
|
||||
current_count = len(history)
|
||||
if current_count > self.last_pruned_count:
|
||||
# New pruned trials
|
||||
new_pruned = history[self.last_pruned_count:]
|
||||
for trial in new_pruned:
|
||||
await self.callback({
|
||||
"type": "trial_pruned",
|
||||
"data": trial
|
||||
})
|
||||
self.last_pruned_count = current_count
|
||||
except Exception as e:
|
||||
print(f"Error processing pruning update: {e}")
|
||||
|
||||
async def process_pareto_update(self, file_path):
|
||||
# This is tricky because study.db is binary.
|
||||
# Instead of reading it directly, we'll trigger a re-fetch of the Pareto front via Optuna
|
||||
# We debounce this to avoid excessive reads
|
||||
try:
|
||||
# Import here to avoid circular imports or heavy load at startup
|
||||
import optuna
|
||||
|
||||
# Connect to DB
|
||||
storage = optuna.storages.RDBStorage(f"sqlite:///{file_path}")
|
||||
study = optuna.load_study(study_name=self.study_id, storage=storage)
|
||||
|
||||
# Check if multi-objective
|
||||
if len(study.directions) > 1:
|
||||
pareto_trials = study.best_trials
|
||||
|
||||
# Only broadcast if count changed (simple heuristic)
|
||||
# In a real app, we might check content hash
|
||||
if len(pareto_trials) != self.last_pareto_count:
|
||||
pareto_data = [
|
||||
{
|
||||
"trial_number": t.number,
|
||||
"values": t.values,
|
||||
"params": t.params,
|
||||
"user_attrs": dict(t.user_attrs),
|
||||
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True)
|
||||
}
|
||||
for t in pareto_trials
|
||||
]
|
||||
|
||||
await self.callback({
|
||||
"type": "pareto_front",
|
||||
"data": {
|
||||
"pareto_front": pareto_data,
|
||||
"count": len(pareto_trials)
|
||||
}
|
||||
})
|
||||
self.last_pareto_count = len(pareto_trials)
|
||||
|
||||
except Exception as e:
|
||||
# DB might be locked, ignore transient errors
|
||||
pass
|
||||
|
||||
async def process_state_update(self, file_path):
|
||||
try:
|
||||
async with aiofiles.open(file_path, mode='r') as f:
|
||||
content = await f.read()
|
||||
state = json.loads(content)
|
||||
|
||||
# Check timestamp to avoid duplicate broadcasts
|
||||
if state.get("timestamp") != self.last_state_timestamp:
|
||||
await self.callback({
|
||||
"type": "optimizer_state",
|
||||
"data": state
|
||||
})
|
||||
self.last_state_timestamp = state.get("timestamp")
|
||||
except Exception as e:
|
||||
print(f"Error processing state update: {e}")
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
self.observers: Dict[str, Observer] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, study_id: str):
|
||||
await websocket.accept()
|
||||
if study_id not in self.active_connections:
|
||||
self.active_connections[study_id] = set()
|
||||
self.start_watching(study_id)
|
||||
self.active_connections[study_id].add(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, study_id: str):
|
||||
if study_id in self.active_connections:
|
||||
self.active_connections[study_id].remove(websocket)
|
||||
if not self.active_connections[study_id]:
|
||||
del self.active_connections[study_id]
|
||||
self.stop_watching(study_id)
|
||||
|
||||
async def broadcast(self, message: dict, study_id: str):
|
||||
if study_id in self.active_connections:
|
||||
for connection in self.active_connections[study_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
print(f"Error broadcasting to client: {e}")
|
||||
|
||||
def start_watching(self, study_id: str):
|
||||
study_dir = STUDIES_DIR / study_id / "2_results"
|
||||
if not study_dir.exists():
|
||||
return
|
||||
|
||||
async def callback(message):
|
||||
await self.broadcast(message, study_id)
|
||||
|
||||
event_handler = OptimizationFileHandler(study_id, callback)
|
||||
observer = Observer()
|
||||
observer.schedule(event_handler, str(study_dir), recursive=True)
|
||||
observer.start()
|
||||
self.observers[study_id] = observer
|
||||
|
||||
def stop_watching(self, study_id: str):
|
||||
if study_id in self.observers:
|
||||
self.observers[study_id].stop()
|
||||
self.observers[study_id].join()
|
||||
del self.observers[study_id]
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
@router.websocket("/optimization/{study_id}")
|
||||
async def optimization_stream(websocket: WebSocket, study_id: str):
|
||||
await manager.connect(websocket, study_id)
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"data": {"message": f"Connected to stream for study {study_id}"}
|
||||
})
|
||||
while True:
|
||||
# Keep connection alive and handle incoming messages if needed
|
||||
data = await websocket.receive_text()
|
||||
# We could handle client commands here (e.g., "pause", "stop")
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, study_id)
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, study_id)
|
||||
Reference in New Issue
Block a user