Files
Atomizer/optimization_engine/gnn/gnn_optimizer.py

719 lines
26 KiB
Python
Raw Normal View History

feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00
"""
GNN-Based Optimizer for Zernike Mirror Optimization
====================================================
This module provides a fast GNN-based optimization workflow:
1. Load trained GNN checkpoint
2. Run thousands of fast GNN predictions
3. Select top candidates
4. Validate with FEA (optional)
Usage:
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
optimizer = ZernikeGNNOptimizer.from_checkpoint('zernike_gnn_checkpoint.pt')
results = optimizer.turbo_optimize(n_trials=5000)
# Get best designs
best = results.get_best(n=10)
# Validate with FEA
validated = optimizer.validate_with_fea(best, study_dir)
"""
import json
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
from datetime import datetime
import optuna
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
from optimization_engine.gnn.zernike_gnn import create_model, load_model
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer
@dataclass
class GNNPrediction:
"""Single GNN prediction result."""
design_vars: Dict[str, float]
objectives: Dict[str, float]
z_displacement: Optional[np.ndarray] = None # [3000, 4] if stored
def to_dict(self) -> Dict:
return {
'design_vars': self.design_vars,
'objectives': self.objectives,
}
@dataclass
class OptimizationResults:
"""Container for optimization results."""
predictions: List[GNNPrediction] = field(default_factory=list)
pareto_front: List[int] = field(default_factory=list) # Indices
def add(self, pred: GNNPrediction):
self.predictions.append(pred)
def get_best(self, n: int = 10, objective: str = 'rel_filtered_rms_40_vs_20') -> List[GNNPrediction]:
"""Get top N designs by a single objective."""
sorted_preds = sorted(self.predictions, key=lambda p: p.objectives.get(objective, float('inf')))
return sorted_preds[:n]
def get_pareto_front(self, objectives: List[str] = None) -> List[GNNPrediction]:
"""Get Pareto-optimal designs."""
if objectives is None:
objectives = ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']
# Extract objective values
obj_values = np.array([
[p.objectives.get(obj, float('inf')) for obj in objectives]
for p in self.predictions
])
# Find Pareto front (all objectives are minimized)
pareto_indices = []
for i in range(len(self.predictions)):
is_dominated = False
for j in range(len(self.predictions)):
if i != j:
# j dominates i if j is <= in all objectives and < in at least one
if np.all(obj_values[j] <= obj_values[i]) and np.any(obj_values[j] < obj_values[i]):
is_dominated = True
break
if not is_dominated:
pareto_indices.append(i)
self.pareto_front = pareto_indices
return [self.predictions[i] for i in pareto_indices]
def to_dataframe(self):
"""Convert to pandas DataFrame."""
import pandas as pd
rows = []
for i, pred in enumerate(self.predictions):
row = {'index': i}
row.update(pred.design_vars)
row.update({f'obj_{k}': v for k, v in pred.objectives.items()})
rows.append(row)
return pd.DataFrame(rows)
def save(self, path: Path):
"""Save results to JSON."""
data = {
'n_predictions': len(self.predictions),
'pareto_front_indices': self.pareto_front,
'predictions': [p.to_dict() for p in self.predictions],
'timestamp': datetime.now().isoformat(),
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
class ZernikeGNNOptimizer:
"""
GNN-based optimizer for Zernike mirror optimization.
Provides fast objective prediction using trained GNN surrogate.
"""
def __init__(
self,
model: nn.Module,
polar_graph: PolarMirrorGraph,
design_names: List[str],
design_bounds: Dict[str, Tuple[float, float]],
design_mean: torch.Tensor,
design_std: torch.Tensor,
device: str = 'cpu',
disp_scale: float = 1.0
):
self.model = model.to(device)
self.model.eval()
self.polar_graph = polar_graph
self.design_names = design_names
self.design_bounds = design_bounds
self.design_mean = design_mean.to(device)
self.design_std = design_std.to(device)
self.device = torch.device(device)
self.disp_scale = disp_scale # Scaling factor from training
# Prepare fixed graph tensors
self.node_features = torch.tensor(
polar_graph.get_node_features(normalized=True),
dtype=torch.float32
).to(device)
self.edge_index = torch.tensor(
polar_graph.edge_index,
dtype=torch.long
).to(device)
self.edge_attr = torch.tensor(
polar_graph.get_edge_features(normalized=True),
dtype=torch.float32
).to(device)
# Objective computation layer (must be on same device as model)
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes=50).to(device)
@classmethod
def from_checkpoint(
cls,
checkpoint_path: Path,
config_path: Optional[Path] = None,
device: str = 'auto'
) -> 'ZernikeGNNOptimizer':
"""
Load optimizer from trained checkpoint.
Args:
checkpoint_path: Path to zernike_gnn_checkpoint.pt
config_path: Path to optimization_config.json (for design bounds)
device: Device to use ('cpu', 'cuda', 'auto')
"""
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_path = Path(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Create polar graph
polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
# Create model - handle both old ('model_config') and new ('config') format
model_config = checkpoint.get('model_config') or checkpoint.get('config', {})
model = create_model(**model_config)
model.load_state_dict(checkpoint['model_state_dict'])
# Get design info from checkpoint
design_mean = checkpoint['design_mean']
design_std = checkpoint['design_std']
disp_scale = checkpoint.get('disp_scale', 1.0) # Displacement scaling factor
# Try to get design names and bounds from config
design_names = []
design_bounds = {}
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
config = json.load(f)
for var in config.get('design_variables', []):
name = var['name']
design_names.append(name)
design_bounds[name] = (var['min'], var['max'])
else:
# Use generic names based on checkpoint
n_vars = len(design_mean)
design_names = [f'var_{i}' for i in range(n_vars)]
# Default bounds (will be overridden if config provided)
for name in design_names:
design_bounds[name] = (-100, 100)
return cls(
model=model,
polar_graph=polar_graph,
design_names=design_names,
design_bounds=design_bounds,
design_mean=design_mean,
design_std=design_std,
device=device,
disp_scale=disp_scale
)
@torch.no_grad()
def predict(self, design_vars: Dict[str, float], return_field: bool = False) -> GNNPrediction:
"""
Predict objectives for a single design.
Args:
design_vars: Dict mapping variable names to values
return_field: Whether to include displacement field in result
Returns:
GNNPrediction with objectives
"""
# Convert to tensor
design_values = [design_vars.get(name, 0.0) for name in self.design_names]
design_tensor = torch.tensor(design_values, dtype=torch.float32).to(self.device)
# Normalize
design_norm = (design_tensor - self.design_mean) / self.design_std
# Forward pass
z_disp_scaled = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design_norm
) # [3000, 4] in scaled units (μm if disp_scale=1e6)
# Convert back to mm before computing objectives
# During training: z_disp_mm * disp_scale = z_disp_scaled
# So: z_disp_mm = z_disp_scaled / disp_scale
z_disp_mm = z_disp_scaled / self.disp_scale
# Compute objectives (ZernikeObjectiveLayer expects mm input)
objectives = self.objective_layer(z_disp_mm)
# Objectives are now directly in nm (no additional scaling needed)
obj_dict = {
'rel_filtered_rms_40_vs_20': objectives['rel_filtered_rms_40_vs_20'].item(),
'rel_filtered_rms_60_vs_20': objectives['rel_filtered_rms_60_vs_20'].item(),
'mfg_90_optician_workload': objectives['mfg_90_optician_workload'].item(),
}
field_data = z_disp_mm.cpu().numpy() if return_field else None
return GNNPrediction(
design_vars=design_vars,
objectives=obj_dict,
z_displacement=field_data
)
@torch.no_grad()
def predict_batch(self, designs: List[Dict[str, float]]) -> List[GNNPrediction]:
"""
Predict objectives for multiple designs (batched for efficiency).
Args:
designs: List of design variable dicts
Returns:
List of GNNPrediction
"""
results = []
for design in designs:
results.append(self.predict(design))
return results
def random_design(self) -> Dict[str, float]:
"""Generate a random design within bounds."""
design = {}
for name in self.design_names:
low, high = self.design_bounds.get(name, (-100, 100))
design[name] = np.random.uniform(low, high)
return design
def turbo_optimize(
self,
n_trials: int = 5000,
sampler: str = 'tpe',
seed: int = 42,
verbose: bool = True
) -> OptimizationResults:
"""
Run fast GNN-based optimization.
Args:
n_trials: Number of GNN trials to run
sampler: Optuna sampler ('tpe', 'random', 'cmaes')
seed: Random seed
verbose: Print progress
Returns:
OptimizationResults with all predictions
"""
np.random.seed(seed)
results = OptimizationResults()
if verbose:
print(f"\n{'='*60}")
print("GNN TURBO OPTIMIZATION")
print(f"{'='*60}")
print(f"Trials: {n_trials}")
print(f"Sampler: {sampler}")
print(f"Design variables: {len(self.design_names)}")
print(f"Device: {self.device}")
# Create Optuna study for smart sampling
if sampler == 'tpe':
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
elif sampler == 'random':
optuna_sampler = optuna.samplers.RandomSampler(seed=seed)
elif sampler == 'cmaes':
optuna_sampler = optuna.samplers.CmaEsSampler(seed=seed)
else:
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(
directions=['minimize', 'minimize', 'minimize'], # 3 objectives
sampler=optuna_sampler
)
start_time = datetime.now()
def objective(trial):
# Sample design
design = {}
for name in self.design_names:
low, high = self.design_bounds.get(name, (-100, 100))
design[name] = trial.suggest_float(name, low, high)
# Predict with GNN
pred = self.predict(design)
results.add(pred)
return (
pred.objectives['rel_filtered_rms_40_vs_20'],
pred.objectives['rel_filtered_rms_60_vs_20'],
pred.objectives['mfg_90_optician_workload']
)
# Run optimization
if verbose:
print(f"\nRunning {n_trials} GNN trials...")
optuna.logging.set_verbosity(optuna.logging.WARNING)
study.optimize(objective, n_trials=n_trials, show_progress_bar=verbose)
elapsed = (datetime.now() - start_time).total_seconds()
if verbose:
print(f"\nCompleted in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
# Compute Pareto front
pareto = results.get_pareto_front()
print(f"Pareto front: {len(pareto)} designs")
# Best by each objective
print("\nBest by objective:")
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
best = results.get_best(n=1, objective=obj)[0]
print(f" {obj}: {best.objectives[obj]:.2f} nm")
return results
def validate_with_fea(
self,
candidates: List[GNNPrediction],
study_dir: Path,
verbose: bool = True,
start_trial_num: int = 9000
) -> List[Dict]:
"""
Validate GNN predictions with actual FEA.
This runs the full NX + Nastran workflow on each candidate
to get true objective values.
Args:
candidates: GNN predictions to validate
study_dir: Path to study directory (for config and scripts)
verbose: Print progress
start_trial_num: Starting trial number for iteration folders
Returns:
List of dicts with 'gnn' and 'fea' objectives for comparison
"""
import time
import re
from optimization_engine.nx_solver import NXSolver
from optimization_engine.extractors import ZernikeExtractor
study_dir = Path(study_dir)
config_path = study_dir / "1_setup" / "optimization_config.json"
model_dir = study_dir / "1_setup" / "model"
iterations_dir = study_dir / "2_iterations"
# Load config
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
with open(config_path) as f:
config = json.load(f)
# Initialize NX Solver
nx_settings = config.get('nx_settings', {})
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
version_match = re.search(r'NX(\d+)', nx_install_dir)
nastran_version = version_match.group(1) if version_match else "2506"
solver = NXSolver(
master_model_dir=str(model_dir),
nx_install_dir=nx_install_dir,
nastran_version=nastran_version,
timeout=nx_settings.get('simulation_timeout_s', 600),
use_iteration_folders=True,
study_name="gnn_validation"
)
iterations_dir.mkdir(exist_ok=True)
results = []
if verbose:
print(f"\n{'='*60}")
print("FEA VALIDATION OF GNN PREDICTIONS")
print(f"{'='*60}")
print(f"Validating {len(candidates)} candidates")
print(f"Study: {study_dir.name}")
for i, candidate in enumerate(candidates):
trial_num = start_trial_num + i
if verbose:
print(f"\n[{i+1}/{len(candidates)}] Trial {trial_num}")
print(f" GNN predicted: 40vs20={candidate.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
# Build expression updates from design variables
expressions = {}
for var in config.get('design_variables', []):
var_name = var['name']
expr_name = var.get('expression_name', var_name)
if var_name in candidate.design_vars:
expressions[expr_name] = candidate.design_vars[var_name]
# Create iteration folder with model copies
try:
iter_folder = solver.create_iteration_folder(
iterations_base_dir=iterations_dir,
iteration_number=trial_num,
expression_updates=expressions
)
except Exception as e:
if verbose:
print(f" ERROR creating iteration folder: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'error',
'error': str(e)
})
continue
# Run simulation
sim_file = iter_folder / nx_settings.get('sim_file', 'ASSY_M1_assyfem1_sim1.sim')
solution_name = nx_settings.get('solution_name', 'Solution 1')
t_start = time.time()
try:
solve_result = solver.run_simulation(
sim_file=sim_file,
working_dir=iter_folder,
expression_updates=expressions,
solution_name=solution_name,
cleanup=False
)
except Exception as e:
if verbose:
print(f" ERROR in simulation: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'solve_error',
'error': str(e)
})
continue
solve_time = time.time() - t_start
if not solve_result['success']:
if verbose:
print(f" Solve FAILED: {solve_result.get('errors', ['Unknown'])}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'solve_failed',
'errors': solve_result.get('errors', [])
})
continue
if verbose:
print(f" Solved in {solve_time:.1f}s")
# Extract objectives using ZernikeExtractor
op2_path = solve_result['op2_file']
if op2_path is None or not Path(op2_path).exists():
if verbose:
print(f" ERROR: OP2 file not found")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'no_op2',
})
continue
try:
zernike_settings = config.get('zernike_settings', {})
extractor = ZernikeExtractor(
op2_path,
bdf_path=None,
displacement_unit=zernike_settings.get('displacement_unit', 'mm'),
n_modes=zernike_settings.get('n_modes', 50),
filter_orders=zernike_settings.get('filter_low_orders', 4)
)
ref = zernike_settings.get('reference_subcase', '2')
# Extract objectives: 40 vs 20, 60 vs 20, mfg 90
rel_40 = extractor.extract_relative("3", ref)
rel_60 = extractor.extract_relative("4", ref)
rel_90 = extractor.extract_relative("1", ref)
fea_objectives = {
'rel_filtered_rms_40_vs_20': rel_40['relative_filtered_rms_nm'],
'rel_filtered_rms_60_vs_20': rel_60['relative_filtered_rms_nm'],
'mfg_90_optician_workload': rel_90['relative_rms_filter_j1to3'],
}
except Exception as e:
if verbose:
print(f" ERROR in Zernike extraction: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'extraction_error',
'error': str(e)
})
continue
# Compute errors
errors = {}
for obj_name in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
gnn_val = candidate.objectives[obj_name]
fea_val = fea_objectives[obj_name]
errors[f'{obj_name}_abs_error'] = abs(gnn_val - fea_val)
errors[f'{obj_name}_pct_error'] = 100 * abs(gnn_val - fea_val) / max(fea_val, 0.01)
if verbose:
print(f" FEA results:")
print(f" 40vs20: {fea_objectives['rel_filtered_rms_40_vs_20']:.2f} nm "
f"(GNN: {candidate.objectives['rel_filtered_rms_40_vs_20']:.2f}, "
f"err: {errors['rel_filtered_rms_40_vs_20_pct_error']:.1f}%)")
print(f" 60vs20: {fea_objectives['rel_filtered_rms_60_vs_20']:.2f} nm "
f"(GNN: {candidate.objectives['rel_filtered_rms_60_vs_20']:.2f}, "
f"err: {errors['rel_filtered_rms_60_vs_20_pct_error']:.1f}%)")
print(f" mfg90: {fea_objectives['mfg_90_optician_workload']:.2f} nm "
f"(GNN: {candidate.objectives['mfg_90_optician_workload']:.2f}, "
f"err: {errors['mfg_90_optician_workload_pct_error']:.1f}%)")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': fea_objectives,
'errors': errors,
'solve_time': solve_time,
'trial_num': trial_num,
'status': 'success'
})
# Summary
if verbose:
successful = [r for r in results if r['status'] == 'success']
print(f"\n{'='*60}")
print(f"VALIDATION SUMMARY")
print(f"{'='*60}")
print(f"Successful: {len(successful)}/{len(candidates)}")
if successful:
avg_errors = {}
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
avg_errors[obj] = np.mean([r['errors'][f'{obj}_pct_error'] for r in successful])
print(f"\nAverage prediction errors:")
print(f" 40 vs 20: {avg_errors['rel_filtered_rms_40_vs_20']:.1f}%")
print(f" 60 vs 20: {avg_errors['rel_filtered_rms_60_vs_20']:.1f}%")
print(f" mfg 90: {avg_errors['mfg_90_optician_workload']:.1f}%")
return results
def save_validation_report(
self,
validation_results: List[Dict],
output_path: Path
):
"""Save validation results to JSON file."""
report = {
'timestamp': datetime.now().isoformat(),
'n_candidates': len(validation_results),
'n_successful': len([r for r in validation_results if r['status'] == 'success']),
'results': validation_results,
}
# Compute summary statistics if we have successful results
successful = [r for r in validation_results if r['status'] == 'success']
if successful:
avg_errors = {}
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
errors = [r['errors'][f'{obj}_pct_error'] for r in successful]
avg_errors[obj] = {
'mean_pct': float(np.mean(errors)),
'std_pct': float(np.std(errors)),
'max_pct': float(np.max(errors)),
}
report['error_summary'] = avg_errors
with open(output_path, 'w') as f:
json.dump(report, f, indent=2)
print(f"Validation report saved to: {output_path}")
def main():
"""Example usage of GNN optimizer."""
import argparse
parser = argparse.ArgumentParser(description='GNN-based Zernike optimization')
parser.add_argument('checkpoint', type=Path, help='Path to GNN checkpoint')
parser.add_argument('--config', type=Path, help='Path to optimization_config.json')
parser.add_argument('--trials', type=int, default=5000, help='Number of GNN trials')
parser.add_argument('--output', '-o', type=Path, help='Output results JSON')
parser.add_argument('--top-n', type=int, default=20, help='Number of top candidates to show')
args = parser.parse_args()
# Load optimizer
print(f"Loading GNN from {args.checkpoint}...")
optimizer = ZernikeGNNOptimizer.from_checkpoint(
args.checkpoint,
config_path=args.config
)
# Run turbo optimization
results = optimizer.turbo_optimize(n_trials=args.trials)
# Show top candidates
print(f"\n{'='*60}")
print(f"TOP {args.top_n} CANDIDATES (by rel_filtered_rms_40_vs_20)")
print(f"{'='*60}")
top = results.get_best(n=args.top_n, objective='rel_filtered_rms_40_vs_20')
for i, pred in enumerate(top):
print(f"\n#{i+1}:")
print(f" 40 vs 20: {pred.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
print(f" 60 vs 20: {pred.objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
print(f" mfg_90: {pred.objectives['mfg_90_optician_workload']:.2f} nm")
# Save results
if args.output:
results.save(args.output)
print(f"\nResults saved to {args.output}")
# Show Pareto front
pareto = results.get_pareto_front()
print(f"\n{'='*60}")
print(f"PARETO FRONT: {len(pareto)} designs")
print(f"{'='*60}")
for i, pred in enumerate(pareto[:10]): # Show first 10
print(f" [{i+1}] 40vs20={pred.objectives['rel_filtered_rms_40_vs_20']:.1f}, "
f"60vs20={pred.objectives['rel_filtered_rms_60_vs_20']:.1f}, "
f"mfg={pred.objectives['mfg_90_optician_workload']:.1f}")
if __name__ == '__main__':
main()