""" GNN-Based Optimizer for Zernike Mirror Optimization ==================================================== This module provides a fast GNN-based optimization workflow: 1. Load trained GNN checkpoint 2. Run thousands of fast GNN predictions 3. Select top candidates 4. Validate with FEA (optional) Usage: from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer optimizer = ZernikeGNNOptimizer.from_checkpoint('zernike_gnn_checkpoint.pt') results = optimizer.turbo_optimize(n_trials=5000) # Get best designs best = results.get_best(n=10) # Validate with FEA validated = optimizer.validate_with_fea(best, study_dir) """ import json import numpy as np import torch import torch.nn as nn from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, field from datetime import datetime import optuna from optimization_engine.gnn.polar_graph import PolarMirrorGraph from optimization_engine.gnn.zernike_gnn import create_model, load_model from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer @dataclass class GNNPrediction: """Single GNN prediction result.""" design_vars: Dict[str, float] objectives: Dict[str, float] z_displacement: Optional[np.ndarray] = None # [3000, 4] if stored def to_dict(self) -> Dict: return { 'design_vars': self.design_vars, 'objectives': self.objectives, } @dataclass class OptimizationResults: """Container for optimization results.""" predictions: List[GNNPrediction] = field(default_factory=list) pareto_front: List[int] = field(default_factory=list) # Indices def add(self, pred: GNNPrediction): self.predictions.append(pred) def get_best(self, n: int = 10, objective: str = 'rel_filtered_rms_40_vs_20') -> List[GNNPrediction]: """Get top N designs by a single objective.""" sorted_preds = sorted(self.predictions, key=lambda p: p.objectives.get(objective, float('inf'))) return sorted_preds[:n] def get_pareto_front(self, objectives: List[str] = None) -> List[GNNPrediction]: """Get Pareto-optimal designs.""" if objectives is None: objectives = ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload'] # Extract objective values obj_values = np.array([ [p.objectives.get(obj, float('inf')) for obj in objectives] for p in self.predictions ]) # Find Pareto front (all objectives are minimized) pareto_indices = [] for i in range(len(self.predictions)): is_dominated = False for j in range(len(self.predictions)): if i != j: # j dominates i if j is <= in all objectives and < in at least one if np.all(obj_values[j] <= obj_values[i]) and np.any(obj_values[j] < obj_values[i]): is_dominated = True break if not is_dominated: pareto_indices.append(i) self.pareto_front = pareto_indices return [self.predictions[i] for i in pareto_indices] def to_dataframe(self): """Convert to pandas DataFrame.""" import pandas as pd rows = [] for i, pred in enumerate(self.predictions): row = {'index': i} row.update(pred.design_vars) row.update({f'obj_{k}': v for k, v in pred.objectives.items()}) rows.append(row) return pd.DataFrame(rows) def save(self, path: Path): """Save results to JSON.""" data = { 'n_predictions': len(self.predictions), 'pareto_front_indices': self.pareto_front, 'predictions': [p.to_dict() for p in self.predictions], 'timestamp': datetime.now().isoformat(), } with open(path, 'w') as f: json.dump(data, f, indent=2) class ZernikeGNNOptimizer: """ GNN-based optimizer for Zernike mirror optimization. Provides fast objective prediction using trained GNN surrogate. """ def __init__( self, model: nn.Module, polar_graph: PolarMirrorGraph, design_names: List[str], design_bounds: Dict[str, Tuple[float, float]], design_mean: torch.Tensor, design_std: torch.Tensor, device: str = 'cpu', disp_scale: float = 1.0 ): self.model = model.to(device) self.model.eval() self.polar_graph = polar_graph self.design_names = design_names self.design_bounds = design_bounds self.design_mean = design_mean.to(device) self.design_std = design_std.to(device) self.device = torch.device(device) self.disp_scale = disp_scale # Scaling factor from training # Prepare fixed graph tensors self.node_features = torch.tensor( polar_graph.get_node_features(normalized=True), dtype=torch.float32 ).to(device) self.edge_index = torch.tensor( polar_graph.edge_index, dtype=torch.long ).to(device) self.edge_attr = torch.tensor( polar_graph.get_edge_features(normalized=True), dtype=torch.float32 ).to(device) # Objective computation layer (must be on same device as model) self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes=50).to(device) @classmethod def from_checkpoint( cls, checkpoint_path: Path, config_path: Optional[Path] = None, device: str = 'auto' ) -> 'ZernikeGNNOptimizer': """ Load optimizer from trained checkpoint. Args: checkpoint_path: Path to zernike_gnn_checkpoint.pt config_path: Path to optimization_config.json (for design bounds) device: Device to use ('cpu', 'cuda', 'auto') """ if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' checkpoint_path = Path(checkpoint_path) checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # Create polar graph polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60) # Create model - handle both old ('model_config') and new ('config') format model_config = checkpoint.get('model_config') or checkpoint.get('config', {}) model = create_model(**model_config) model.load_state_dict(checkpoint['model_state_dict']) # Get design info from checkpoint design_mean = checkpoint['design_mean'] design_std = checkpoint['design_std'] disp_scale = checkpoint.get('disp_scale', 1.0) # Displacement scaling factor # Try to get design names and bounds from config design_names = [] design_bounds = {} if config_path and Path(config_path).exists(): with open(config_path, 'r') as f: config = json.load(f) for var in config.get('design_variables', []): name = var['name'] design_names.append(name) design_bounds[name] = (var['min'], var['max']) else: # Use generic names based on checkpoint n_vars = len(design_mean) design_names = [f'var_{i}' for i in range(n_vars)] # Default bounds (will be overridden if config provided) for name in design_names: design_bounds[name] = (-100, 100) return cls( model=model, polar_graph=polar_graph, design_names=design_names, design_bounds=design_bounds, design_mean=design_mean, design_std=design_std, device=device, disp_scale=disp_scale ) @torch.no_grad() def predict(self, design_vars: Dict[str, float], return_field: bool = False) -> GNNPrediction: """ Predict objectives for a single design. Args: design_vars: Dict mapping variable names to values return_field: Whether to include displacement field in result Returns: GNNPrediction with objectives """ # Convert to tensor design_values = [design_vars.get(name, 0.0) for name in self.design_names] design_tensor = torch.tensor(design_values, dtype=torch.float32).to(self.device) # Normalize design_norm = (design_tensor - self.design_mean) / self.design_std # Forward pass z_disp_scaled = self.model( self.node_features, self.edge_index, self.edge_attr, design_norm ) # [3000, 4] in scaled units (μm if disp_scale=1e6) # Convert back to mm before computing objectives # During training: z_disp_mm * disp_scale = z_disp_scaled # So: z_disp_mm = z_disp_scaled / disp_scale z_disp_mm = z_disp_scaled / self.disp_scale # Compute objectives (ZernikeObjectiveLayer expects mm input) objectives = self.objective_layer(z_disp_mm) # Objectives are now directly in nm (no additional scaling needed) obj_dict = { 'rel_filtered_rms_40_vs_20': objectives['rel_filtered_rms_40_vs_20'].item(), 'rel_filtered_rms_60_vs_20': objectives['rel_filtered_rms_60_vs_20'].item(), 'mfg_90_optician_workload': objectives['mfg_90_optician_workload'].item(), } field_data = z_disp_mm.cpu().numpy() if return_field else None return GNNPrediction( design_vars=design_vars, objectives=obj_dict, z_displacement=field_data ) @torch.no_grad() def predict_batch(self, designs: List[Dict[str, float]]) -> List[GNNPrediction]: """ Predict objectives for multiple designs (batched for efficiency). Args: designs: List of design variable dicts Returns: List of GNNPrediction """ results = [] for design in designs: results.append(self.predict(design)) return results def random_design(self) -> Dict[str, float]: """Generate a random design within bounds.""" design = {} for name in self.design_names: low, high = self.design_bounds.get(name, (-100, 100)) design[name] = np.random.uniform(low, high) return design def turbo_optimize( self, n_trials: int = 5000, sampler: str = 'tpe', seed: int = 42, verbose: bool = True ) -> OptimizationResults: """ Run fast GNN-based optimization. Args: n_trials: Number of GNN trials to run sampler: Optuna sampler ('tpe', 'random', 'cmaes') seed: Random seed verbose: Print progress Returns: OptimizationResults with all predictions """ np.random.seed(seed) results = OptimizationResults() if verbose: print(f"\n{'='*60}") print("GNN TURBO OPTIMIZATION") print(f"{'='*60}") print(f"Trials: {n_trials}") print(f"Sampler: {sampler}") print(f"Design variables: {len(self.design_names)}") print(f"Device: {self.device}") # Create Optuna study for smart sampling if sampler == 'tpe': optuna_sampler = optuna.samplers.TPESampler(seed=seed) elif sampler == 'random': optuna_sampler = optuna.samplers.RandomSampler(seed=seed) elif sampler == 'cmaes': optuna_sampler = optuna.samplers.CmaEsSampler(seed=seed) else: optuna_sampler = optuna.samplers.TPESampler(seed=seed) study = optuna.create_study( directions=['minimize', 'minimize', 'minimize'], # 3 objectives sampler=optuna_sampler ) start_time = datetime.now() def objective(trial): # Sample design design = {} for name in self.design_names: low, high = self.design_bounds.get(name, (-100, 100)) design[name] = trial.suggest_float(name, low, high) # Predict with GNN pred = self.predict(design) results.add(pred) return ( pred.objectives['rel_filtered_rms_40_vs_20'], pred.objectives['rel_filtered_rms_60_vs_20'], pred.objectives['mfg_90_optician_workload'] ) # Run optimization if verbose: print(f"\nRunning {n_trials} GNN trials...") optuna.logging.set_verbosity(optuna.logging.WARNING) study.optimize(objective, n_trials=n_trials, show_progress_bar=verbose) elapsed = (datetime.now() - start_time).total_seconds() if verbose: print(f"\nCompleted in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)") # Compute Pareto front pareto = results.get_pareto_front() print(f"Pareto front: {len(pareto)} designs") # Best by each objective print("\nBest by objective:") for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']: best = results.get_best(n=1, objective=obj)[0] print(f" {obj}: {best.objectives[obj]:.2f} nm") return results def validate_with_fea( self, candidates: List[GNNPrediction], study_dir: Path, verbose: bool = True, start_trial_num: int = 9000 ) -> List[Dict]: """ Validate GNN predictions with actual FEA. This runs the full NX + Nastran workflow on each candidate to get true objective values. Args: candidates: GNN predictions to validate study_dir: Path to study directory (for config and scripts) verbose: Print progress start_trial_num: Starting trial number for iteration folders Returns: List of dicts with 'gnn' and 'fea' objectives for comparison """ import time import re from optimization_engine.nx.solver import NXSolver from optimization_engine.extractors import ZernikeExtractor study_dir = Path(study_dir) config_path = study_dir / "1_setup" / "optimization_config.json" model_dir = study_dir / "1_setup" / "model" iterations_dir = study_dir / "2_iterations" # Load config if not config_path.exists(): raise FileNotFoundError(f"Config not found: {config_path}") with open(config_path) as f: config = json.load(f) # Initialize NX Solver nx_settings = config.get('nx_settings', {}) nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506') version_match = re.search(r'NX(\d+)', nx_install_dir) nastran_version = version_match.group(1) if version_match else "2506" solver = NXSolver( master_model_dir=str(model_dir), nx_install_dir=nx_install_dir, nastran_version=nastran_version, timeout=nx_settings.get('simulation_timeout_s', 600), use_iteration_folders=True, study_name="gnn_validation" ) iterations_dir.mkdir(exist_ok=True) results = [] if verbose: print(f"\n{'='*60}") print("FEA VALIDATION OF GNN PREDICTIONS") print(f"{'='*60}") print(f"Validating {len(candidates)} candidates") print(f"Study: {study_dir.name}") for i, candidate in enumerate(candidates): trial_num = start_trial_num + i if verbose: print(f"\n[{i+1}/{len(candidates)}] Trial {trial_num}") print(f" GNN predicted: 40vs20={candidate.objectives['rel_filtered_rms_40_vs_20']:.2f} nm") # Build expression updates from design variables expressions = {} for var in config.get('design_variables', []): var_name = var['name'] expr_name = var.get('expression_name', var_name) if var_name in candidate.design_vars: expressions[expr_name] = candidate.design_vars[var_name] # Create iteration folder with model copies try: iter_folder = solver.create_iteration_folder( iterations_base_dir=iterations_dir, iteration_number=trial_num, expression_updates=expressions ) except Exception as e: if verbose: print(f" ERROR creating iteration folder: {e}") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': None, 'status': 'error', 'error': str(e) }) continue # Run simulation sim_file = iter_folder / nx_settings.get('sim_file', 'ASSY_M1_assyfem1_sim1.sim') solution_name = nx_settings.get('solution_name', 'Solution 1') t_start = time.time() try: solve_result = solver.run_simulation( sim_file=sim_file, working_dir=iter_folder, expression_updates=expressions, solution_name=solution_name, cleanup=False ) except Exception as e: if verbose: print(f" ERROR in simulation: {e}") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': None, 'status': 'solve_error', 'error': str(e) }) continue solve_time = time.time() - t_start if not solve_result['success']: if verbose: print(f" Solve FAILED: {solve_result.get('errors', ['Unknown'])}") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': None, 'status': 'solve_failed', 'errors': solve_result.get('errors', []) }) continue if verbose: print(f" Solved in {solve_time:.1f}s") # Extract objectives using ZernikeExtractor op2_path = solve_result['op2_file'] if op2_path is None or not Path(op2_path).exists(): if verbose: print(f" ERROR: OP2 file not found") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': None, 'status': 'no_op2', }) continue try: zernike_settings = config.get('zernike_settings', {}) extractor = ZernikeExtractor( op2_path, bdf_path=None, displacement_unit=zernike_settings.get('displacement_unit', 'mm'), n_modes=zernike_settings.get('n_modes', 50), filter_orders=zernike_settings.get('filter_low_orders', 4) ) ref = zernike_settings.get('reference_subcase', '2') # Extract objectives: 40 vs 20, 60 vs 20, mfg 90 rel_40 = extractor.extract_relative("3", ref) rel_60 = extractor.extract_relative("4", ref) rel_90 = extractor.extract_relative("1", ref) fea_objectives = { 'rel_filtered_rms_40_vs_20': rel_40['relative_filtered_rms_nm'], 'rel_filtered_rms_60_vs_20': rel_60['relative_filtered_rms_nm'], 'mfg_90_optician_workload': rel_90['relative_rms_filter_j1to3'], } except Exception as e: if verbose: print(f" ERROR in Zernike extraction: {e}") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': None, 'status': 'extraction_error', 'error': str(e) }) continue # Compute errors errors = {} for obj_name in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']: gnn_val = candidate.objectives[obj_name] fea_val = fea_objectives[obj_name] errors[f'{obj_name}_abs_error'] = abs(gnn_val - fea_val) errors[f'{obj_name}_pct_error'] = 100 * abs(gnn_val - fea_val) / max(fea_val, 0.01) if verbose: print(f" FEA results:") print(f" 40vs20: {fea_objectives['rel_filtered_rms_40_vs_20']:.2f} nm " f"(GNN: {candidate.objectives['rel_filtered_rms_40_vs_20']:.2f}, " f"err: {errors['rel_filtered_rms_40_vs_20_pct_error']:.1f}%)") print(f" 60vs20: {fea_objectives['rel_filtered_rms_60_vs_20']:.2f} nm " f"(GNN: {candidate.objectives['rel_filtered_rms_60_vs_20']:.2f}, " f"err: {errors['rel_filtered_rms_60_vs_20_pct_error']:.1f}%)") print(f" mfg90: {fea_objectives['mfg_90_optician_workload']:.2f} nm " f"(GNN: {candidate.objectives['mfg_90_optician_workload']:.2f}, " f"err: {errors['mfg_90_optician_workload_pct_error']:.1f}%)") results.append({ 'design': candidate.design_vars, 'gnn_objectives': candidate.objectives, 'fea_objectives': fea_objectives, 'errors': errors, 'solve_time': solve_time, 'trial_num': trial_num, 'status': 'success' }) # Summary if verbose: successful = [r for r in results if r['status'] == 'success'] print(f"\n{'='*60}") print(f"VALIDATION SUMMARY") print(f"{'='*60}") print(f"Successful: {len(successful)}/{len(candidates)}") if successful: avg_errors = {} for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']: avg_errors[obj] = np.mean([r['errors'][f'{obj}_pct_error'] for r in successful]) print(f"\nAverage prediction errors:") print(f" 40 vs 20: {avg_errors['rel_filtered_rms_40_vs_20']:.1f}%") print(f" 60 vs 20: {avg_errors['rel_filtered_rms_60_vs_20']:.1f}%") print(f" mfg 90: {avg_errors['mfg_90_optician_workload']:.1f}%") return results def save_validation_report( self, validation_results: List[Dict], output_path: Path ): """Save validation results to JSON file.""" report = { 'timestamp': datetime.now().isoformat(), 'n_candidates': len(validation_results), 'n_successful': len([r for r in validation_results if r['status'] == 'success']), 'results': validation_results, } # Compute summary statistics if we have successful results successful = [r for r in validation_results if r['status'] == 'success'] if successful: avg_errors = {} for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']: errors = [r['errors'][f'{obj}_pct_error'] for r in successful] avg_errors[obj] = { 'mean_pct': float(np.mean(errors)), 'std_pct': float(np.std(errors)), 'max_pct': float(np.max(errors)), } report['error_summary'] = avg_errors with open(output_path, 'w') as f: json.dump(report, f, indent=2) print(f"Validation report saved to: {output_path}") def main(): """Example usage of GNN optimizer.""" import argparse parser = argparse.ArgumentParser(description='GNN-based Zernike optimization') parser.add_argument('checkpoint', type=Path, help='Path to GNN checkpoint') parser.add_argument('--config', type=Path, help='Path to optimization_config.json') parser.add_argument('--trials', type=int, default=5000, help='Number of GNN trials') parser.add_argument('--output', '-o', type=Path, help='Output results JSON') parser.add_argument('--top-n', type=int, default=20, help='Number of top candidates to show') args = parser.parse_args() # Load optimizer print(f"Loading GNN from {args.checkpoint}...") optimizer = ZernikeGNNOptimizer.from_checkpoint( args.checkpoint, config_path=args.config ) # Run turbo optimization results = optimizer.turbo_optimize(n_trials=args.trials) # Show top candidates print(f"\n{'='*60}") print(f"TOP {args.top_n} CANDIDATES (by rel_filtered_rms_40_vs_20)") print(f"{'='*60}") top = results.get_best(n=args.top_n, objective='rel_filtered_rms_40_vs_20') for i, pred in enumerate(top): print(f"\n#{i+1}:") print(f" 40 vs 20: {pred.objectives['rel_filtered_rms_40_vs_20']:.2f} nm") print(f" 60 vs 20: {pred.objectives['rel_filtered_rms_60_vs_20']:.2f} nm") print(f" mfg_90: {pred.objectives['mfg_90_optician_workload']:.2f} nm") # Save results if args.output: results.save(args.output) print(f"\nResults saved to {args.output}") # Show Pareto front pareto = results.get_pareto_front() print(f"\n{'='*60}") print(f"PARETO FRONT: {len(pareto)} designs") print(f"{'='*60}") for i, pred in enumerate(pareto[:10]): # Show first 10 print(f" [{i+1}] 40vs20={pred.objectives['rel_filtered_rms_40_vs_20']:.1f}, " f"60vs20={pred.objectives['rel_filtered_rms_60_vs_20']:.1f}, " f"mfg={pred.objectives['mfg_90_optician_workload']:.1f}") if __name__ == '__main__': main()