diff --git a/atomizer-dashboard/backend/api/routes/optimization.py b/atomizer-dashboard/backend/api/routes/optimization.py index 5466b6de..2e01fc04 100644 --- a/atomizer-dashboard/backend/api/routes/optimization.py +++ b/atomizer-dashboard/backend/api/routes/optimization.py @@ -692,12 +692,33 @@ async def get_study_metadata(study_id: str): if "unit" not in obj or not obj["unit"]: obj["unit"] = _infer_objective_unit(obj) + # Get sampler/algorithm info + optimization = config.get("optimization", {}) + algorithm = optimization.get("algorithm", "TPE") + + # Map algorithm names to Optuna sampler names for frontend display + sampler_map = { + "CMA-ES": "CmaEsSampler", + "cma-es": "CmaEsSampler", + "cmaes": "CmaEsSampler", + "TPE": "TPESampler", + "tpe": "TPESampler", + "NSGA-II": "NSGAIISampler", + "nsga-ii": "NSGAIISampler", + "NSGA-III": "NSGAIIISampler", + "Random": "RandomSampler", + } + sampler = sampler_map.get(algorithm, algorithm) + return { "objectives": objectives, "design_variables": config.get("design_variables", []), "constraints": config.get("constraints", []), "study_name": config.get("study_name", study_id), - "description": config.get("description", "") + "description": config.get("description", ""), + "sampler": sampler, + "algorithm": algorithm, + "n_trials": optimization.get("n_trials", 100) } except FileNotFoundError: @@ -2475,6 +2496,7 @@ async def get_trial_zernike(study_id: str, trial_number: int): # Only import heavy dependencies after we know we have an OP2 file sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + from optimization_engine.extractors.extract_zernike_figure import ZernikeOPDExtractor from optimization_engine.extractors import ZernikeExtractor import numpy as np from math import factorial @@ -2482,6 +2504,10 @@ async def get_trial_zernike(study_id: str, trial_number: int): from plotly.subplots import make_subplots from matplotlib.tri import Triangulation + # Also find BDF/DAT geometry file for OPD extractor + bdf_files = list(iter_dir.glob("*.dat")) + list(iter_dir.glob("*.bdf")) + bdf_path = bdf_files[0] if bdf_files else None + # Configuration N_MODES = 50 AMP = 0.5 # Reduced deformation scaling (0.5x) @@ -2867,10 +2893,320 @@ async def get_trial_zernike(study_id: str, trial_number: int): return fig.to_html(include_plotlyjs='cdn', full_html=True) - # Load OP2 and generate reports + # ===================================================================== + # NEW: Use OPD method (accounts for lateral X,Y displacement) + # ===================================================================== op2_path = op2_files[0] - extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES) + # Try OPD extractor first (more accurate), fall back to Standard if no BDF + use_opd = bdf_path is not None + if use_opd: + try: + opd_extractor = ZernikeOPDExtractor( + str(op2_path), + bdf_path=str(bdf_path), + n_modes=N_MODES, + filter_orders=FILTER_LOW_ORDERS + ) + except Exception as e: + print(f"OPD extractor failed, falling back to Standard: {e}") + use_opd = False + + # Also create Standard extractor for comparison + std_extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES) + + def generate_dual_method_html( + title: str, + target_sc: str, + ref_sc: str, + is_manufacturing: bool = False + ) -> tuple: + """Generate HTML with OPD method and displacement component views. + + Returns: (html_content, rms_global_opd, rms_filtered_opd, lateral_stats) + """ + target_angle = SUBCASE_MAP.get(target_sc, target_sc) + ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc) + + # Extract using OPD method (primary) + if use_opd: + opd_rel = opd_extractor.extract_relative(target_sc, ref_sc) + opd_abs = opd_extractor.extract_subcase(target_sc) + else: + opd_rel = None + opd_abs = None + + # Extract using Standard method (for comparison) + std_rel = std_extractor.extract_relative(target_sc, ref_sc, include_coefficients=True) + std_abs = std_extractor.extract_subcase(target_sc, include_coefficients=True) + + # Get OPD data with full arrays for visualization + if use_opd: + opd_data = opd_extractor._build_figure_opd_data(target_sc) + opd_ref_data = opd_extractor._build_figure_opd_data(ref_sc) + + # Build relative displacement arrays (node-by-node) + ref_map = {int(nid): i for i, nid in enumerate(opd_ref_data['node_ids'])} + X_list, Y_list, WFE_list = [], [], [] + dx_list, dy_list, dz_list = [], [], [] + + for i, nid in enumerate(opd_data['node_ids']): + nid = int(nid) + if nid not in ref_map: + continue + ref_idx = ref_map[nid] + + # Use deformed coordinates from OPD + X_list.append(opd_data['x_deformed'][i]) + Y_list.append(opd_data['y_deformed'][i]) + WFE_list.append(opd_data['wfe_nm'][i] - opd_ref_data['wfe_nm'][ref_idx]) + + # Relative displacements (target - reference) + dx_list.append(opd_data['dx'][i] - opd_ref_data['dx'][ref_idx]) + dy_list.append(opd_data['dy'][i] - opd_ref_data['dy'][ref_idx]) + dz_list.append(opd_data['dz'][i] - opd_ref_data['dz'][ref_idx]) + + X = np.array(X_list) + Y = np.array(Y_list) + W = np.array(WFE_list) + dx = np.array(dx_list) * 1000.0 # mm to µm + dy = np.array(dy_list) * 1000.0 + dz = np.array(dz_list) * 1000.0 + + # Lateral displacement magnitude + lateral_um = np.sqrt(dx**2 + dy**2) + max_lateral = float(np.max(np.abs(lateral_um))) + rms_lateral = float(np.sqrt(np.mean(lateral_um**2))) + + rms_global_opd = opd_rel['relative_global_rms_nm'] + rms_filtered_opd = opd_rel['relative_filtered_rms_nm'] + coefficients = np.array(opd_rel.get('delta_coefficients', std_rel['coefficients'])) + else: + # Fallback to Standard method arrays + target_disp = std_extractor.displacements[target_sc] + ref_disp = std_extractor.displacements[ref_sc] + ref_map = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])} + + X_list, Y_list, W_list = [], [], [] + dx_list, dy_list, dz_list = [], [], [] + + for i, nid in enumerate(target_disp['node_ids']): + nid = int(nid) + if nid not in ref_map: + continue + geo = std_extractor.node_geometry.get(nid) + if geo is None: + continue + ref_idx = ref_map[nid] + + X_list.append(geo[0]) + Y_list.append(geo[1]) + + target_wfe = target_disp['disp'][i, 2] * std_extractor.wfe_factor + ref_wfe = ref_disp['disp'][ref_idx, 2] * std_extractor.wfe_factor + W_list.append(target_wfe - ref_wfe) + + # Relative displacements (mm to µm) + dx_list.append((target_disp['disp'][i, 0] - ref_disp['disp'][ref_idx, 0]) * 1000.0) + dy_list.append((target_disp['disp'][i, 1] - ref_disp['disp'][ref_idx, 1]) * 1000.0) + dz_list.append((target_disp['disp'][i, 2] - ref_disp['disp'][ref_idx, 2]) * 1000.0) + + X = np.array(X_list) + Y = np.array(Y_list) + W = np.array(W_list) + dx = np.array(dx_list) + dy = np.array(dy_list) + dz = np.array(dz_list) + + lateral_um = np.sqrt(dx**2 + dy**2) + max_lateral = float(np.max(np.abs(lateral_um))) + rms_lateral = float(np.sqrt(np.mean(lateral_um**2))) + + rms_global_opd = std_rel['relative_global_rms_nm'] + rms_filtered_opd = std_rel['relative_filtered_rms_nm'] + coefficients = np.array(std_rel['coefficients']) + + # Standard method RMS values + rms_global_std = std_rel['relative_global_rms_nm'] + rms_filtered_std = std_rel['relative_filtered_rms_nm'] + + # Compute residual surface + Xc = X - np.mean(X) + Yc = Y - np.mean(Y) + R = float(np.max(np.hypot(Xc, Yc))) + r = np.hypot(Xc/R, Yc/R) + th = np.arctan2(Yc, Xc) + + Z_basis = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)]) + W_res_filt = W - Z_basis[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS]) + + # Downsample for display + n = len(X) + if n > PLOT_DOWNSAMPLE: + rng = np.random.default_rng(42) + sel = rng.choice(n, size=PLOT_DOWNSAMPLE, replace=False) + Xp, Yp = X[sel], Y[sel] + Wp = W_res_filt[sel] + dxp, dyp, dzp = dx[sel], dy[sel], dz[sel] + else: + Xp, Yp, Wp = X, Y, W_res_filt + dxp, dyp, dzp = dx, dy, dz + + res_amp = AMP * Wp + max_amp = float(np.max(np.abs(res_amp))) if res_amp.size else 1.0 + + # Helper to build mesh trace + def build_mesh_trace(Zp, colorscale, colorbar_title, unit): + try: + tri = Triangulation(Xp, Yp) + if tri.triangles is not None and len(tri.triangles) > 0: + i_idx, j_idx, k_idx = tri.triangles.T + return go.Mesh3d( + x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(), + i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(), + intensity=Zp.tolist(), + colorscale=colorscale, + opacity=1.0, + flatshading=False, + lighting=dict(ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2), + lightposition=dict(x=100, y=200, z=300), + showscale=True, + colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15, len=0.5), + hovertemplate=f"X: %{{x:.1f}}
Y: %{{y:.1f}}
{unit}: %{{z:.3f}}" + ) + except Exception: + pass + return go.Scatter3d( + x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(), + mode='markers', marker=dict(size=2, color=Zp.tolist(), colorscale=colorscale, showscale=True) + ) + + # Build traces for each view + trace_wfe = build_mesh_trace(res_amp, COLORSCALE, "WFE (nm)", "WFE nm") + trace_dx = build_mesh_trace(dxp, 'RdBu_r', "ΔX (µm)", "ΔX µm") + trace_dy = build_mesh_trace(dyp, 'RdBu_r', "ΔY (µm)", "ΔY µm") + trace_dz = build_mesh_trace(dzp, 'RdBu_r', "ΔZ (µm)", "ΔZ µm") + + # Create figure with dropdown to switch views + fig = go.Figure() + + # Add all traces (only WFE visible initially) + trace_wfe.visible = True + trace_dx.visible = False + trace_dy.visible = False + trace_dz.visible = False + + fig.add_trace(trace_wfe) + fig.add_trace(trace_dx) + fig.add_trace(trace_dy) + fig.add_trace(trace_dz) + + # Dropdown menu for view selection + fig.update_layout( + updatemenus=[ + dict( + type="buttons", + direction="right", + x=0.0, y=1.12, + xanchor="left", + showactive=True, + buttons=[ + dict(label="WFE (nm)", method="update", + args=[{"visible": [True, False, False, False]}]), + dict(label="ΔX (µm)", method="update", + args=[{"visible": [False, True, False, False]}]), + dict(label="ΔY (µm)", method="update", + args=[{"visible": [False, False, True, False]}]), + dict(label="ΔZ (µm)", method="update", + args=[{"visible": [False, False, False, True]}]), + ], + font=dict(size=12), + pad=dict(r=10, t=10), + ) + ] + ) + + # Compute method difference + pct_diff = 100.0 * (rms_filtered_opd - rms_filtered_std) / rms_filtered_std if rms_filtered_std > 0 else 0.0 + + # Annotations for metrics + method_label = "OPD (X,Y,Z)" if use_opd else "Standard (Z-only)" + annotations_text = f""" +Method: {method_label} {'← More Accurate' if use_opd else '(BDF not found)'} +RMS Metrics (Filtered J1-J4): + • OPD: {rms_filtered_opd:.2f} nm + • Standard: {rms_filtered_std:.2f} nm + • Δ: {pct_diff:+.1f}% +Lateral Displacement: + • Max: {max_lateral:.3f} µm + • RMS: {rms_lateral:.3f} µm +Displacement RMS: + • ΔX: {float(np.sqrt(np.mean(dx**2))):.3f} µm + • ΔY: {float(np.sqrt(np.mean(dy**2))):.3f} µm + • ΔZ: {float(np.sqrt(np.mean(dz**2))):.3f} µm +""" + + # Z-axis range for different views + max_disp = max(float(np.max(np.abs(dxp))), float(np.max(np.abs(dyp))), float(np.max(np.abs(dzp))), 0.1) + + fig.update_layout( + scene=dict( + camera=dict(eye=dict(x=1.2, y=1.2, z=0.8), up=dict(x=0, y=0, z=1)), + xaxis=dict(title="X (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)', + showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'), + yaxis=dict(title="Y (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)', + showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'), + zaxis=dict(title="Value", showgrid=True, gridcolor='rgba(128,128,128,0.3)', + showbackground=True, backgroundcolor='rgba(230,230,250,0.9)'), + aspectmode='manual', + aspectratio=dict(x=1, y=1, z=0.4) + ), + width=1400, + height=900, + margin=dict(t=120, b=20, l=20, r=20), + title=dict( + text=f"{title}
Click buttons to switch: WFE, ΔX, ΔY, ΔZ", + font=dict(size=18), + x=0.5 + ), + paper_bgcolor='white', + plot_bgcolor='white', + annotations=[ + dict( + text=annotations_text.replace('\n', '
'), + xref="paper", yref="paper", + x=1.02, y=0.98, + xanchor="left", yanchor="top", + showarrow=False, + font=dict(family="monospace", size=11), + align="left", + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.3)", + borderwidth=1, + borderpad=8 + ), + dict( + text="View:", + xref="paper", yref="paper", + x=0.0, y=1.15, + xanchor="left", yanchor="top", + showarrow=False, + font=dict(size=12) + ) + ] + ) + + html_content = fig.to_html(include_plotlyjs='cdn', full_html=True) + + return (html_content, rms_global_opd, rms_filtered_opd, { + 'max_lateral_um': max_lateral, + 'rms_lateral_um': rms_lateral, + 'method': 'opd' if use_opd else 'standard', + 'rms_std': rms_filtered_std, + 'pct_diff': pct_diff + }) + + # Generate results for each comparison results = {} comparisons = [ ('3', '2', '40_vs_20', '40 deg vs 20 deg'), @@ -2878,84 +3214,31 @@ async def get_trial_zernike(study_id: str, trial_number: int): ('1', '2', '90_vs_20', '90 deg vs 20 deg (manufacturing)'), ] - # Pre-compute absolute 90 deg metrics for manufacturing view - abs_90_data = None - abs_90_metrics = None - if '1' in extractor.displacements: - abs_90_data = extractor.extract_subcase('1', include_coefficients=True) - abs_90_metrics = compute_manufacturing_metrics(np.array(abs_90_data['coefficients'])) - for target_sc, ref_sc, key, title_suffix in comparisons: - if target_sc not in extractor.displacements: + if target_sc not in std_extractor.displacements: continue - # Get relative data with coefficients - rel_data = extractor.extract_relative(target_sc, ref_sc, include_coefficients=True) - - # Get absolute data for this subcase - abs_data = extractor.extract_subcase(target_sc, include_coefficients=True) - - # Build coordinate arrays - target_disp = extractor.displacements[target_sc] - ref_disp = extractor.displacements[ref_sc] - - ref_node_to_idx = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])} - X_list, Y_list, W_list = [], [], [] - - for i, nid in enumerate(target_disp['node_ids']): - nid = int(nid) - if nid not in ref_node_to_idx: - continue - geo = extractor.node_geometry.get(nid) - if geo is None: - continue - - ref_idx = ref_node_to_idx[nid] - target_wfe = target_disp['disp'][i, 2] * extractor.wfe_factor - ref_wfe = ref_disp['disp'][ref_idx, 2] * extractor.wfe_factor - - X_list.append(geo[0]) - Y_list.append(geo[1]) - W_list.append(target_wfe - ref_wfe) - - X = np.array(X_list) - Y = np.array(Y_list) - W = np.array(W_list) - target_angle = SUBCASE_MAP.get(target_sc, target_sc) ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc) - - # Check if this is the manufacturing (90 deg) comparison is_mfg = (key == '90_vs_20') - # Compute correction metrics (relative coefficients) for manufacturing view - correction_metrics = None - if is_mfg and 'coefficients' in rel_data: - correction_metrics = compute_manufacturing_metrics(np.array(rel_data['coefficients'])) - # Also compute rms_filter_j1to3 for the relative data - R = float(np.max(np.hypot(X - np.mean(X), Y - np.mean(Y)))) - correction_metrics['rms_filter_j1to3'] = compute_rms_filter_j1to3( - X, Y, W, np.array(rel_data['coefficients']), R - ) - - html_content = generate_zernike_html( - title=f"iter{trial_number}: {target_angle} deg vs {ref_angle} deg", - X=X, Y=Y, W_nm=W, - coefficients=np.array(rel_data['coefficients']), - rms_global=rel_data['relative_global_rms_nm'], - rms_filtered=rel_data['relative_filtered_rms_nm'], - ref_title=f"{ref_angle} deg", - abs_pair=(abs_data['global_rms_nm'], abs_data['filtered_rms_nm']), - is_manufacturing=is_mfg, - mfg_metrics=abs_90_metrics if is_mfg else None, - correction_metrics=correction_metrics + html_content, rms_global, rms_filtered, lateral_stats = generate_dual_method_html( + title=f"iter{trial_number}: {target_angle}° vs {ref_angle}°", + target_sc=target_sc, + ref_sc=ref_sc, + is_manufacturing=is_mfg ) results[key] = { "html": html_content, - "rms_global": rel_data['relative_global_rms_nm'], - "rms_filtered": rel_data['relative_filtered_rms_nm'], - "title": f"{target_angle}° vs {ref_angle}°" + "rms_global": rms_global, + "rms_filtered": rms_filtered, + "title": f"{target_angle}° vs {ref_angle}°", + "method": lateral_stats['method'], + "rms_std": lateral_stats['rms_std'], + "pct_diff": lateral_stats['pct_diff'], + "max_lateral_um": lateral_stats['max_lateral_um'], + "rms_lateral_um": lateral_stats['rms_lateral_um'] } if not results: @@ -2968,7 +3251,8 @@ async def get_trial_zernike(study_id: str, trial_number: int): "study_id": study_id, "trial_number": trial_number, "comparisons": results, - "available_comparisons": list(results.keys()) + "available_comparisons": list(results.keys()), + "method": "opd" if use_opd else "standard" } except HTTPException: