diff --git a/atomizer-dashboard/backend/api/routes/optimization.py b/atomizer-dashboard/backend/api/routes/optimization.py
index 5466b6de..2e01fc04 100644
--- a/atomizer-dashboard/backend/api/routes/optimization.py
+++ b/atomizer-dashboard/backend/api/routes/optimization.py
@@ -692,12 +692,33 @@ async def get_study_metadata(study_id: str):
if "unit" not in obj or not obj["unit"]:
obj["unit"] = _infer_objective_unit(obj)
+ # Get sampler/algorithm info
+ optimization = config.get("optimization", {})
+ algorithm = optimization.get("algorithm", "TPE")
+
+ # Map algorithm names to Optuna sampler names for frontend display
+ sampler_map = {
+ "CMA-ES": "CmaEsSampler",
+ "cma-es": "CmaEsSampler",
+ "cmaes": "CmaEsSampler",
+ "TPE": "TPESampler",
+ "tpe": "TPESampler",
+ "NSGA-II": "NSGAIISampler",
+ "nsga-ii": "NSGAIISampler",
+ "NSGA-III": "NSGAIIISampler",
+ "Random": "RandomSampler",
+ }
+ sampler = sampler_map.get(algorithm, algorithm)
+
return {
"objectives": objectives,
"design_variables": config.get("design_variables", []),
"constraints": config.get("constraints", []),
"study_name": config.get("study_name", study_id),
- "description": config.get("description", "")
+ "description": config.get("description", ""),
+ "sampler": sampler,
+ "algorithm": algorithm,
+ "n_trials": optimization.get("n_trials", 100)
}
except FileNotFoundError:
@@ -2475,6 +2496,7 @@ async def get_trial_zernike(study_id: str, trial_number: int):
# Only import heavy dependencies after we know we have an OP2 file
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
+ from optimization_engine.extractors.extract_zernike_figure import ZernikeOPDExtractor
from optimization_engine.extractors import ZernikeExtractor
import numpy as np
from math import factorial
@@ -2482,6 +2504,10 @@ async def get_trial_zernike(study_id: str, trial_number: int):
from plotly.subplots import make_subplots
from matplotlib.tri import Triangulation
+ # Also find BDF/DAT geometry file for OPD extractor
+ bdf_files = list(iter_dir.glob("*.dat")) + list(iter_dir.glob("*.bdf"))
+ bdf_path = bdf_files[0] if bdf_files else None
+
# Configuration
N_MODES = 50
AMP = 0.5 # Reduced deformation scaling (0.5x)
@@ -2867,10 +2893,320 @@ async def get_trial_zernike(study_id: str, trial_number: int):
return fig.to_html(include_plotlyjs='cdn', full_html=True)
- # Load OP2 and generate reports
+ # =====================================================================
+ # NEW: Use OPD method (accounts for lateral X,Y displacement)
+ # =====================================================================
op2_path = op2_files[0]
- extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES)
+ # Try OPD extractor first (more accurate), fall back to Standard if no BDF
+ use_opd = bdf_path is not None
+ if use_opd:
+ try:
+ opd_extractor = ZernikeOPDExtractor(
+ str(op2_path),
+ bdf_path=str(bdf_path),
+ n_modes=N_MODES,
+ filter_orders=FILTER_LOW_ORDERS
+ )
+ except Exception as e:
+ print(f"OPD extractor failed, falling back to Standard: {e}")
+ use_opd = False
+
+ # Also create Standard extractor for comparison
+ std_extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES)
+
+ def generate_dual_method_html(
+ title: str,
+ target_sc: str,
+ ref_sc: str,
+ is_manufacturing: bool = False
+ ) -> tuple:
+ """Generate HTML with OPD method and displacement component views.
+
+ Returns: (html_content, rms_global_opd, rms_filtered_opd, lateral_stats)
+ """
+ target_angle = SUBCASE_MAP.get(target_sc, target_sc)
+ ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc)
+
+ # Extract using OPD method (primary)
+ if use_opd:
+ opd_rel = opd_extractor.extract_relative(target_sc, ref_sc)
+ opd_abs = opd_extractor.extract_subcase(target_sc)
+ else:
+ opd_rel = None
+ opd_abs = None
+
+ # Extract using Standard method (for comparison)
+ std_rel = std_extractor.extract_relative(target_sc, ref_sc, include_coefficients=True)
+ std_abs = std_extractor.extract_subcase(target_sc, include_coefficients=True)
+
+ # Get OPD data with full arrays for visualization
+ if use_opd:
+ opd_data = opd_extractor._build_figure_opd_data(target_sc)
+ opd_ref_data = opd_extractor._build_figure_opd_data(ref_sc)
+
+ # Build relative displacement arrays (node-by-node)
+ ref_map = {int(nid): i for i, nid in enumerate(opd_ref_data['node_ids'])}
+ X_list, Y_list, WFE_list = [], [], []
+ dx_list, dy_list, dz_list = [], [], []
+
+ for i, nid in enumerate(opd_data['node_ids']):
+ nid = int(nid)
+ if nid not in ref_map:
+ continue
+ ref_idx = ref_map[nid]
+
+ # Use deformed coordinates from OPD
+ X_list.append(opd_data['x_deformed'][i])
+ Y_list.append(opd_data['y_deformed'][i])
+ WFE_list.append(opd_data['wfe_nm'][i] - opd_ref_data['wfe_nm'][ref_idx])
+
+ # Relative displacements (target - reference)
+ dx_list.append(opd_data['dx'][i] - opd_ref_data['dx'][ref_idx])
+ dy_list.append(opd_data['dy'][i] - opd_ref_data['dy'][ref_idx])
+ dz_list.append(opd_data['dz'][i] - opd_ref_data['dz'][ref_idx])
+
+ X = np.array(X_list)
+ Y = np.array(Y_list)
+ W = np.array(WFE_list)
+ dx = np.array(dx_list) * 1000.0 # mm to µm
+ dy = np.array(dy_list) * 1000.0
+ dz = np.array(dz_list) * 1000.0
+
+ # Lateral displacement magnitude
+ lateral_um = np.sqrt(dx**2 + dy**2)
+ max_lateral = float(np.max(np.abs(lateral_um)))
+ rms_lateral = float(np.sqrt(np.mean(lateral_um**2)))
+
+ rms_global_opd = opd_rel['relative_global_rms_nm']
+ rms_filtered_opd = opd_rel['relative_filtered_rms_nm']
+ coefficients = np.array(opd_rel.get('delta_coefficients', std_rel['coefficients']))
+ else:
+ # Fallback to Standard method arrays
+ target_disp = std_extractor.displacements[target_sc]
+ ref_disp = std_extractor.displacements[ref_sc]
+ ref_map = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])}
+
+ X_list, Y_list, W_list = [], [], []
+ dx_list, dy_list, dz_list = [], [], []
+
+ for i, nid in enumerate(target_disp['node_ids']):
+ nid = int(nid)
+ if nid not in ref_map:
+ continue
+ geo = std_extractor.node_geometry.get(nid)
+ if geo is None:
+ continue
+ ref_idx = ref_map[nid]
+
+ X_list.append(geo[0])
+ Y_list.append(geo[1])
+
+ target_wfe = target_disp['disp'][i, 2] * std_extractor.wfe_factor
+ ref_wfe = ref_disp['disp'][ref_idx, 2] * std_extractor.wfe_factor
+ W_list.append(target_wfe - ref_wfe)
+
+ # Relative displacements (mm to µm)
+ dx_list.append((target_disp['disp'][i, 0] - ref_disp['disp'][ref_idx, 0]) * 1000.0)
+ dy_list.append((target_disp['disp'][i, 1] - ref_disp['disp'][ref_idx, 1]) * 1000.0)
+ dz_list.append((target_disp['disp'][i, 2] - ref_disp['disp'][ref_idx, 2]) * 1000.0)
+
+ X = np.array(X_list)
+ Y = np.array(Y_list)
+ W = np.array(W_list)
+ dx = np.array(dx_list)
+ dy = np.array(dy_list)
+ dz = np.array(dz_list)
+
+ lateral_um = np.sqrt(dx**2 + dy**2)
+ max_lateral = float(np.max(np.abs(lateral_um)))
+ rms_lateral = float(np.sqrt(np.mean(lateral_um**2)))
+
+ rms_global_opd = std_rel['relative_global_rms_nm']
+ rms_filtered_opd = std_rel['relative_filtered_rms_nm']
+ coefficients = np.array(std_rel['coefficients'])
+
+ # Standard method RMS values
+ rms_global_std = std_rel['relative_global_rms_nm']
+ rms_filtered_std = std_rel['relative_filtered_rms_nm']
+
+ # Compute residual surface
+ Xc = X - np.mean(X)
+ Yc = Y - np.mean(Y)
+ R = float(np.max(np.hypot(Xc, Yc)))
+ r = np.hypot(Xc/R, Yc/R)
+ th = np.arctan2(Yc, Xc)
+
+ Z_basis = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)])
+ W_res_filt = W - Z_basis[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS])
+
+ # Downsample for display
+ n = len(X)
+ if n > PLOT_DOWNSAMPLE:
+ rng = np.random.default_rng(42)
+ sel = rng.choice(n, size=PLOT_DOWNSAMPLE, replace=False)
+ Xp, Yp = X[sel], Y[sel]
+ Wp = W_res_filt[sel]
+ dxp, dyp, dzp = dx[sel], dy[sel], dz[sel]
+ else:
+ Xp, Yp, Wp = X, Y, W_res_filt
+ dxp, dyp, dzp = dx, dy, dz
+
+ res_amp = AMP * Wp
+ max_amp = float(np.max(np.abs(res_amp))) if res_amp.size else 1.0
+
+ # Helper to build mesh trace
+ def build_mesh_trace(Zp, colorscale, colorbar_title, unit):
+ try:
+ tri = Triangulation(Xp, Yp)
+ if tri.triangles is not None and len(tri.triangles) > 0:
+ i_idx, j_idx, k_idx = tri.triangles.T
+ return go.Mesh3d(
+ x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(),
+ i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(),
+ intensity=Zp.tolist(),
+ colorscale=colorscale,
+ opacity=1.0,
+ flatshading=False,
+ lighting=dict(ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2),
+ lightposition=dict(x=100, y=200, z=300),
+ showscale=True,
+ colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15, len=0.5),
+ hovertemplate=f"X: %{{x:.1f}}
Y: %{{y:.1f}}
{unit}: %{{z:.3f}}"
+ )
+ except Exception:
+ pass
+ return go.Scatter3d(
+ x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(),
+ mode='markers', marker=dict(size=2, color=Zp.tolist(), colorscale=colorscale, showscale=True)
+ )
+
+ # Build traces for each view
+ trace_wfe = build_mesh_trace(res_amp, COLORSCALE, "WFE (nm)", "WFE nm")
+ trace_dx = build_mesh_trace(dxp, 'RdBu_r', "ΔX (µm)", "ΔX µm")
+ trace_dy = build_mesh_trace(dyp, 'RdBu_r', "ΔY (µm)", "ΔY µm")
+ trace_dz = build_mesh_trace(dzp, 'RdBu_r', "ΔZ (µm)", "ΔZ µm")
+
+ # Create figure with dropdown to switch views
+ fig = go.Figure()
+
+ # Add all traces (only WFE visible initially)
+ trace_wfe.visible = True
+ trace_dx.visible = False
+ trace_dy.visible = False
+ trace_dz.visible = False
+
+ fig.add_trace(trace_wfe)
+ fig.add_trace(trace_dx)
+ fig.add_trace(trace_dy)
+ fig.add_trace(trace_dz)
+
+ # Dropdown menu for view selection
+ fig.update_layout(
+ updatemenus=[
+ dict(
+ type="buttons",
+ direction="right",
+ x=0.0, y=1.12,
+ xanchor="left",
+ showactive=True,
+ buttons=[
+ dict(label="WFE (nm)", method="update",
+ args=[{"visible": [True, False, False, False]}]),
+ dict(label="ΔX (µm)", method="update",
+ args=[{"visible": [False, True, False, False]}]),
+ dict(label="ΔY (µm)", method="update",
+ args=[{"visible": [False, False, True, False]}]),
+ dict(label="ΔZ (µm)", method="update",
+ args=[{"visible": [False, False, False, True]}]),
+ ],
+ font=dict(size=12),
+ pad=dict(r=10, t=10),
+ )
+ ]
+ )
+
+ # Compute method difference
+ pct_diff = 100.0 * (rms_filtered_opd - rms_filtered_std) / rms_filtered_std if rms_filtered_std > 0 else 0.0
+
+ # Annotations for metrics
+ method_label = "OPD (X,Y,Z)" if use_opd else "Standard (Z-only)"
+ annotations_text = f"""
+Method: {method_label} {'← More Accurate' if use_opd else '(BDF not found)'}
+RMS Metrics (Filtered J1-J4):
+ • OPD: {rms_filtered_opd:.2f} nm
+ • Standard: {rms_filtered_std:.2f} nm
+ • Δ: {pct_diff:+.1f}%
+Lateral Displacement:
+ • Max: {max_lateral:.3f} µm
+ • RMS: {rms_lateral:.3f} µm
+Displacement RMS:
+ • ΔX: {float(np.sqrt(np.mean(dx**2))):.3f} µm
+ • ΔY: {float(np.sqrt(np.mean(dy**2))):.3f} µm
+ • ΔZ: {float(np.sqrt(np.mean(dz**2))):.3f} µm
+"""
+
+ # Z-axis range for different views
+ max_disp = max(float(np.max(np.abs(dxp))), float(np.max(np.abs(dyp))), float(np.max(np.abs(dzp))), 0.1)
+
+ fig.update_layout(
+ scene=dict(
+ camera=dict(eye=dict(x=1.2, y=1.2, z=0.8), up=dict(x=0, y=0, z=1)),
+ xaxis=dict(title="X (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
+ showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'),
+ yaxis=dict(title="Y (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
+ showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'),
+ zaxis=dict(title="Value", showgrid=True, gridcolor='rgba(128,128,128,0.3)',
+ showbackground=True, backgroundcolor='rgba(230,230,250,0.9)'),
+ aspectmode='manual',
+ aspectratio=dict(x=1, y=1, z=0.4)
+ ),
+ width=1400,
+ height=900,
+ margin=dict(t=120, b=20, l=20, r=20),
+ title=dict(
+ text=f"{title}
Click buttons to switch: WFE, ΔX, ΔY, ΔZ",
+ font=dict(size=18),
+ x=0.5
+ ),
+ paper_bgcolor='white',
+ plot_bgcolor='white',
+ annotations=[
+ dict(
+ text=annotations_text.replace('\n', '
'),
+ xref="paper", yref="paper",
+ x=1.02, y=0.98,
+ xanchor="left", yanchor="top",
+ showarrow=False,
+ font=dict(family="monospace", size=11),
+ align="left",
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.3)",
+ borderwidth=1,
+ borderpad=8
+ ),
+ dict(
+ text="View:",
+ xref="paper", yref="paper",
+ x=0.0, y=1.15,
+ xanchor="left", yanchor="top",
+ showarrow=False,
+ font=dict(size=12)
+ )
+ ]
+ )
+
+ html_content = fig.to_html(include_plotlyjs='cdn', full_html=True)
+
+ return (html_content, rms_global_opd, rms_filtered_opd, {
+ 'max_lateral_um': max_lateral,
+ 'rms_lateral_um': rms_lateral,
+ 'method': 'opd' if use_opd else 'standard',
+ 'rms_std': rms_filtered_std,
+ 'pct_diff': pct_diff
+ })
+
+ # Generate results for each comparison
results = {}
comparisons = [
('3', '2', '40_vs_20', '40 deg vs 20 deg'),
@@ -2878,84 +3214,31 @@ async def get_trial_zernike(study_id: str, trial_number: int):
('1', '2', '90_vs_20', '90 deg vs 20 deg (manufacturing)'),
]
- # Pre-compute absolute 90 deg metrics for manufacturing view
- abs_90_data = None
- abs_90_metrics = None
- if '1' in extractor.displacements:
- abs_90_data = extractor.extract_subcase('1', include_coefficients=True)
- abs_90_metrics = compute_manufacturing_metrics(np.array(abs_90_data['coefficients']))
-
for target_sc, ref_sc, key, title_suffix in comparisons:
- if target_sc not in extractor.displacements:
+ if target_sc not in std_extractor.displacements:
continue
- # Get relative data with coefficients
- rel_data = extractor.extract_relative(target_sc, ref_sc, include_coefficients=True)
-
- # Get absolute data for this subcase
- abs_data = extractor.extract_subcase(target_sc, include_coefficients=True)
-
- # Build coordinate arrays
- target_disp = extractor.displacements[target_sc]
- ref_disp = extractor.displacements[ref_sc]
-
- ref_node_to_idx = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])}
- X_list, Y_list, W_list = [], [], []
-
- for i, nid in enumerate(target_disp['node_ids']):
- nid = int(nid)
- if nid not in ref_node_to_idx:
- continue
- geo = extractor.node_geometry.get(nid)
- if geo is None:
- continue
-
- ref_idx = ref_node_to_idx[nid]
- target_wfe = target_disp['disp'][i, 2] * extractor.wfe_factor
- ref_wfe = ref_disp['disp'][ref_idx, 2] * extractor.wfe_factor
-
- X_list.append(geo[0])
- Y_list.append(geo[1])
- W_list.append(target_wfe - ref_wfe)
-
- X = np.array(X_list)
- Y = np.array(Y_list)
- W = np.array(W_list)
-
target_angle = SUBCASE_MAP.get(target_sc, target_sc)
ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc)
-
- # Check if this is the manufacturing (90 deg) comparison
is_mfg = (key == '90_vs_20')
- # Compute correction metrics (relative coefficients) for manufacturing view
- correction_metrics = None
- if is_mfg and 'coefficients' in rel_data:
- correction_metrics = compute_manufacturing_metrics(np.array(rel_data['coefficients']))
- # Also compute rms_filter_j1to3 for the relative data
- R = float(np.max(np.hypot(X - np.mean(X), Y - np.mean(Y))))
- correction_metrics['rms_filter_j1to3'] = compute_rms_filter_j1to3(
- X, Y, W, np.array(rel_data['coefficients']), R
- )
-
- html_content = generate_zernike_html(
- title=f"iter{trial_number}: {target_angle} deg vs {ref_angle} deg",
- X=X, Y=Y, W_nm=W,
- coefficients=np.array(rel_data['coefficients']),
- rms_global=rel_data['relative_global_rms_nm'],
- rms_filtered=rel_data['relative_filtered_rms_nm'],
- ref_title=f"{ref_angle} deg",
- abs_pair=(abs_data['global_rms_nm'], abs_data['filtered_rms_nm']),
- is_manufacturing=is_mfg,
- mfg_metrics=abs_90_metrics if is_mfg else None,
- correction_metrics=correction_metrics
+ html_content, rms_global, rms_filtered, lateral_stats = generate_dual_method_html(
+ title=f"iter{trial_number}: {target_angle}° vs {ref_angle}°",
+ target_sc=target_sc,
+ ref_sc=ref_sc,
+ is_manufacturing=is_mfg
)
results[key] = {
"html": html_content,
- "rms_global": rel_data['relative_global_rms_nm'],
- "rms_filtered": rel_data['relative_filtered_rms_nm'],
- "title": f"{target_angle}° vs {ref_angle}°"
+ "rms_global": rms_global,
+ "rms_filtered": rms_filtered,
+ "title": f"{target_angle}° vs {ref_angle}°",
+ "method": lateral_stats['method'],
+ "rms_std": lateral_stats['rms_std'],
+ "pct_diff": lateral_stats['pct_diff'],
+ "max_lateral_um": lateral_stats['max_lateral_um'],
+ "rms_lateral_um": lateral_stats['rms_lateral_um']
}
if not results:
@@ -2968,7 +3251,8 @@ async def get_trial_zernike(study_id: str, trial_number: int):
"study_id": study_id,
"trial_number": trial_number,
"comparisons": results,
- "available_comparisons": list(results.keys())
+ "available_comparisons": list(results.keys()),
+ "method": "opd" if use_opd else "standard"
}
except HTTPException: