"""
Optimization API endpoints
Handles study status, history retrieval, and control operations
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
from pathlib import Path
from typing import List, Dict, Optional
import json
import sys
import sqlite3
import shutil
import subprocess
import psutil
import signal
from datetime import datetime
# Add project root to path
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
router = APIRouter()
# Base studies directory
STUDIES_DIR = Path(__file__).parent.parent.parent.parent.parent / "studies"
def get_results_dir(study_dir: Path) -> Path:
"""Get the results directory for a study, supporting both 2_results and 3_results."""
results_dir = study_dir / "2_results"
if not results_dir.exists():
results_dir = study_dir / "3_results"
return results_dir
def resolve_study_path(study_id: str) -> Path:
"""Find study folder by scanning all topic directories.
Supports nested folder structure: studies/Topic/study_name/
Study ID is the short name (e.g., 'm1_mirror_adaptive_V14')
Returns the full path to the study directory.
Raises HTTPException 404 if not found.
"""
# First check direct path (backwards compatibility for flat structure)
direct_path = STUDIES_DIR / study_id
if direct_path.exists() and direct_path.is_dir():
# Verify it's actually a study (has 1_setup or config)
if (direct_path / "1_setup").exists() or (direct_path / "optimization_config.json").exists():
return direct_path
# Scan topic folders for nested structure
for topic_dir in STUDIES_DIR.iterdir():
if topic_dir.is_dir() and not topic_dir.name.startswith('.'):
study_dir = topic_dir / study_id
if study_dir.exists() and study_dir.is_dir():
# Verify it's actually a study
if (study_dir / "1_setup").exists() or (study_dir / "optimization_config.json").exists():
return study_dir
raise HTTPException(status_code=404, detail=f"Study not found: {study_id}")
def get_study_topic(study_dir: Path) -> Optional[str]:
"""Get the topic folder name for a study, or None if in root."""
# Check if parent is a topic folder (not the root studies dir)
parent = study_dir.parent
if parent != STUDIES_DIR and parent.parent == STUDIES_DIR:
return parent.name
return None
def is_optimization_running(study_id: str) -> bool:
"""Check if an optimization process is currently running for a study.
Looks for Python processes running run_optimization.py with the study_id in the command line.
"""
try:
study_dir = resolve_study_path(study_id)
except HTTPException:
return False
for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'cwd']):
try:
cmdline = proc.info.get('cmdline') or []
cmdline_str = ' '.join(cmdline) if cmdline else ''
# Check if this is a Python process running run_optimization.py for this study
if 'python' in cmdline_str.lower() and 'run_optimization' in cmdline_str:
if study_id in cmdline_str or str(study_dir) in cmdline_str:
return True
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return False
def get_accurate_study_status(study_id: str, trial_count: int, total_trials: int, has_db: bool) -> str:
"""Determine accurate study status based on multiple factors.
Status can be:
- not_started: No database or 0 trials
- running: Active process found
- paused: Has trials but no active process and not completed
- completed: Reached trial target
- failed: Has error indicators (future enhancement)
Args:
study_id: The study identifier
trial_count: Number of completed trials
total_trials: Target number of trials from config
has_db: Whether the study database exists
Returns:
Status string: "not_started", "running", "paused", or "completed"
"""
# No database or no trials = not started
if not has_db or trial_count == 0:
return "not_started"
# Check if we've reached the target
if trial_count >= total_trials:
return "completed"
# Check if process is actively running
if is_optimization_running(study_id):
return "running"
# Has trials but not running and not complete = paused
return "paused"
def _load_study_info(study_dir: Path, topic: Optional[str] = None) -> Optional[dict]:
"""Load study info from a study directory. Returns None if not a valid study."""
# Look for optimization config (check multiple locations)
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
return None
# Load config
with open(config_file) as f:
config = json.load(f)
# Check if results directory exists (support both 2_results and 3_results)
results_dir = study_dir / "2_results"
if not results_dir.exists():
results_dir = study_dir / "3_results"
# Check for Optuna database (Protocol 10) or JSON history (other protocols)
study_db = results_dir / "study.db"
history_file = results_dir / "optimization_history_incremental.json"
trial_count = 0
best_value = None
has_db = False
# Protocol 10: Read from Optuna SQLite database
if study_db.exists():
has_db = True
try:
# Use timeout to avoid blocking on locked databases
conn = sqlite3.connect(str(study_db), timeout=2.0)
cursor = conn.cursor()
# Get trial count and status
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
trial_count = cursor.fetchone()[0]
# Get best trial (for single-objective, or first objective for multi-objective)
if trial_count > 0:
cursor.execute("""
SELECT value FROM trial_values
WHERE trial_id IN (
SELECT trial_id FROM trials WHERE state = 'COMPLETE'
)
ORDER BY value ASC
LIMIT 1
""")
result = cursor.fetchone()
if result:
best_value = result[0]
conn.close()
except Exception as e:
print(f"Warning: Failed to read Optuna database for {study_dir.name}: {e}")
# Legacy: Read from JSON history
elif history_file.exists():
has_db = True
with open(history_file) as f:
history = json.load(f)
trial_count = len(history)
if history:
# Find best trial
best_trial = min(history, key=lambda x: x['objective'])
best_value = best_trial['objective']
# Get total trials from config (supports both formats)
total_trials = (
config.get('optimization_settings', {}).get('n_trials') or
config.get('optimization', {}).get('n_trials') or
config.get('trials', {}).get('n_trials', 50)
)
# Get accurate status using process detection
status = get_accurate_study_status(study_dir.name, trial_count, total_trials, has_db)
# Get creation date from directory or config modification time
created_at = None
try:
# First try to get from database (most accurate)
if study_db.exists():
created_at = datetime.fromtimestamp(study_db.stat().st_mtime).isoformat()
elif config_file.exists():
created_at = datetime.fromtimestamp(config_file.stat().st_mtime).isoformat()
else:
created_at = datetime.fromtimestamp(study_dir.stat().st_ctime).isoformat()
except:
created_at = None
# Get last modified time
last_modified = None
try:
if study_db.exists():
last_modified = datetime.fromtimestamp(study_db.stat().st_mtime).isoformat()
elif history_file.exists():
last_modified = datetime.fromtimestamp(history_file.stat().st_mtime).isoformat()
except:
last_modified = None
return {
"id": study_dir.name,
"name": study_dir.name.replace("_", " ").title(),
"topic": topic, # NEW: topic field for grouping
"status": status,
"progress": {
"current": trial_count,
"total": total_trials
},
"best_value": best_value,
"target": config.get('target', {}).get('value'),
"path": str(study_dir),
"created_at": created_at,
"last_modified": last_modified
}
@router.get("/studies")
async def list_studies():
"""List all available optimization studies.
Supports both flat and nested folder structures:
- Flat: studies/study_name/
- Nested: studies/Topic/study_name/
Returns studies with 'topic' field for frontend grouping.
"""
try:
studies = []
if not STUDIES_DIR.exists():
return {"studies": []}
for item in STUDIES_DIR.iterdir():
if not item.is_dir():
continue
if item.name.startswith('.'):
continue
# Check if this is a study (flat structure) or a topic folder (nested structure)
is_study = (item / "1_setup").exists() or (item / "optimization_config.json").exists()
if is_study:
# Flat structure: study directly in studies/
study_info = _load_study_info(item, topic=None)
if study_info:
studies.append(study_info)
else:
# Nested structure: this might be a topic folder
# Check if it contains study subdirectories
for sub_item in item.iterdir():
if not sub_item.is_dir():
continue
if sub_item.name.startswith('.'):
continue
# Check if this subdirectory is a study
sub_is_study = (sub_item / "1_setup").exists() or (sub_item / "optimization_config.json").exists()
if sub_is_study:
study_info = _load_study_info(sub_item, topic=item.name)
if study_info:
studies.append(study_info)
return {"studies": studies}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list studies: {str(e)}")
@router.get("/studies/{study_id}/status")
async def get_study_status(study_id: str):
"""Get detailed status of a specific study"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Load config (check multiple locations)
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "1_setup" / "optimization_config.json"
with open(config_file) as f:
config = json.load(f)
# Check for results (support both 2_results and 3_results)
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
history_file = results_dir / "optimization_history_incremental.json"
# Protocol 10: Read from Optuna database
if study_db.exists():
conn = sqlite3.connect(str(study_db))
cursor = conn.cursor()
# Get trial counts by state
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
trial_count = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'PRUNED'")
pruned_count = cursor.fetchone()[0]
# Get best trial (first objective for multi-objective)
best_trial = None
if trial_count > 0:
cursor.execute("""
SELECT t.trial_id, t.number, tv.value
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY tv.value ASC
LIMIT 1
""")
result = cursor.fetchone()
if result:
trial_id, trial_number, best_value = result
# Get parameters for this trial
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {row[0]: row[1] for row in cursor.fetchall()}
best_trial = {
"trial_number": trial_number,
"objective": best_value,
"design_variables": params,
"results": {"first_frequency": best_value}
}
conn.close()
total_trials = config.get('optimization_settings', {}).get('n_trials', 50)
status = get_accurate_study_status(study_id, trial_count, total_trials, True)
return {
"study_id": study_id,
"status": status,
"progress": {
"current": trial_count,
"total": total_trials,
"percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0
},
"best_trial": best_trial,
"pruned_trials": pruned_count,
"config": config
}
# Legacy: Read from JSON history
if not history_file.exists():
return {
"study_id": study_id,
"status": "not_started",
"progress": {"current": 0, "total": config.get('trials', {}).get('n_trials', 50)},
"config": config
}
with open(history_file) as f:
history = json.load(f)
trial_count = len(history)
total_trials = config.get('trials', {}).get('n_trials', 50)
# Find best trial
best_trial = None
if history:
best_trial = min(history, key=lambda x: x['objective'])
# Check for pruning data
pruning_file = results_dir / "pruning_history.json"
pruned_count = 0
if pruning_file.exists():
with open(pruning_file) as f:
pruning_history = json.load(f)
pruned_count = len(pruning_history)
status = "completed" if trial_count >= total_trials else "running"
return {
"study_id": study_id,
"status": status,
"progress": {
"current": trial_count,
"total": total_trials,
"percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0
},
"best_trial": best_trial,
"pruned_trials": pruned_count,
"config": config
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get study status: {str(e)}")
@router.get("/studies/{study_id}/history")
async def get_optimization_history(study_id: str, limit: Optional[int] = None):
"""Get optimization history (all trials)"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
history_file = results_dir / "optimization_history_incremental.json"
# Protocol 10: Read from Optuna database
if study_db.exists():
conn = sqlite3.connect(str(study_db))
cursor = conn.cursor()
# Get all completed trials FROM ALL STUDIES in the database
# This handles adaptive optimizations that create multiple Optuna studies
# (e.g., v11_fea for FEA trials, v11_iter1_nn for NN trials, etc.)
cursor.execute("""
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name
FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
ORDER BY t.datetime_start DESC
""" + (f" LIMIT {limit}" if limit else ""))
trial_rows = cursor.fetchall()
trials = []
for trial_id, trial_num, start_time, end_time, study_name in trial_rows:
# Get objectives for this trial
cursor.execute("""
SELECT value
FROM trial_values
WHERE trial_id = ?
ORDER BY objective
""", (trial_id,))
values = [row[0] for row in cursor.fetchall()]
# Get parameters for this trial
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {}
for param_name, param_value in cursor.fetchall():
try:
params[param_name] = float(param_value) if param_value is not None else None
except (ValueError, TypeError):
params[param_name] = param_value
# Get user attributes (extracted results: mass, frequency, stress, displacement, etc.)
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for key, value_json in cursor.fetchall():
try:
user_attrs[key] = json.loads(value_json)
except (ValueError, TypeError):
user_attrs[key] = value_json
# Extract ALL numeric metrics from user_attrs for results
# This ensures multi-objective studies show all Zernike metrics, RMS values, etc.
results = {}
excluded_keys = {"design_vars", "constraint_satisfied", "constraint_violations"}
for key, val in user_attrs.items():
if key in excluded_keys:
continue
# Include numeric values and lists of numbers
if isinstance(val, (int, float)):
results[key] = val
elif isinstance(val, list) and len(val) > 0 and isinstance(val[0], (int, float)):
# For lists, store as-is (e.g., Zernike coefficients)
results[key] = val
elif key == "objectives" and isinstance(val, dict):
# Extract nested objectives dict (Zernike multi-objective studies)
for obj_key, obj_val in val.items():
if isinstance(obj_val, (int, float)):
results[obj_key] = obj_val
# Fallback to first frequency from objectives if available
if not results and len(values) > 0:
results["first_frequency"] = values[0]
# CRITICAL: Extract design_vars from user_attrs if stored there
# The optimization code does: trial.set_user_attr("design_vars", design_vars)
design_vars_from_attrs = user_attrs.get("design_vars", {})
# Merge with params (prefer user_attrs design_vars if available)
final_design_vars = {**params, **design_vars_from_attrs} if design_vars_from_attrs else params
# Extract source for FEA vs NN differentiation
source = user_attrs.get("source", "FEA") # Default to FEA for legacy studies
# Get iter_num from user_attrs if available (this is the actual iteration folder number)
iter_num = user_attrs.get("iter_num", None)
# Use iter_num if available, otherwise use trial_id as unique identifier
# trial_id is unique across all studies in the database
unique_trial_num = iter_num if iter_num is not None else trial_id
trials.append({
"trial_number": unique_trial_num,
"trial_id": trial_id, # Keep original for debugging
"optuna_trial_num": trial_num, # Keep original Optuna trial number
"objective": values[0] if len(values) > 0 else None, # Primary objective
"objectives": values if len(values) > 1 else None, # All objectives for multi-objective
"design_variables": final_design_vars, # Use merged design vars
"results": results,
"user_attrs": user_attrs, # Include all user attributes
"source": source, # FEA or NN
"start_time": start_time,
"end_time": end_time,
"study_name": study_name # Include for debugging
})
conn.close()
return {"trials": trials}
# Legacy: Read from JSON history
if not history_file.exists():
return {"trials": []}
with open(history_file) as f:
history = json.load(f)
# Apply limit if specified
if limit:
history = history[-limit:]
return {"trials": history}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get history: {str(e)}")
@router.get("/studies/{study_id}/pruning")
async def get_pruning_history(study_id: str):
"""Get pruning diagnostics from Optuna database or legacy JSON file"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
pruning_file = results_dir / "pruning_history.json"
# Protocol 10+: Read from Optuna database
if study_db.exists():
conn = sqlite3.connect(str(study_db))
cursor = conn.cursor()
# Get all pruned trials from Optuna database
cursor.execute("""
SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete
FROM trials t
WHERE t.state = 'PRUNED'
ORDER BY t.number DESC
""")
pruned_rows = cursor.fetchall()
pruned_trials = []
for trial_id, trial_num, start_time, end_time in pruned_rows:
# Get parameters for this trial
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {row[0]: row[1] for row in cursor.fetchall()}
# Get user attributes (may contain pruning cause)
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for key, value_json in cursor.fetchall():
try:
user_attrs[key] = json.loads(value_json)
except (ValueError, TypeError):
user_attrs[key] = value_json
pruned_trials.append({
"trial_number": trial_num,
"params": params,
"pruning_cause": user_attrs.get("pruning_cause", "Unknown"),
"start_time": start_time,
"end_time": end_time
})
conn.close()
return {"pruned_trials": pruned_trials, "count": len(pruned_trials)}
# Legacy: Read from JSON history
if not pruning_file.exists():
return {"pruned_trials": [], "count": 0}
with open(pruning_file) as f:
pruning_history = json.load(f)
return {"pruned_trials": pruning_history, "count": len(pruning_history)}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get pruning history: {str(e)}")
def _infer_objective_unit(objective: Dict) -> str:
"""Infer unit from objective name and description"""
name = objective.get("name", "").lower()
desc = objective.get("description", "").lower()
# Common unit patterns
if "frequency" in name or "hz" in desc:
return "Hz"
elif "stiffness" in name or "n/mm" in desc:
return "N/mm"
elif "mass" in name or "kg" in desc:
return "kg"
elif "stress" in name or "mpa" in desc or "pa" in desc:
return "MPa"
elif "displacement" in name or "mm" in desc:
return "mm"
elif "force" in name or "newton" in desc:
return "N"
elif "%" in desc or "percent" in desc:
return "%"
# Check if unit is explicitly mentioned in description (e.g., "(N/mm)")
import re
unit_match = re.search(r'\(([^)]+)\)', desc)
if unit_match:
return unit_match.group(1)
return "" # No unit found
@router.get("/studies/{study_id}/metadata")
async def get_study_metadata(study_id: str):
"""Read optimization_config.json for objectives, design vars, units (Protocol 13)"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Load config (check multiple locations)
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}")
with open(config_file) as f:
config = json.load(f)
# Enhance objectives with inferred units if not present
objectives = config.get("objectives", [])
for obj in objectives:
if "unit" not in obj or not obj["unit"]:
obj["unit"] = _infer_objective_unit(obj)
# Get sampler/algorithm info
optimization = config.get("optimization", {})
algorithm = optimization.get("algorithm", "TPE")
# Map algorithm names to Optuna sampler names for frontend display
sampler_map = {
"CMA-ES": "CmaEsSampler",
"cma-es": "CmaEsSampler",
"cmaes": "CmaEsSampler",
"TPE": "TPESampler",
"tpe": "TPESampler",
"NSGA-II": "NSGAIISampler",
"nsga-ii": "NSGAIISampler",
"NSGA-III": "NSGAIIISampler",
"Random": "RandomSampler",
}
sampler = sampler_map.get(algorithm, algorithm)
return {
"objectives": objectives,
"design_variables": config.get("design_variables", []),
"constraints": config.get("constraints", []),
"study_name": config.get("study_name", study_id),
"description": config.get("description", ""),
"sampler": sampler,
"algorithm": algorithm,
"n_trials": optimization.get("n_trials", 100)
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get study metadata: {str(e)}")
@router.get("/studies/{study_id}/optimizer-state")
async def get_optimizer_state(study_id: str):
"""Read realtime optimizer state from intelligent_optimizer/ (Protocol 13)"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
state_file = results_dir / "intelligent_optimizer" / "optimizer_state.json"
if not state_file.exists():
return {"available": False}
with open(state_file) as f:
state = json.load(f)
return {"available": True, **state}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get optimizer state: {str(e)}")
@router.get("/studies/{study_id}/pareto-front")
async def get_pareto_front(study_id: str):
"""Get Pareto-optimal solutions for multi-objective studies (Protocol 13)"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
if not study_db.exists():
return {"is_multi_objective": False, "pareto_front": []}
# Import optuna here to avoid loading it for all endpoints
import optuna
storage = optuna.storages.RDBStorage(f"sqlite:///{study_db}")
study = optuna.load_study(study_name=study_id, storage=storage)
# Check if multi-objective
if len(study.directions) == 1:
return {"is_multi_objective": False, "pareto_front": []}
# Get Pareto front
pareto_trials = study.best_trials
return {
"is_multi_objective": True,
"pareto_front": [
{
"trial_number": t.number,
"values": t.values,
"params": t.params,
"user_attrs": dict(t.user_attrs),
"constraint_satisfied": t.user_attrs.get("constraint_satisfied", True)
}
for t in pareto_trials
]
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get Pareto front: {str(e)}")
@router.get("/studies/{study_id}/nn-pareto-front")
async def get_nn_pareto_front(study_id: str):
"""Get NN surrogate Pareto front from nn_pareto_front.json"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
nn_pareto_file = results_dir / "nn_pareto_front.json"
if not nn_pareto_file.exists():
return {"has_nn_results": False, "pareto_front": []}
with open(nn_pareto_file) as f:
nn_pareto = json.load(f)
# Transform to match Trial interface format
transformed = []
for trial in nn_pareto:
transformed.append({
"trial_number": trial.get("trial_number"),
"values": [trial.get("mass"), trial.get("frequency")],
"params": trial.get("params", {}),
"user_attrs": {
"source": "NN",
"feasible": trial.get("feasible", False),
"predicted_stress": trial.get("predicted_stress"),
"predicted_displacement": trial.get("predicted_displacement"),
"mass": trial.get("mass"),
"frequency": trial.get("frequency")
},
"constraint_satisfied": trial.get("feasible", False),
"source": "NN"
})
return {
"has_nn_results": True,
"pareto_front": transformed,
"count": len(transformed)
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get NN Pareto front: {str(e)}")
@router.get("/studies/{study_id}/nn-state")
async def get_nn_optimization_state(study_id: str):
"""Get NN optimization state/summary from nn_optimization_state.json"""
try:
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
nn_state_file = results_dir / "nn_optimization_state.json"
if not nn_state_file.exists():
return {"has_nn_state": False}
with open(nn_state_file) as f:
state = json.load(f)
return {
"has_nn_state": True,
"total_fea_count": state.get("total_fea_count", 0),
"total_nn_count": state.get("total_nn_count", 0),
"pareto_front_size": state.get("pareto_front_size", 0),
"best_mass": state.get("best_mass"),
"best_frequency": state.get("best_frequency"),
"timestamp": state.get("timestamp")
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get NN state: {str(e)}")
@router.post("/studies")
async def create_study(
config: str = Form(...),
prt_file: Optional[UploadFile] = File(None),
sim_file: Optional[UploadFile] = File(None),
fem_file: Optional[UploadFile] = File(None)
):
"""
Create a new optimization study
Accepts:
- config: JSON string with study configuration
- prt_file: NX part file (optional if using existing study)
- sim_file: NX simulation file (optional)
- fem_file: NX FEM file (optional)
"""
try:
# Parse config
config_data = json.loads(config)
study_name = config_data.get("name") # Changed from study_name to name to match frontend
if not study_name:
raise HTTPException(status_code=400, detail="name is required in config")
# Create study directory structure
study_dir = STUDIES_DIR / study_name
if study_dir.exists():
raise HTTPException(status_code=400, detail=f"Study {study_name} already exists")
setup_dir = study_dir / "1_setup"
model_dir = setup_dir / "model"
results_dir = study_dir / "2_results"
setup_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)
# Save config file
config_file = setup_dir / "optimization_config.json"
with open(config_file, 'w') as f:
json.dump(config_data, f, indent=2)
# Save uploaded files
files_saved = {}
if prt_file:
prt_path = model_dir / prt_file.filename
with open(prt_path, 'wb') as f:
content = await prt_file.read()
f.write(content)
files_saved['prt_file'] = str(prt_path)
if sim_file:
sim_path = model_dir / sim_file.filename
with open(sim_path, 'wb') as f:
content = await sim_file.read()
f.write(content)
files_saved['sim_file'] = str(sim_path)
if fem_file:
fem_path = model_dir / fem_file.filename
with open(fem_path, 'wb') as f:
content = await fem_file.read()
f.write(content)
files_saved['fem_file'] = str(fem_path)
return JSONResponse(
status_code=201,
content={
"status": "created",
"study_id": study_name,
"study_path": str(study_dir),
"config_path": str(config_file),
"files_saved": files_saved,
"message": f"Study {study_name} created successfully. Ready to run optimization."
}
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON in config: {str(e)}")
except Exception as e:
# Clean up on error
if 'study_dir' in locals() and study_dir.exists():
shutil.rmtree(study_dir)
raise HTTPException(status_code=500, detail=f"Failed to create study: {str(e)}")
@router.post("/studies/{study_id}/convert-mesh")
async def convert_study_mesh(study_id: str):
"""
Convert study mesh to GLTF for 3D visualization
Creates a web-viewable 3D model with FEA results as vertex colors
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Import mesh converter
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
from optimization_engine.mesh_converter import convert_study_mesh
# Convert mesh
output_path = convert_study_mesh(study_dir)
if output_path and output_path.exists():
return {
"status": "success",
"gltf_path": str(output_path),
"gltf_url": f"/api/optimization/studies/{study_id}/mesh/model.gltf",
"metadata_url": f"/api/optimization/studies/{study_id}/mesh/model.json",
"message": "Mesh converted successfully"
}
else:
raise HTTPException(status_code=500, detail="Mesh conversion failed")
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to convert mesh: {str(e)}")
@router.get("/studies/{study_id}/mesh/{filename}")
async def get_mesh_file(study_id: str, filename: str):
"""
Serve GLTF mesh files and metadata
Supports .gltf, .bin, and .json files
"""
try:
# Validate filename to prevent directory traversal
if '..' in filename or '/' in filename or '\\' in filename:
raise HTTPException(status_code=400, detail="Invalid filename")
study_dir = resolve_study_path(study_id)
visualization_dir = study_dir / "3_visualization"
file_path = visualization_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail=f"File {filename} not found")
# Determine content type
suffix = file_path.suffix.lower()
content_types = {
'.gltf': 'model/gltf+json',
'.bin': 'application/octet-stream',
'.json': 'application/json',
'.glb': 'model/gltf-binary'
}
content_type = content_types.get(suffix, 'application/octet-stream')
return FileResponse(
path=str(file_path),
media_type=content_type,
filename=filename
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to serve mesh file: {str(e)}")
@router.get("/studies/{study_id}/optuna-url")
async def get_optuna_dashboard_url(study_id: str):
"""
Get the Optuna dashboard URL for a specific study.
Returns the URL to access the study in Optuna dashboard.
The Optuna dashboard should be started with a relative path from the Atomizer root:
sqlite:///studies/{study_id}/2_results/study.db
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
if not study_db.exists():
raise HTTPException(status_code=404, detail=f"No Optuna database found for study {study_id}")
# Get the study name from the database (may differ from folder name)
import optuna
storage = optuna.storages.RDBStorage(f"sqlite:///{study_db}")
studies = storage.get_all_studies()
if not studies:
raise HTTPException(status_code=404, detail=f"No Optuna study found in database for {study_id}")
# Use the actual study name from the database
optuna_study_name = studies[0].study_name
# Return URL info for the frontend
# The dashboard should be running on port 8081 with the correct database
return {
"study_id": study_id,
"optuna_study_name": optuna_study_name,
"database_path": f"studies/{study_id}/2_results/study.db",
"dashboard_url": f"http://localhost:8081/dashboard/studies/{studies[0]._study_id}",
"dashboard_base": "http://localhost:8081",
"note": "Optuna dashboard must be started with: sqlite:///studies/{study_id}/2_results/study.db"
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get Optuna URL: {str(e)}")
@router.post("/studies/{study_id}/generate-report")
async def generate_report(
study_id: str,
format: str = "markdown",
include_llm_summary: bool = False
):
"""
Generate an optimization report in the specified format
Args:
study_id: Study identifier
format: Report format ('markdown', 'html', or 'pdf')
include_llm_summary: Whether to include LLM-generated executive summary
Returns:
Information about the generated report including download URL
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Validate format
valid_formats = ['markdown', 'md', 'html', 'pdf']
if format.lower() not in valid_formats:
raise HTTPException(status_code=400, detail=f"Invalid format. Must be one of: {', '.join(valid_formats)}")
# Import report generator
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
from optimization_engine.report_generator import generate_study_report
# Generate report
output_path = generate_study_report(
study_dir=study_dir,
output_format=format.lower(),
include_llm_summary=include_llm_summary
)
if output_path and output_path.exists():
# Get relative path for URL
rel_path = output_path.relative_to(study_dir)
return {
"status": "success",
"format": format,
"file_path": str(output_path),
"download_url": f"/api/optimization/studies/{study_id}/reports/{output_path.name}",
"file_size": output_path.stat().st_size,
"message": f"Report generated successfully in {format} format"
}
else:
raise HTTPException(status_code=500, detail="Report generation failed")
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}")
@router.get("/studies/{study_id}/reports/{filename}")
async def download_report(study_id: str, filename: str):
"""
Download a generated report file
Args:
study_id: Study identifier
filename: Report filename
Returns:
Report file for download
"""
try:
# Validate filename to prevent directory traversal
if '..' in filename or '/' in filename or '\\' in filename:
raise HTTPException(status_code=400, detail="Invalid filename")
study_dir = resolve_study_path(study_id)
results_dir = get_results_dir(study_dir)
file_path = results_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail=f"Report file {filename} not found")
# Determine content type
suffix = file_path.suffix.lower()
content_types = {
'.md': 'text/markdown',
'.html': 'text/html',
'.pdf': 'application/pdf',
'.json': 'application/json'
}
content_type = content_types.get(suffix, 'application/octet-stream')
return FileResponse(
path=str(file_path),
media_type=content_type,
filename=filename,
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Report file not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to download report: {str(e)}")
@router.get("/studies/{study_id}/console")
async def get_console_output(study_id: str, lines: int = 200):
"""
Get the latest console output/logs from the optimization run
Args:
study_id: Study identifier
lines: Number of lines to return (default: 200)
Returns:
JSON with console output lines
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for log files in various locations
log_paths = [
study_dir / "optimization.log",
study_dir / "2_results" / "optimization.log",
study_dir / "3_results" / "optimization.log",
study_dir / "run.log",
]
log_content = None
log_path_used = None
for log_path in log_paths:
if log_path.exists():
log_path_used = log_path
break
if log_path_used is None:
return {
"lines": [],
"total_lines": 0,
"log_file": None,
"message": "No log file found. Optimization may not have started yet."
}
# Read the last N lines efficiently
with open(log_path_used, 'r', encoding='utf-8', errors='replace') as f:
all_lines = f.readlines()
# Get last N lines
last_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
# Clean up lines (remove trailing newlines)
last_lines = [line.rstrip('\n\r') for line in last_lines]
return {
"lines": last_lines,
"total_lines": len(all_lines),
"displayed_lines": len(last_lines),
"log_file": str(log_path_used),
"timestamp": datetime.now().isoformat()
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read console output: {str(e)}")
@router.get("/studies/{study_id}/report")
async def get_study_report(study_id: str):
"""
Get the STUDY_REPORT.md file content for a study
Args:
study_id: Study identifier
Returns:
JSON with the markdown content
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for STUDY_REPORT.md in the study root
report_path = study_dir / "STUDY_REPORT.md"
if not report_path.exists():
raise HTTPException(status_code=404, detail="No STUDY_REPORT.md found for this study")
with open(report_path, 'r', encoding='utf-8') as f:
content = f.read()
return {
"content": content,
"path": str(report_path),
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read study report: {str(e)}")
# ============================================================================
# Study README and Config Endpoints
# ============================================================================
@router.get("/studies/{study_id}/readme")
async def get_study_readme(study_id: str):
"""
Get the README.md file content for a study (from 1_setup folder)
Args:
study_id: Study identifier
Returns:
JSON with the markdown content
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for README.md in various locations
readme_paths = [
study_dir / "README.md",
study_dir / "1_setup" / "README.md",
study_dir / "readme.md",
]
readme_content = None
readme_path = None
for path in readme_paths:
if path.exists():
readme_path = path
with open(path, 'r', encoding='utf-8') as f:
readme_content = f.read()
break
if readme_content is None:
# Generate a basic README from config if none exists
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "optimization_config.json"
if config_file.exists():
with open(config_file) as f:
config = json.load(f)
readme_content = f"""# {config.get('study_name', study_id)}
{config.get('description', 'No description available.')}
## Design Variables
{chr(10).join([f"- **{dv['name']}**: {dv.get('min', '?')} - {dv.get('max', '?')} {dv.get('units', '')}" for dv in config.get('design_variables', [])])}
## Objectives
{chr(10).join([f"- **{obj['name']}**: {obj.get('description', '')} ({obj.get('direction', 'minimize')})" for obj in config.get('objectives', [])])}
"""
else:
readme_content = f"# {study_id}\n\nNo README or configuration found for this study."
return {
"content": readme_content,
"path": str(readme_path) if readme_path else None,
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read README: {str(e)}")
@router.get("/studies/{study_id}/image/{image_path:path}")
async def get_study_image(study_id: str, image_path: str):
"""
Serve images from a study directory.
Supports images in:
- study_dir/image.png
- study_dir/1_setup/image.png
- study_dir/3_results/image.png
- study_dir/assets/image.png
Args:
study_id: Study identifier
image_path: Relative path to the image within the study
Returns:
FileResponse with the image
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Sanitize path to prevent directory traversal
image_path = image_path.replace('..', '').lstrip('/')
# Try multiple locations for the image
possible_paths = [
study_dir / image_path,
study_dir / "1_setup" / image_path,
study_dir / "3_results" / image_path,
study_dir / "2_results" / image_path,
study_dir / "assets" / image_path,
]
image_file = None
for path in possible_paths:
if path.exists() and path.is_file():
image_file = path
break
if image_file is None:
raise HTTPException(status_code=404, detail=f"Image not found: {image_path}")
# Determine media type
suffix = image_file.suffix.lower()
media_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.svg': 'image/svg+xml',
'.webp': 'image/webp',
}
media_type = media_types.get(suffix, 'application/octet-stream')
return FileResponse(image_file, media_type=media_type)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to serve image: {str(e)}")
@router.get("/studies/{study_id}/config")
async def get_study_config(study_id: str):
"""
Get the full optimization_config.json for a study
Args:
study_id: Study identifier
Returns:
JSON with the complete configuration
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for config in various locations
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}")
with open(config_file) as f:
config = json.load(f)
return {
"config": config,
"path": str(config_file),
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read config: {str(e)}")
# ============================================================================
# Process Control Endpoints
# ============================================================================
# Track running processes by study_id
_running_processes: Dict[str, int] = {}
def _find_optimization_process(study_id: str) -> Optional[psutil.Process]:
"""Find a running optimization process for a given study"""
study_dir = resolve_study_path(study_id)
for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'cwd']):
try:
cmdline = proc.info.get('cmdline') or []
cmdline_str = ' '.join(cmdline) if cmdline else ''
# Check if this is a Python process running run_optimization.py for this study
if 'python' in cmdline_str.lower() and 'run_optimization' in cmdline_str:
if study_id in cmdline_str or str(study_dir) in cmdline_str:
return proc
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return None
@router.get("/studies/{study_id}/process")
async def get_process_status(study_id: str):
"""
Get the process status for a study's optimization run
Args:
study_id: Study identifier
Returns:
JSON with process status (is_running, pid, iteration counts)
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Check if process is running
proc = _find_optimization_process(study_id)
is_running = proc is not None
pid = proc.pid if proc else None
# Get iteration counts from database
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
fea_count = 0
nn_count = 0
iteration = None
if study_db.exists():
try:
conn = sqlite3.connect(str(study_db))
cursor = conn.cursor()
# Count FEA trials (from main study or studies with "_fea" suffix)
cursor.execute("""
SELECT COUNT(*) FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
AND (s.study_name LIKE '%_fea' OR s.study_name NOT LIKE '%_nn%')
""")
fea_count = cursor.fetchone()[0]
# Count NN trials
cursor.execute("""
SELECT COUNT(*) FROM trials t
JOIN studies s ON t.study_id = s.study_id
WHERE t.state = 'COMPLETE'
AND s.study_name LIKE '%_nn%'
""")
nn_count = cursor.fetchone()[0]
# Try to get current iteration from study names
cursor.execute("""
SELECT study_name FROM studies
WHERE study_name LIKE '%_iter%'
ORDER BY study_name DESC LIMIT 1
""")
result = cursor.fetchone()
if result:
import re
match = re.search(r'iter(\d+)', result[0])
if match:
iteration = int(match.group(1))
conn.close()
except Exception as e:
print(f"Warning: Failed to read database for process status: {e}")
return {
"is_running": is_running,
"pid": pid,
"iteration": iteration,
"fea_count": fea_count,
"nn_count": nn_count,
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get process status: {str(e)}")
class StartOptimizationRequest(BaseModel):
freshStart: bool = False
maxIterations: int = 100
feaBatchSize: int = 5
tuneTrials: int = 30
ensembleSize: int = 3
patience: int = 5
@router.post("/studies/{study_id}/start")
async def start_optimization(study_id: str, request: StartOptimizationRequest = None):
"""
Start the optimization process for a study
Args:
study_id: Study identifier
request: Optional start options
Returns:
JSON with process info
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Check if already running
existing_proc = _find_optimization_process(study_id)
if existing_proc:
return {
"success": False,
"message": f"Optimization already running (PID: {existing_proc.pid})",
"pid": existing_proc.pid
}
# Find run_optimization.py
run_script = study_dir / "run_optimization.py"
if not run_script.exists():
raise HTTPException(status_code=404, detail=f"run_optimization.py not found for study {study_id}")
# Build command with arguments
python_exe = sys.executable
cmd = [python_exe, str(run_script), "--start"]
if request:
if request.freshStart:
cmd.append("--fresh")
cmd.extend(["--fea-batch", str(request.feaBatchSize)])
cmd.extend(["--tune-trials", str(request.tuneTrials)])
cmd.extend(["--ensemble-size", str(request.ensembleSize)])
cmd.extend(["--patience", str(request.patience)])
# Start process in background
proc = subprocess.Popen(
cmd,
cwd=str(study_dir),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
_running_processes[study_id] = proc.pid
return {
"success": True,
"message": f"Optimization started successfully",
"pid": proc.pid,
"command": ' '.join(cmd)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to start optimization: {str(e)}")
class StopRequest(BaseModel):
force: bool = True # Default to force kill
@router.post("/studies/{study_id}/stop")
async def stop_optimization(study_id: str, request: StopRequest = None):
"""
Stop the optimization process for a study (hard kill by default)
Args:
study_id: Study identifier
request.force: If True (default), immediately kill. If False, try graceful first.
Returns:
JSON with result
"""
if request is None:
request = StopRequest()
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Find running process
proc = _find_optimization_process(study_id)
if not proc:
return {
"success": False,
"message": "No running optimization process found"
}
pid = proc.pid
killed_pids = []
try:
# FIRST: Get all children BEFORE killing parent
children = []
try:
children = proc.children(recursive=True)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
if request.force:
# Hard kill: immediately kill parent and all children
# Kill children first (bottom-up)
for child in reversed(children):
try:
child.kill() # SIGKILL on Unix, TerminateProcess on Windows
killed_pids.append(child.pid)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Then kill parent
try:
proc.kill()
killed_pids.append(pid)
except psutil.NoSuchProcess:
pass
else:
# Graceful: try SIGTERM first, then force
try:
proc.terminate()
proc.wait(timeout=5)
except psutil.TimeoutExpired:
# Didn't stop gracefully, force kill
for child in reversed(children):
try:
child.kill()
killed_pids.append(child.pid)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
proc.kill()
killed_pids.append(pid)
except psutil.NoSuchProcess:
pass
# Clean up tracking
if study_id in _running_processes:
del _running_processes[study_id]
return {
"success": True,
"message": f"Optimization killed (PID: {pid}, +{len(children)} children)",
"pid": pid,
"killed_pids": killed_pids
}
except psutil.NoSuchProcess:
if study_id in _running_processes:
del _running_processes[study_id]
return {
"success": True,
"message": "Process already terminated"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to stop optimization: {str(e)}")
class ValidateRequest(BaseModel):
topN: int = 5
@router.post("/studies/{study_id}/validate")
async def validate_optimization(study_id: str, request: ValidateRequest = None):
"""
Run final FEA validation on top NN predictions
Args:
study_id: Study identifier
request: Validation options (topN)
Returns:
JSON with process info
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Check if optimization is still running
existing_proc = _find_optimization_process(study_id)
if existing_proc:
return {
"success": False,
"message": "Cannot validate while optimization is running. Stop optimization first."
}
# Look for final_validation.py script
validation_script = study_dir / "final_validation.py"
if not validation_script.exists():
# Fall back to run_optimization.py with --validate flag if script doesn't exist
run_script = study_dir / "run_optimization.py"
if not run_script.exists():
raise HTTPException(status_code=404, detail="No validation script found")
python_exe = sys.executable
top_n = request.topN if request else 5
cmd = [python_exe, str(run_script), "--validate", "--top", str(top_n)]
else:
python_exe = sys.executable
top_n = request.topN if request else 5
cmd = [python_exe, str(validation_script), "--top", str(top_n)]
# Start validation process
proc = subprocess.Popen(
cmd,
cwd=str(study_dir),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
return {
"success": True,
"message": f"Validation started for top {top_n} NN predictions",
"pid": proc.pid,
"command": ' '.join(cmd)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to start validation: {str(e)}")
# ============================================================================
# Optuna Dashboard Launch
# ============================================================================
_optuna_processes: Dict[str, subprocess.Popen] = {}
@router.post("/studies/{study_id}/optuna-dashboard")
async def launch_optuna_dashboard(study_id: str):
"""
Launch Optuna dashboard for a specific study
Args:
study_id: Study identifier
Returns:
JSON with dashboard URL and process info
"""
import time
import socket
def is_port_in_use(port: int) -> bool:
"""Check if a port is already in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
results_dir = get_results_dir(study_dir)
study_db = results_dir / "study.db"
if not study_db.exists():
raise HTTPException(status_code=404, detail=f"No Optuna database found for study {study_id}")
port = 8081
# Check if dashboard is already running on this port
if is_port_in_use(port):
# Check if it's our process
if study_id in _optuna_processes:
proc = _optuna_processes[study_id]
if proc.poll() is None: # Still running
return {
"success": True,
"url": f"http://localhost:{port}",
"pid": proc.pid,
"message": "Optuna dashboard already running"
}
# Port in use but not by us - still return success since dashboard is available
return {
"success": True,
"url": f"http://localhost:{port}",
"pid": None,
"message": "Optuna dashboard already running on port 8081"
}
# Launch optuna-dashboard using CLI command (more robust than Python import)
# Use absolute path with POSIX format for SQLite URL
abs_db_path = study_db.absolute().as_posix()
storage_url = f"sqlite:///{abs_db_path}"
# Use optuna-dashboard CLI command directly
cmd = ["optuna-dashboard", storage_url, "--port", str(port), "--host", "0.0.0.0"]
# On Windows, use CREATE_NEW_PROCESS_GROUP and DETACHED_PROCESS flags
import platform
if platform.system() == 'Windows':
# Windows-specific: create detached process
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
)
else:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
_optuna_processes[study_id] = proc
# Wait for dashboard to start (check port repeatedly)
max_wait = 5 # seconds
start_time = time.time()
while time.time() - start_time < max_wait:
if is_port_in_use(port):
return {
"success": True,
"url": f"http://localhost:{port}",
"pid": proc.pid,
"message": "Optuna dashboard launched successfully"
}
# Check if process died
if proc.poll() is not None:
stderr = ""
try:
stderr = proc.stderr.read().decode() if proc.stderr else ""
except:
pass
return {
"success": False,
"message": f"Failed to start Optuna dashboard: {stderr}"
}
time.sleep(0.5)
# Timeout - process might still be starting
if proc.poll() is None:
return {
"success": True,
"url": f"http://localhost:{port}",
"pid": proc.pid,
"message": "Optuna dashboard starting (may take a moment)"
}
else:
stderr = ""
try:
stderr = proc.stderr.read().decode() if proc.stderr else ""
except:
pass
return {
"success": False,
"message": f"Failed to start Optuna dashboard: {stderr}"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to launch Optuna dashboard: {str(e)}")
# ============================================================================
# Model Files Endpoint
# ============================================================================
@router.get("/studies/{study_id}/model-files")
async def get_model_files(study_id: str):
"""
Get list of NX model files (.prt, .sim, .fem, .bdf, .dat, .op2) for a study
Args:
study_id: Study identifier
Returns:
JSON with list of model files and their paths
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Look for model directory (check multiple locations)
model_dirs = [
study_dir / "1_setup" / "model",
study_dir / "model",
study_dir / "1_setup",
study_dir
]
model_files = []
model_dir_path = None
# NX and FEA file extensions to look for
nx_extensions = {'.prt', '.sim', '.fem', '.bdf', '.dat', '.op2', '.f06', '.inp'}
for model_dir in model_dirs:
if model_dir.exists() and model_dir.is_dir():
for file_path in model_dir.iterdir():
if file_path.is_file() and file_path.suffix.lower() in nx_extensions:
model_files.append({
"name": file_path.name,
"path": str(file_path),
"extension": file_path.suffix.lower(),
"size_bytes": file_path.stat().st_size,
"size_display": _format_file_size(file_path.stat().st_size),
"modified": datetime.fromtimestamp(file_path.stat().st_mtime).isoformat()
})
if model_dir_path is None:
model_dir_path = str(model_dir)
# Sort by extension for better display (prt first, then sim, fem, etc.)
extension_order = {'.prt': 0, '.sim': 1, '.fem': 2, '.bdf': 3, '.dat': 4, '.op2': 5, '.f06': 6, '.inp': 7}
model_files.sort(key=lambda x: (extension_order.get(x['extension'], 99), x['name']))
return {
"study_id": study_id,
"model_dir": model_dir_path or str(study_dir / "1_setup" / "model"),
"files": model_files,
"count": len(model_files)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get model files: {str(e)}")
def _format_file_size(size_bytes: int) -> str:
"""Format file size in human-readable form"""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024 * 1024:
return f"{size_bytes / 1024:.1f} KB"
elif size_bytes < 1024 * 1024 * 1024:
return f"{size_bytes / (1024 * 1024):.1f} MB"
else:
return f"{size_bytes / (1024 * 1024 * 1024):.2f} GB"
@router.post("/studies/{study_id}/open-folder")
async def open_model_folder(study_id: str, folder_type: str = "model"):
"""
Open the model folder in system file explorer
Args:
study_id: Study identifier
folder_type: Type of folder to open (model, results, setup)
Returns:
JSON with success status
"""
import os
import platform
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Determine which folder to open
if folder_type == "model":
target_dir = study_dir / "1_setup" / "model"
if not target_dir.exists():
target_dir = study_dir / "1_setup"
elif folder_type == "results":
target_dir = get_results_dir(study_dir)
elif folder_type == "setup":
target_dir = study_dir / "1_setup"
else:
target_dir = study_dir
if not target_dir.exists():
target_dir = study_dir
# Open in file explorer based on platform
system = platform.system()
try:
if system == "Windows":
os.startfile(str(target_dir))
elif system == "Darwin": # macOS
subprocess.Popen(["open", str(target_dir)])
else: # Linux
subprocess.Popen(["xdg-open", str(target_dir)])
return {
"success": True,
"message": f"Opened {target_dir}",
"path": str(target_dir)
}
except Exception as e:
return {
"success": False,
"message": f"Failed to open folder: {str(e)}",
"path": str(target_dir)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to open folder: {str(e)}")
@router.get("/studies/{study_id}/best-solution")
async def get_best_solution(study_id: str):
"""Get the best trial(s) for a study with improvement metrics"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
return {
"study_id": study_id,
"best_trial": None,
"first_trial": None,
"improvements": {},
"total_trials": 0
}
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get best trial (single objective - minimize by default)
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective,
datetime(tv.value_id, 'unixepoch') as timestamp
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY tv.value ASC
LIMIT 1
""")
best_row = cursor.fetchone()
# Get first completed trial for comparison
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY t.number ASC
LIMIT 1
""")
first_row = cursor.fetchone()
# Get total trial count
cursor.execute("SELECT COUNT(*) FROM trials WHERE state = 'COMPLETE'")
total_trials = cursor.fetchone()[0]
best_trial = None
first_trial = None
improvements = {}
if best_row:
best_trial_id = best_row['trial_id']
# Get design variables
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (best_trial_id,))
params = {row['param_name']: row['param_value'] for row in cursor.fetchall()}
# Get user attributes (including results)
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (best_trial_id,))
user_attrs = {}
for row in cursor.fetchall():
try:
user_attrs[row['key']] = json.loads(row['value_json'])
except:
user_attrs[row['key']] = row['value_json']
best_trial = {
"trial_number": best_row['number'],
"objective": best_row['objective'],
"design_variables": params,
"user_attrs": user_attrs,
"timestamp": best_row['timestamp']
}
if first_row:
first_trial_id = first_row['trial_id']
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (first_trial_id,))
first_params = {row['param_name']: row['param_value'] for row in cursor.fetchall()}
first_trial = {
"trial_number": first_row['number'],
"objective": first_row['objective'],
"design_variables": first_params
}
# Calculate improvement
if best_row and first_row['objective'] != 0:
improvement_pct = ((first_row['objective'] - best_row['objective']) / abs(first_row['objective'])) * 100
improvements["objective"] = {
"initial": first_row['objective'],
"final": best_row['objective'],
"improvement_pct": round(improvement_pct, 2),
"absolute_change": round(first_row['objective'] - best_row['objective'], 6)
}
conn.close()
return {
"study_id": study_id,
"best_trial": best_trial,
"first_trial": first_trial,
"improvements": improvements,
"total_trials": total_trials
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get best solution: {str(e)}")
@router.get("/studies/{study_id}/runs")
async def get_study_runs(study_id: str):
"""
Get all optimization runs/studies in the database for comparison.
Many studies have multiple Optuna studies (e.g., v11_fea, v11_iter1_nn, v11_iter2_nn).
This endpoint returns metrics for each sub-study.
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
return {"runs": [], "total_runs": 0}
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all Optuna studies in this database
cursor.execute("""
SELECT study_id, study_name
FROM studies
ORDER BY study_id
""")
studies = cursor.fetchall()
runs = []
for study_row in studies:
optuna_study_id = study_row['study_id']
study_name = study_row['study_name']
# Get trial count
cursor.execute("""
SELECT COUNT(*) FROM trials
WHERE study_id = ? AND state = 'COMPLETE'
""", (optuna_study_id,))
trial_count = cursor.fetchone()[0]
if trial_count == 0:
continue
# Get best value (first objective)
cursor.execute("""
SELECT MIN(tv.value) as best_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0
""", (optuna_study_id,))
best_result = cursor.fetchone()
best_value = best_result['best_value'] if best_result else None
# Get average value
cursor.execute("""
SELECT AVG(tv.value) as avg_value
FROM trial_values tv
JOIN trials t ON tv.trial_id = t.trial_id
WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0
""", (optuna_study_id,))
avg_result = cursor.fetchone()
avg_value = avg_result['avg_value'] if avg_result else None
# Get time range
cursor.execute("""
SELECT MIN(datetime_start) as first_trial, MAX(datetime_complete) as last_trial
FROM trials
WHERE study_id = ? AND state = 'COMPLETE'
""", (optuna_study_id,))
time_result = cursor.fetchone()
# Determine source type (FEA or NN)
source = "NN" if "_nn" in study_name.lower() else "FEA"
runs.append({
"run_id": optuna_study_id,
"name": study_name,
"source": source,
"trial_count": trial_count,
"best_value": best_value,
"avg_value": avg_value,
"first_trial": time_result['first_trial'] if time_result else None,
"last_trial": time_result['last_trial'] if time_result else None
})
conn.close()
return {
"runs": runs,
"total_runs": len(runs),
"study_id": study_id
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get runs: {str(e)}")
class UpdateConfigRequest(BaseModel):
config: dict
@router.put("/studies/{study_id}/config")
async def update_study_config(study_id: str, request: UpdateConfigRequest):
"""
Update the optimization_config.json for a study
Args:
study_id: Study identifier
request: New configuration data
Returns:
JSON with success status
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study {study_id} not found")
# Check if optimization is running - don't allow config changes while running
if is_optimization_running(study_id):
raise HTTPException(
status_code=409,
detail="Cannot modify config while optimization is running. Stop the optimization first."
)
# Find config file location
config_file = study_dir / "1_setup" / "optimization_config.json"
if not config_file.exists():
config_file = study_dir / "optimization_config.json"
if not config_file.exists():
raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}")
# Backup existing config
backup_file = config_file.with_suffix('.json.backup')
shutil.copy(config_file, backup_file)
# Write new config
with open(config_file, 'w') as f:
json.dump(request.config, f, indent=2)
return {
"success": True,
"message": "Configuration updated successfully",
"path": str(config_file),
"backup_path": str(backup_file)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update config: {str(e)}")
# ============================================================================
# Zernike Analysis Endpoints
# ============================================================================
@router.get("/studies/{study_id}/zernike-available")
async def get_zernike_available_trials(study_id: str):
"""
Get list of trial numbers that have Zernike analysis available (OP2 files).
Returns:
JSON with list of trial numbers that have iteration folders with OP2 files
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
iter_base = study_dir / "2_iterations"
if not iter_base.exists():
return {"study_id": study_id, "available_trials": [], "count": 0}
available_trials = []
for d in iter_base.iterdir():
if d.is_dir() and d.name.startswith('iter'):
# Check for OP2 file
op2_files = list(d.glob("*.op2"))
if op2_files:
iter_num_str = d.name.replace('iter', '')
try:
iter_num = int(iter_num_str)
# Map iter number to trial number (iter1 -> trial 0, etc.)
# But also keep iter_num as possibility
if iter_num != 9999:
available_trials.append(iter_num - 1) # 0-indexed trial
else:
available_trials.append(9999) # Special test iteration
except ValueError:
pass
available_trials.sort()
return {
"study_id": study_id,
"available_trials": available_trials,
"count": len(available_trials)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get available trials: {str(e)}")
@router.get("/studies/{study_id}/trials/{trial_number}/zernike")
async def get_trial_zernike(study_id: str, trial_number: int):
"""
Generate or retrieve Zernike analysis HTML for a specific trial.
This endpoint generates interactive Zernike wavefront analysis for mirror
optimization trials. It produces 3D surface residual plots, RMS metrics,
and coefficient bar charts for each angle comparison (40_vs_20, 60_vs_20, 90_vs_20).
Args:
study_id: Study identifier
trial_number: Trial/iteration number
Returns:
JSON with HTML content for each comparison, or error if OP2 not found
"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
# Find iteration directory
# Trial numbers in Optuna DB may differ from iteration folder numbers
# Common patterns:
# 1. iter{trial_number} - direct mapping
# 2. iter{trial_number + 1} - 0-indexed trials vs 1-indexed folders
# 3. Check for actual folder existence
iter_dir = None
possible_iter_nums = [trial_number, trial_number + 1]
for iter_num in possible_iter_nums:
candidate = study_dir / "2_iterations" / f"iter{iter_num}"
if candidate.exists():
iter_dir = candidate
break
if iter_dir is None:
raise HTTPException(
status_code=404,
detail=f"No FEA results for trial {trial_number}. This trial may have used surrogate model (NN) prediction instead of full FEA simulation. Zernike analysis requires OP2 results from actual FEA runs."
)
# Check for OP2 file BEFORE doing expensive imports
op2_files = list(iter_dir.glob("*.op2"))
if not op2_files:
raise HTTPException(
status_code=404,
detail=f"No OP2 results file found in {iter_dir.name}. FEA may not have completed."
)
# Only import heavy dependencies after we know we have an OP2 file
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
from optimization_engine.extractors.extract_zernike_figure import ZernikeOPDExtractor
from optimization_engine.extractors import ZernikeExtractor
import numpy as np
from math import factorial
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib.tri import Triangulation
# Also find BDF/DAT geometry file for OPD extractor
bdf_files = list(iter_dir.glob("*.dat")) + list(iter_dir.glob("*.bdf"))
bdf_path = bdf_files[0] if bdf_files else None
# Configuration
N_MODES = 50
AMP = 0.5 # Reduced deformation scaling (0.5x)
PANCAKE = 3.0 # Z-axis range multiplier
PLOT_DOWNSAMPLE = 5000 # Reduced for faster loading
FILTER_LOW_ORDERS = 4
COLORSCALE = 'Turbo' # Colorscale: 'RdBu_r', 'Viridis', 'Plasma', 'Turbo'
SUBCASE_MAP = {
'1': '90', '2': '20', '3': '40', '4': '60',
}
REF_SUBCASE = '2'
def noll_indices(j: int):
if j < 1:
raise ValueError("Noll index j must be >= 1")
count = 0
n = 0
while True:
if n == 0:
ms = [0]
elif n % 2 == 0:
ms = [0] + [m for k in range(1, n//2 + 1) for m in (-2*k, 2*k)]
else:
ms = [m for k in range(0, (n+1)//2) for m in (-(2*k+1), (2*k+1))]
for m in ms:
count += 1
if count == j:
return n, m
n += 1
def zernike_noll(j: int, r: np.ndarray, th: np.ndarray) -> np.ndarray:
n, m = noll_indices(j)
R = np.zeros_like(r)
for s in range((n-abs(m))//2 + 1):
c = ((-1)**s * factorial(n-s) /
(factorial(s) *
factorial((n+abs(m))//2 - s) *
factorial((n-abs(m))//2 - s)))
R += c * r**(n-2*s)
if m == 0:
return R
return R * (np.cos(m*th) if m > 0 else np.sin(-m*th))
def zernike_common_name(n: int, m: int) -> str:
names = {
(0, 0): "Piston", (1, -1): "Tilt X", (1, 1): "Tilt Y",
(2, 0): "Defocus", (2, -2): "Astig 45°", (2, 2): "Astig 0°",
(3, -1): "Coma X", (3, 1): "Coma Y", (3, -3): "Trefoil X", (3, 3): "Trefoil Y",
(4, 0): "Primary Spherical", (4, -2): "Sec Astig X", (4, 2): "Sec Astig Y",
(4, -4): "Quadrafoil X", (4, 4): "Quadrafoil Y",
(5, -1): "Sec Coma X", (5, 1): "Sec Coma Y",
(5, -3): "Sec Trefoil X", (5, 3): "Sec Trefoil Y",
(5, -5): "Pentafoil X", (5, 5): "Pentafoil Y",
(6, 0): "Sec Spherical",
}
return names.get((n, m), f"Z(n={n}, m={m})")
def zernike_label(j: int) -> str:
n, m = noll_indices(j)
return f"J{j:02d} - {zernike_common_name(n, m)}"
def compute_manufacturing_metrics(coefficients: np.ndarray) -> dict:
"""Compute manufacturing-related aberration metrics."""
return {
'defocus_nm': float(abs(coefficients[3])), # J4
'astigmatism_rms': float(np.sqrt(coefficients[4]**2 + coefficients[5]**2)), # J5+J6
'coma_rms': float(np.sqrt(coefficients[6]**2 + coefficients[7]**2)), # J7+J8
'trefoil_rms': float(np.sqrt(coefficients[8]**2 + coefficients[9]**2)), # J9+J10
'spherical_nm': float(abs(coefficients[10])) if len(coefficients) > 10 else 0.0, # J11
}
def compute_rms_filter_j1to3(X, Y, W_nm, coefficients, R):
"""Compute RMS with J1-J3 filtered (keeping defocus for optician workload)."""
Xc = X - np.mean(X)
Yc = Y - np.mean(Y)
r = np.hypot(Xc/R, Yc/R)
th = np.arctan2(Yc, Xc)
Z_j1to3 = np.column_stack([zernike_noll(j, r, th) for j in range(1, 4)])
W_filter_j1to3 = W_nm - Z_j1to3 @ coefficients[:3]
return float(np.sqrt(np.mean(W_filter_j1to3**2)))
def generate_zernike_html(
title: str,
X: np.ndarray,
Y: np.ndarray,
W_nm: np.ndarray,
coefficients: np.ndarray,
rms_global: float,
rms_filtered: float,
ref_title: str = "20 deg",
abs_pair = None,
is_manufacturing: bool = False,
mfg_metrics: dict = None,
correction_metrics: dict = None
) -> str:
"""Generate HTML string for Zernike visualization with full tables."""
# Compute residual surface (filtered)
Xc = X - np.mean(X)
Yc = Y - np.mean(Y)
R = float(np.max(np.hypot(Xc, Yc)))
r = np.hypot(Xc/R, Yc/R)
th = np.arctan2(Yc, Xc)
Z = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)])
W_res_filt = W_nm - Z[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS])
# Compute J1-J3 filtered RMS (optician workload metric)
rms_filter_j1to3 = compute_rms_filter_j1to3(X, Y, W_nm, coefficients, R)
# Downsample for display
n = len(X)
if n > PLOT_DOWNSAMPLE:
rng = np.random.default_rng(42)
sel = rng.choice(n, size=PLOT_DOWNSAMPLE, replace=False)
Xp, Yp, Wp = X[sel], Y[sel], W_res_filt[sel]
else:
Xp, Yp, Wp = X, Y, W_res_filt
res_amp = AMP * Wp
max_amp = float(np.max(np.abs(res_amp))) if res_amp.size else 1.0
# Create smooth shaded SURFACE mesh with lighting
surface_trace = None
try:
tri = Triangulation(Xp, Yp)
if tri.triangles is not None and len(tri.triangles) > 0:
i_idx, j_idx, k_idx = tri.triangles.T
surface_trace = go.Mesh3d(
x=Xp.tolist(), y=Yp.tolist(), z=res_amp.tolist(),
i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(),
intensity=res_amp.tolist(),
colorscale=COLORSCALE,
opacity=1.0,
flatshading=False, # Smooth shading
lighting=dict(
ambient=0.4,
diffuse=0.8,
specular=0.3,
roughness=0.5,
fresnel=0.2
),
lightposition=dict(x=100, y=200, z=300),
showscale=True,
colorbar=dict(
title=dict(text="Residual (nm)", side='right'),
thickness=15,
len=0.6,
tickformat=".1f"
),
hovertemplate="X: %{x:.1f}
Y: %{y:.1f}
Residual: %{z:.2f} nm"
)
except Exception as e:
print(f"Triangulation failed: {e}")
labels = [zernike_label(j) for j in range(1, N_MODES+1)]
coeff_abs = np.abs(coefficients)
mfg = compute_manufacturing_metrics(coefficients)
# Determine layout based on whether this is manufacturing (90 deg) view
if is_manufacturing and mfg_metrics and correction_metrics:
# Manufacturing view: 5 rows
fig = make_subplots(
rows=5, cols=1,
specs=[[{"type": "scene"}],
[{"type": "table"}],
[{"type": "table"}],
[{"type": "table"}],
[{"type": "xy"}]],
row_heights=[0.35, 0.10, 0.15, 0.15, 0.25],
vertical_spacing=0.025,
subplot_titles=[
f"Surface Residual (relative to {ref_title})",
"RMS Metrics",
"Mode Magnitudes (Absolute 90 deg)",
"Pre-Correction (90 deg - 20 deg)",
f"Zernike Coefficients ({N_MODES} modes)"
]
)
else:
# Standard relative view: 4 rows with full coefficient table
fig = make_subplots(
rows=4, cols=1,
specs=[[{"type": "scene"}],
[{"type": "table"}],
[{"type": "table"}],
[{"type": "xy"}]],
row_heights=[0.40, 0.12, 0.28, 0.20],
vertical_spacing=0.03,
subplot_titles=[
f"Surface Residual (relative to {ref_title})",
"RMS Metrics",
f"Zernike Coefficients ({N_MODES} modes)",
"Top 20 |Zernike Coefficients| (nm)"
]
)
# Add surface mesh (or fallback to scatter)
if surface_trace is not None:
fig.add_trace(surface_trace, row=1, col=1)
else:
# Fallback to scatter if triangulation failed
fig.add_trace(go.Scatter3d(
x=Xp.tolist(), y=Yp.tolist(), z=res_amp.tolist(),
mode='markers',
marker=dict(size=2, color=res_amp.tolist(), colorscale=COLORSCALE, showscale=True),
showlegend=False
), row=1, col=1)
fig.update_scenes(
camera=dict(
eye=dict(x=1.2, y=1.2, z=0.8),
up=dict(x=0, y=0, z=1)
),
xaxis=dict(
title="X (mm)",
showgrid=True,
gridcolor='rgba(128,128,128,0.3)',
showbackground=True,
backgroundcolor='rgba(240,240,240,0.9)'
),
yaxis=dict(
title="Y (mm)",
showgrid=True,
gridcolor='rgba(128,128,128,0.3)',
showbackground=True,
backgroundcolor='rgba(240,240,240,0.9)'
),
zaxis=dict(
title="Residual (nm)",
range=[-max_amp * PANCAKE, max_amp * PANCAKE],
showgrid=True,
gridcolor='rgba(128,128,128,0.3)',
showbackground=True,
backgroundcolor='rgba(230,230,250,0.9)'
),
aspectmode='manual',
aspectratio=dict(x=1, y=1, z=0.4)
)
# Row 2: RMS table with all metrics
if abs_pair is not None:
abs_global, abs_filtered = abs_pair
fig.add_trace(go.Table(
header=dict(
values=["Metric", "Relative (nm)", "Absolute (nm)"],
align="left",
fill_color='rgb(55, 83, 109)',
font=dict(color='white', size=12)
),
cells=dict(
values=[
["Global RMS", "Filtered RMS (J1-J4)", "Filtered RMS (J1-J3, w/ defocus)"],
[f"{rms_global:.2f}", f"{rms_filtered:.2f}", f"{rms_filter_j1to3:.2f}"],
[f"{abs_global:.2f}", f"{abs_filtered:.2f}", "-"],
],
align="left",
fill_color='rgb(243, 243, 243)'
)
), row=2, col=1)
else:
fig.add_trace(go.Table(
header=dict(
values=["Metric", "Value (nm)"],
align="left",
fill_color='rgb(55, 83, 109)',
font=dict(color='white', size=12)
),
cells=dict(
values=[
["Global RMS", "Filtered RMS (J1-J4)", "Filtered RMS (J1-J3, w/ defocus)"],
[f"{rms_global:.2f}", f"{rms_filtered:.2f}", f"{rms_filter_j1to3:.2f}"]
],
align="left",
fill_color='rgb(243, 243, 243)'
)
), row=2, col=1)
if is_manufacturing and mfg_metrics and correction_metrics:
# Row 3: Mode magnitudes at 90 deg (absolute)
fig.add_trace(go.Table(
header=dict(
values=["Mode", "Value (nm)"],
align="left",
fill_color='rgb(55, 83, 109)',
font=dict(color='white', size=11)
),
cells=dict(
values=[
["Defocus (J4)", "Astigmatism (J5+J6)", "Coma (J7+J8)", "Trefoil (J9+J10)", "Spherical (J11)"],
[f"{mfg_metrics['defocus_nm']:.2f}", f"{mfg_metrics['astigmatism_rms']:.2f}",
f"{mfg_metrics['coma_rms']:.2f}", f"{mfg_metrics['trefoil_rms']:.2f}",
f"{mfg_metrics['spherical_nm']:.2f}"]
],
align="left",
fill_color='rgb(243, 243, 243)'
)
), row=3, col=1)
# Row 4: Pre-correction (90 deg - 20 deg)
fig.add_trace(go.Table(
header=dict(
values=["Correction Mode", "Value (nm)"],
align="left",
fill_color='rgb(55, 83, 109)',
font=dict(color='white', size=11)
),
cells=dict(
values=[
["Total RMS (J1-J3 filter)", "Defocus (J4)", "Astigmatism (J5+J6)", "Coma (J7+J8)"],
[f"{correction_metrics.get('rms_filter_j1to3', 0):.2f}",
f"{correction_metrics['defocus_nm']:.2f}",
f"{correction_metrics['astigmatism_rms']:.2f}",
f"{correction_metrics['coma_rms']:.2f}"]
],
align="left",
fill_color='rgb(243, 243, 243)'
)
), row=4, col=1)
# Row 5: Bar chart
sorted_idx = np.argsort(coeff_abs)[::-1][:20]
fig.add_trace(
go.Bar(
x=[float(coeff_abs[i]) for i in sorted_idx],
y=[labels[i] for i in sorted_idx],
orientation='h',
marker_color='rgb(55, 83, 109)',
hovertemplate="%{y}
|Coeff| = %{x:.3f} nm",
showlegend=False
),
row=5, col=1
)
else:
# Row 3: Full coefficient table
fig.add_trace(go.Table(
header=dict(
values=["Noll j", "Mode Name", "Coeff (nm)", "|Coeff| (nm)"],
align="left",
fill_color='rgb(55, 83, 109)',
font=dict(color='white', size=11)
),
cells=dict(
values=[
list(range(1, N_MODES+1)),
labels,
[f"{c:+.3f}" for c in coefficients],
[f"{abs(c):.3f}" for c in coefficients]
],
align="left",
fill_color='rgb(243, 243, 243)',
font=dict(size=10),
height=22
)
), row=3, col=1)
# Row 4: Bar chart - top 20 modes by magnitude
sorted_idx = np.argsort(coeff_abs)[::-1][:20]
fig.add_trace(
go.Bar(
x=[float(coeff_abs[i]) for i in sorted_idx],
y=[labels[i] for i in sorted_idx],
orientation='h',
marker_color='rgb(55, 83, 109)',
hovertemplate="%{y}
|Coeff| = %{x:.3f} nm",
showlegend=False
),
row=4, col=1
)
fig.update_layout(
width=1400,
height=1800 if is_manufacturing else 1600,
margin=dict(t=80, b=20, l=20, r=20),
title=dict(
text=f"{title}",
font=dict(size=20),
x=0.5
),
paper_bgcolor='white',
plot_bgcolor='white'
)
return fig.to_html(include_plotlyjs='cdn', full_html=True)
# =====================================================================
# NEW: Use OPD method (accounts for lateral X,Y displacement)
# =====================================================================
op2_path = op2_files[0]
# Try OPD extractor first (more accurate), fall back to Standard if no BDF
use_opd = bdf_path is not None
if use_opd:
try:
opd_extractor = ZernikeOPDExtractor(
str(op2_path),
bdf_path=str(bdf_path),
n_modes=N_MODES,
filter_orders=FILTER_LOW_ORDERS
)
except Exception as e:
print(f"OPD extractor failed, falling back to Standard: {e}")
use_opd = False
# Also create Standard extractor for comparison
std_extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES)
def generate_dual_method_html(
title: str,
target_sc: str,
ref_sc: str,
is_manufacturing: bool = False
) -> tuple:
"""Generate HTML with OPD method and displacement component views.
Returns: (html_content, rms_global_opd, rms_filtered_opd, lateral_stats)
"""
target_angle = SUBCASE_MAP.get(target_sc, target_sc)
ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc)
# Extract using OPD method (primary)
if use_opd:
opd_rel = opd_extractor.extract_relative(target_sc, ref_sc)
opd_abs = opd_extractor.extract_subcase(target_sc)
else:
opd_rel = None
opd_abs = None
# Extract using Standard method (for comparison)
std_rel = std_extractor.extract_relative(target_sc, ref_sc, include_coefficients=True)
std_abs = std_extractor.extract_subcase(target_sc, include_coefficients=True)
# Get OPD data with full arrays for visualization
if use_opd:
opd_data = opd_extractor._build_figure_opd_data(target_sc)
opd_ref_data = opd_extractor._build_figure_opd_data(ref_sc)
# Build relative displacement arrays (node-by-node)
ref_map = {int(nid): i for i, nid in enumerate(opd_ref_data['node_ids'])}
X_list, Y_list, WFE_list = [], [], []
dx_list, dy_list, dz_list = [], [], []
for i, nid in enumerate(opd_data['node_ids']):
nid = int(nid)
if nid not in ref_map:
continue
ref_idx = ref_map[nid]
# Use deformed coordinates from OPD
X_list.append(opd_data['x_deformed'][i])
Y_list.append(opd_data['y_deformed'][i])
WFE_list.append(opd_data['wfe_nm'][i] - opd_ref_data['wfe_nm'][ref_idx])
# Relative displacements (target - reference)
dx_list.append(opd_data['dx'][i] - opd_ref_data['dx'][ref_idx])
dy_list.append(opd_data['dy'][i] - opd_ref_data['dy'][ref_idx])
dz_list.append(opd_data['dz'][i] - opd_ref_data['dz'][ref_idx])
X = np.array(X_list)
Y = np.array(Y_list)
W = np.array(WFE_list)
dx = np.array(dx_list) * 1000.0 # mm to µm
dy = np.array(dy_list) * 1000.0
dz = np.array(dz_list) * 1000.0
# Lateral displacement magnitude
lateral_um = np.sqrt(dx**2 + dy**2)
max_lateral = float(np.max(np.abs(lateral_um)))
rms_lateral = float(np.sqrt(np.mean(lateral_um**2)))
rms_global_opd = opd_rel['relative_global_rms_nm']
rms_filtered_opd = opd_rel['relative_filtered_rms_nm']
coefficients = np.array(opd_rel.get('delta_coefficients', std_rel['coefficients']))
else:
# Fallback to Standard method arrays
target_disp = std_extractor.displacements[target_sc]
ref_disp = std_extractor.displacements[ref_sc]
ref_map = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])}
X_list, Y_list, W_list = [], [], []
dx_list, dy_list, dz_list = [], [], []
for i, nid in enumerate(target_disp['node_ids']):
nid = int(nid)
if nid not in ref_map:
continue
geo = std_extractor.node_geometry.get(nid)
if geo is None:
continue
ref_idx = ref_map[nid]
X_list.append(geo[0])
Y_list.append(geo[1])
target_wfe = target_disp['disp'][i, 2] * std_extractor.wfe_factor
ref_wfe = ref_disp['disp'][ref_idx, 2] * std_extractor.wfe_factor
W_list.append(target_wfe - ref_wfe)
# Relative displacements (mm to µm)
dx_list.append((target_disp['disp'][i, 0] - ref_disp['disp'][ref_idx, 0]) * 1000.0)
dy_list.append((target_disp['disp'][i, 1] - ref_disp['disp'][ref_idx, 1]) * 1000.0)
dz_list.append((target_disp['disp'][i, 2] - ref_disp['disp'][ref_idx, 2]) * 1000.0)
X = np.array(X_list)
Y = np.array(Y_list)
W = np.array(W_list)
dx = np.array(dx_list)
dy = np.array(dy_list)
dz = np.array(dz_list)
lateral_um = np.sqrt(dx**2 + dy**2)
max_lateral = float(np.max(np.abs(lateral_um)))
rms_lateral = float(np.sqrt(np.mean(lateral_um**2)))
rms_global_opd = std_rel['relative_global_rms_nm']
rms_filtered_opd = std_rel['relative_filtered_rms_nm']
coefficients = np.array(std_rel['coefficients'])
# Standard method RMS values
rms_global_std = std_rel['relative_global_rms_nm']
rms_filtered_std = std_rel['relative_filtered_rms_nm']
# Compute residual surface
Xc = X - np.mean(X)
Yc = Y - np.mean(Y)
R = float(np.max(np.hypot(Xc, Yc)))
r = np.hypot(Xc/R, Yc/R)
th = np.arctan2(Yc, Xc)
Z_basis = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)])
W_res_filt = W - Z_basis[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS])
# Downsample for display
n = len(X)
if n > PLOT_DOWNSAMPLE:
rng = np.random.default_rng(42)
sel = rng.choice(n, size=PLOT_DOWNSAMPLE, replace=False)
Xp, Yp = X[sel], Y[sel]
Wp = W_res_filt[sel]
dxp, dyp, dzp = dx[sel], dy[sel], dz[sel]
else:
Xp, Yp, Wp = X, Y, W_res_filt
dxp, dyp, dzp = dx, dy, dz
res_amp = AMP * Wp
max_amp = float(np.max(np.abs(res_amp))) if res_amp.size else 1.0
# Helper to build mesh trace
def build_mesh_trace(Zp, colorscale, colorbar_title, unit):
try:
tri = Triangulation(Xp, Yp)
if tri.triangles is not None and len(tri.triangles) > 0:
i_idx, j_idx, k_idx = tri.triangles.T
return go.Mesh3d(
x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(),
i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(),
intensity=Zp.tolist(),
colorscale=colorscale,
opacity=1.0,
flatshading=False,
lighting=dict(ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2),
lightposition=dict(x=100, y=200, z=300),
showscale=True,
colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15, len=0.5),
hovertemplate=f"X: %{{x:.1f}}
Y: %{{y:.1f}}
{unit}: %{{z:.3f}}"
)
except Exception:
pass
return go.Scatter3d(
x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(),
mode='markers', marker=dict(size=2, color=Zp.tolist(), colorscale=colorscale, showscale=True)
)
# Build traces for each view
trace_wfe = build_mesh_trace(res_amp, COLORSCALE, "WFE (nm)", "WFE nm")
trace_dx = build_mesh_trace(dxp, 'RdBu_r', "ΔX (µm)", "ΔX µm")
trace_dy = build_mesh_trace(dyp, 'RdBu_r', "ΔY (µm)", "ΔY µm")
trace_dz = build_mesh_trace(dzp, 'RdBu_r', "ΔZ (µm)", "ΔZ µm")
# Create figure with dropdown to switch views
fig = go.Figure()
# Add all traces (only WFE visible initially)
trace_wfe.visible = True
trace_dx.visible = False
trace_dy.visible = False
trace_dz.visible = False
fig.add_trace(trace_wfe)
fig.add_trace(trace_dx)
fig.add_trace(trace_dy)
fig.add_trace(trace_dz)
# Dropdown menu for view selection
fig.update_layout(
updatemenus=[
dict(
type="buttons",
direction="right",
x=0.0, y=1.12,
xanchor="left",
showactive=True,
buttons=[
dict(label="WFE (nm)", method="update",
args=[{"visible": [True, False, False, False]}]),
dict(label="ΔX (µm)", method="update",
args=[{"visible": [False, True, False, False]}]),
dict(label="ΔY (µm)", method="update",
args=[{"visible": [False, False, True, False]}]),
dict(label="ΔZ (µm)", method="update",
args=[{"visible": [False, False, False, True]}]),
],
font=dict(size=12),
pad=dict(r=10, t=10),
)
]
)
# Compute method difference
pct_diff = 100.0 * (rms_filtered_opd - rms_filtered_std) / rms_filtered_std if rms_filtered_std > 0 else 0.0
# Annotations for metrics
method_label = "OPD (X,Y,Z)" if use_opd else "Standard (Z-only)"
annotations_text = f"""
Method: {method_label} {'← More Accurate' if use_opd else '(BDF not found)'}
RMS Metrics (Filtered J1-J4):
• OPD: {rms_filtered_opd:.2f} nm
• Standard: {rms_filtered_std:.2f} nm
• Δ: {pct_diff:+.1f}%
Lateral Displacement:
• Max: {max_lateral:.3f} µm
• RMS: {rms_lateral:.3f} µm
Displacement RMS:
• ΔX: {float(np.sqrt(np.mean(dx**2))):.3f} µm
• ΔY: {float(np.sqrt(np.mean(dy**2))):.3f} µm
• ΔZ: {float(np.sqrt(np.mean(dz**2))):.3f} µm
"""
# Z-axis range for different views
max_disp = max(float(np.max(np.abs(dxp))), float(np.max(np.abs(dyp))), float(np.max(np.abs(dzp))), 0.1)
fig.update_layout(
scene=dict(
camera=dict(eye=dict(x=1.2, y=1.2, z=0.8), up=dict(x=0, y=0, z=1)),
xaxis=dict(title="X (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'),
yaxis=dict(title="Y (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'),
zaxis=dict(title="Value", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
showbackground=True, backgroundcolor='rgba(230,230,250,0.9)'),
aspectmode='manual',
aspectratio=dict(x=1, y=1, z=0.4)
),
width=1400,
height=900,
margin=dict(t=120, b=20, l=20, r=20),
title=dict(
text=f"{title}
Click buttons to switch: WFE, ΔX, ΔY, ΔZ",
font=dict(size=18),
x=0.5
),
paper_bgcolor='white',
plot_bgcolor='white',
annotations=[
dict(
text=annotations_text.replace('\n', '
'),
xref="paper", yref="paper",
x=1.02, y=0.98,
xanchor="left", yanchor="top",
showarrow=False,
font=dict(family="monospace", size=11),
align="left",
bgcolor="rgba(255,255,255,0.9)",
bordercolor="rgba(0,0,0,0.3)",
borderwidth=1,
borderpad=8
),
dict(
text="View:",
xref="paper", yref="paper",
x=0.0, y=1.15,
xanchor="left", yanchor="top",
showarrow=False,
font=dict(size=12)
)
]
)
html_content = fig.to_html(include_plotlyjs='cdn', full_html=True)
return (html_content, rms_global_opd, rms_filtered_opd, {
'max_lateral_um': max_lateral,
'rms_lateral_um': rms_lateral,
'method': 'opd' if use_opd else 'standard',
'rms_std': rms_filtered_std,
'pct_diff': pct_diff
})
# Generate results for each comparison
results = {}
comparisons = [
('3', '2', '40_vs_20', '40 deg vs 20 deg'),
('4', '2', '60_vs_20', '60 deg vs 20 deg'),
('1', '2', '90_vs_20', '90 deg vs 20 deg (manufacturing)'),
]
for target_sc, ref_sc, key, title_suffix in comparisons:
if target_sc not in std_extractor.displacements:
continue
target_angle = SUBCASE_MAP.get(target_sc, target_sc)
ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc)
is_mfg = (key == '90_vs_20')
html_content, rms_global, rms_filtered, lateral_stats = generate_dual_method_html(
title=f"iter{trial_number}: {target_angle}° vs {ref_angle}°",
target_sc=target_sc,
ref_sc=ref_sc,
is_manufacturing=is_mfg
)
results[key] = {
"html": html_content,
"rms_global": rms_global,
"rms_filtered": rms_filtered,
"title": f"{target_angle}° vs {ref_angle}°",
"method": lateral_stats['method'],
"rms_std": lateral_stats['rms_std'],
"pct_diff": lateral_stats['pct_diff'],
"max_lateral_um": lateral_stats['max_lateral_um'],
"rms_lateral_um": lateral_stats['rms_lateral_um']
}
if not results:
raise HTTPException(
status_code=500,
detail="Failed to generate Zernike analysis. Check if subcases are available."
)
return {
"study_id": study_id,
"trial_number": trial_number,
"comparisons": results,
"available_comparisons": list(results.keys()),
"method": "opd" if use_opd else "standard"
}
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Failed to generate Zernike analysis: {str(e)}")
@router.get("/studies/{study_id}/export/{format}")
async def export_study_data(study_id: str, format: str):
"""Export study data in various formats: csv, json, excel"""
try:
study_dir = resolve_study_path(study_id)
if not study_dir.exists():
raise HTTPException(status_code=404, detail=f"Study '{study_id}' not found")
results_dir = get_results_dir(study_dir)
db_path = results_dir / "study.db"
if not db_path.exists():
raise HTTPException(status_code=404, detail="No study data available")
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all completed trials with their params and values
cursor.execute("""
SELECT t.trial_id, t.number, tv.value as objective
FROM trials t
JOIN trial_values tv ON t.trial_id = tv.trial_id
WHERE t.state = 'COMPLETE'
ORDER BY t.number
""")
trials_data = []
for row in cursor.fetchall():
trial_id = row['trial_id']
# Get params
cursor.execute("""
SELECT param_name, param_value
FROM trial_params
WHERE trial_id = ?
""", (trial_id,))
params = {r['param_name']: r['param_value'] for r in cursor.fetchall()}
# Get user attrs
cursor.execute("""
SELECT key, value_json
FROM trial_user_attributes
WHERE trial_id = ?
""", (trial_id,))
user_attrs = {}
for r in cursor.fetchall():
try:
user_attrs[r['key']] = json.loads(r['value_json'])
except:
user_attrs[r['key']] = r['value_json']
trials_data.append({
"trial_number": row['number'],
"objective": row['objective'],
"params": params,
"user_attrs": user_attrs
})
conn.close()
if format.lower() == "json":
return JSONResponse(content={
"study_id": study_id,
"total_trials": len(trials_data),
"trials": trials_data
})
elif format.lower() == "csv":
import io
import csv
if not trials_data:
return JSONResponse(content={"error": "No data to export"})
# Build CSV
output = io.StringIO()
# Get all param names
param_names = sorted(set(
key for trial in trials_data
for key in trial['params'].keys()
))
fieldnames = ['trial_number', 'objective'] + param_names
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()
for trial in trials_data:
row_data = {
'trial_number': trial['trial_number'],
'objective': trial['objective']
}
row_data.update(trial['params'])
writer.writerow(row_data)
csv_content = output.getvalue()
return JSONResponse(content={
"filename": f"{study_id}_data.csv",
"content": csv_content,
"content_type": "text/csv"
})
elif format.lower() == "config":
# Export optimization config
setup_dir = study_dir / "1_setup"
config_path = setup_dir / "optimization_config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
return JSONResponse(content={
"filename": f"{study_id}_config.json",
"content": json.dumps(config, indent=2),
"content_type": "application/json"
})
else:
raise HTTPException(status_code=404, detail="Config file not found")
else:
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to export data: {str(e)}")