- Add validation framework (config, model, results, study validators) - Add Claude Code skills (create-study, run-optimization, generate-report, troubleshoot, analyze-model) - Add Atomizer Dashboard (React frontend + FastAPI backend) - Reorganize docs into structured directories (00-09) - Add neural surrogate modules and training infrastructure - Add multi-objective optimization support 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
197 lines
7.9 KiB
Python
197 lines
7.9 KiB
Python
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Set
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from watchdog.observers import Observer
|
|
from watchdog.events import FileSystemEventHandler
|
|
import aiofiles
|
|
|
|
router = APIRouter()
|
|
|
|
# Base studies directory
|
|
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
|
|
|
|
class OptimizationFileHandler(FileSystemEventHandler):
|
|
def __init__(self, study_id: str, callback):
|
|
self.study_id = study_id
|
|
self.callback = callback
|
|
self.last_trial_count = 0
|
|
self.last_pruned_count = 0
|
|
self.last_pareto_count = 0
|
|
self.last_state_timestamp = ""
|
|
|
|
def on_modified(self, event):
|
|
if event.src_path.endswith("optimization_history_incremental.json"):
|
|
asyncio.run(self.process_history_update(event.src_path))
|
|
elif event.src_path.endswith("pruning_history.json"):
|
|
asyncio.run(self.process_pruning_update(event.src_path))
|
|
elif event.src_path.endswith("study.db"): # Watch for Optuna DB changes (Pareto front)
|
|
asyncio.run(self.process_pareto_update(event.src_path))
|
|
elif event.src_path.endswith("optimizer_state.json"):
|
|
asyncio.run(self.process_state_update(event.src_path))
|
|
|
|
async def process_history_update(self, file_path):
|
|
try:
|
|
async with aiofiles.open(file_path, mode='r') as f:
|
|
content = await f.read()
|
|
history = json.loads(content)
|
|
|
|
current_count = len(history)
|
|
if current_count > self.last_trial_count:
|
|
# New trials added
|
|
new_trials = history[self.last_trial_count:]
|
|
for trial in new_trials:
|
|
await self.callback({
|
|
"type": "trial_completed",
|
|
"data": trial
|
|
})
|
|
self.last_trial_count = current_count
|
|
except Exception as e:
|
|
print(f"Error processing history update: {e}")
|
|
|
|
async def process_pruning_update(self, file_path):
|
|
try:
|
|
async with aiofiles.open(file_path, mode='r') as f:
|
|
content = await f.read()
|
|
history = json.loads(content)
|
|
|
|
current_count = len(history)
|
|
if current_count > self.last_pruned_count:
|
|
# New pruned trials
|
|
new_pruned = history[self.last_pruned_count:]
|
|
for trial in new_pruned:
|
|
await self.callback({
|
|
"type": "trial_pruned",
|
|
"data": trial
|
|
})
|
|
self.last_pruned_count = current_count
|
|
except Exception as e:
|
|
print(f"Error processing pruning update: {e}")
|
|
|
|
async def process_pareto_update(self, file_path):
|
|
# This is tricky because study.db is binary.
|
|
# Instead of reading it directly, we'll trigger a re-fetch of the Pareto front via Optuna
|
|
# We debounce this to avoid excessive reads
|
|
try:
|
|
# Import here to avoid circular imports or heavy load at startup
|
|
import optuna
|
|
|
|
# Connect to DB
|
|
storage = optuna.storages.RDBStorage(f"sqlite:///{file_path}")
|
|
study = optuna.load_study(study_name=self.study_id, storage=storage)
|
|
|
|
# Check if multi-objective
|
|
if len(study.directions) > 1:
|
|
pareto_trials = study.best_trials
|
|
|
|
# Only broadcast if count changed (simple heuristic)
|
|
# In a real app, we might check content hash
|
|
if len(pareto_trials) != self.last_pareto_count:
|
|
pareto_data = [
|
|
{
|
|
"trial_number": t.number,
|
|
"values": t.values,
|
|
"params": t.params,
|
|
"user_attrs": dict(t.user_attrs),
|
|
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True)
|
|
}
|
|
for t in pareto_trials
|
|
]
|
|
|
|
await self.callback({
|
|
"type": "pareto_front",
|
|
"data": {
|
|
"pareto_front": pareto_data,
|
|
"count": len(pareto_trials)
|
|
}
|
|
})
|
|
self.last_pareto_count = len(pareto_trials)
|
|
|
|
except Exception as e:
|
|
# DB might be locked, ignore transient errors
|
|
pass
|
|
|
|
async def process_state_update(self, file_path):
|
|
try:
|
|
async with aiofiles.open(file_path, mode='r') as f:
|
|
content = await f.read()
|
|
state = json.loads(content)
|
|
|
|
# Check timestamp to avoid duplicate broadcasts
|
|
if state.get("timestamp") != self.last_state_timestamp:
|
|
await self.callback({
|
|
"type": "optimizer_state",
|
|
"data": state
|
|
})
|
|
self.last_state_timestamp = state.get("timestamp")
|
|
except Exception as e:
|
|
print(f"Error processing state update: {e}")
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
|
self.observers: Dict[str, Observer] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, study_id: str):
|
|
await websocket.accept()
|
|
if study_id not in self.active_connections:
|
|
self.active_connections[study_id] = set()
|
|
self.start_watching(study_id)
|
|
self.active_connections[study_id].add(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket, study_id: str):
|
|
if study_id in self.active_connections:
|
|
self.active_connections[study_id].remove(websocket)
|
|
if not self.active_connections[study_id]:
|
|
del self.active_connections[study_id]
|
|
self.stop_watching(study_id)
|
|
|
|
async def broadcast(self, message: dict, study_id: str):
|
|
if study_id in self.active_connections:
|
|
for connection in self.active_connections[study_id]:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
print(f"Error broadcasting to client: {e}")
|
|
|
|
def start_watching(self, study_id: str):
|
|
study_dir = STUDIES_DIR / study_id / "2_results"
|
|
if not study_dir.exists():
|
|
return
|
|
|
|
async def callback(message):
|
|
await self.broadcast(message, study_id)
|
|
|
|
event_handler = OptimizationFileHandler(study_id, callback)
|
|
observer = Observer()
|
|
observer.schedule(event_handler, str(study_dir), recursive=True)
|
|
observer.start()
|
|
self.observers[study_id] = observer
|
|
|
|
def stop_watching(self, study_id: str):
|
|
if study_id in self.observers:
|
|
self.observers[study_id].stop()
|
|
self.observers[study_id].join()
|
|
del self.observers[study_id]
|
|
|
|
manager = ConnectionManager()
|
|
|
|
@router.websocket("/optimization/{study_id}")
|
|
async def optimization_stream(websocket: WebSocket, study_id: str):
|
|
await manager.connect(websocket, study_id)
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {"message": f"Connected to stream for study {study_id}"}
|
|
})
|
|
while True:
|
|
# Keep connection alive and handle incoming messages if needed
|
|
data = await websocket.receive_text()
|
|
# We could handle client commands here (e.g., "pause", "stop")
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket, study_id)
|
|
except Exception as e:
|
|
print(f"WebSocket error: {e}")
|
|
manager.disconnect(websocket, study_id)
|