719 lines
26 KiB
Python
719 lines
26 KiB
Python
|
|
"""
|
||
|
|
GNN-Based Optimizer for Zernike Mirror Optimization
|
||
|
|
====================================================
|
||
|
|
|
||
|
|
This module provides a fast GNN-based optimization workflow:
|
||
|
|
1. Load trained GNN checkpoint
|
||
|
|
2. Run thousands of fast GNN predictions
|
||
|
|
3. Select top candidates
|
||
|
|
4. Validate with FEA (optional)
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
|
||
|
|
|
||
|
|
optimizer = ZernikeGNNOptimizer.from_checkpoint('zernike_gnn_checkpoint.pt')
|
||
|
|
results = optimizer.turbo_optimize(n_trials=5000)
|
||
|
|
|
||
|
|
# Get best designs
|
||
|
|
best = results.get_best(n=10)
|
||
|
|
|
||
|
|
# Validate with FEA
|
||
|
|
validated = optimizer.validate_with_fea(best, study_dir)
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Optional, Tuple, Any
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
import optuna
|
||
|
|
|
||
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||
|
|
from optimization_engine.gnn.zernike_gnn import create_model, load_model
|
||
|
|
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class GNNPrediction:
|
||
|
|
"""Single GNN prediction result."""
|
||
|
|
design_vars: Dict[str, float]
|
||
|
|
objectives: Dict[str, float]
|
||
|
|
z_displacement: Optional[np.ndarray] = None # [3000, 4] if stored
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict:
|
||
|
|
return {
|
||
|
|
'design_vars': self.design_vars,
|
||
|
|
'objectives': self.objectives,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class OptimizationResults:
|
||
|
|
"""Container for optimization results."""
|
||
|
|
predictions: List[GNNPrediction] = field(default_factory=list)
|
||
|
|
pareto_front: List[int] = field(default_factory=list) # Indices
|
||
|
|
|
||
|
|
def add(self, pred: GNNPrediction):
|
||
|
|
self.predictions.append(pred)
|
||
|
|
|
||
|
|
def get_best(self, n: int = 10, objective: str = 'rel_filtered_rms_40_vs_20') -> List[GNNPrediction]:
|
||
|
|
"""Get top N designs by a single objective."""
|
||
|
|
sorted_preds = sorted(self.predictions, key=lambda p: p.objectives.get(objective, float('inf')))
|
||
|
|
return sorted_preds[:n]
|
||
|
|
|
||
|
|
def get_pareto_front(self, objectives: List[str] = None) -> List[GNNPrediction]:
|
||
|
|
"""Get Pareto-optimal designs."""
|
||
|
|
if objectives is None:
|
||
|
|
objectives = ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']
|
||
|
|
|
||
|
|
# Extract objective values
|
||
|
|
obj_values = np.array([
|
||
|
|
[p.objectives.get(obj, float('inf')) for obj in objectives]
|
||
|
|
for p in self.predictions
|
||
|
|
])
|
||
|
|
|
||
|
|
# Find Pareto front (all objectives are minimized)
|
||
|
|
pareto_indices = []
|
||
|
|
for i in range(len(self.predictions)):
|
||
|
|
is_dominated = False
|
||
|
|
for j in range(len(self.predictions)):
|
||
|
|
if i != j:
|
||
|
|
# j dominates i if j is <= in all objectives and < in at least one
|
||
|
|
if np.all(obj_values[j] <= obj_values[i]) and np.any(obj_values[j] < obj_values[i]):
|
||
|
|
is_dominated = True
|
||
|
|
break
|
||
|
|
if not is_dominated:
|
||
|
|
pareto_indices.append(i)
|
||
|
|
|
||
|
|
self.pareto_front = pareto_indices
|
||
|
|
return [self.predictions[i] for i in pareto_indices]
|
||
|
|
|
||
|
|
def to_dataframe(self):
|
||
|
|
"""Convert to pandas DataFrame."""
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
rows = []
|
||
|
|
for i, pred in enumerate(self.predictions):
|
||
|
|
row = {'index': i}
|
||
|
|
row.update(pred.design_vars)
|
||
|
|
row.update({f'obj_{k}': v for k, v in pred.objectives.items()})
|
||
|
|
rows.append(row)
|
||
|
|
|
||
|
|
return pd.DataFrame(rows)
|
||
|
|
|
||
|
|
def save(self, path: Path):
|
||
|
|
"""Save results to JSON."""
|
||
|
|
data = {
|
||
|
|
'n_predictions': len(self.predictions),
|
||
|
|
'pareto_front_indices': self.pareto_front,
|
||
|
|
'predictions': [p.to_dict() for p in self.predictions],
|
||
|
|
'timestamp': datetime.now().isoformat(),
|
||
|
|
}
|
||
|
|
with open(path, 'w') as f:
|
||
|
|
json.dump(data, f, indent=2)
|
||
|
|
|
||
|
|
|
||
|
|
class ZernikeGNNOptimizer:
|
||
|
|
"""
|
||
|
|
GNN-based optimizer for Zernike mirror optimization.
|
||
|
|
|
||
|
|
Provides fast objective prediction using trained GNN surrogate.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
model: nn.Module,
|
||
|
|
polar_graph: PolarMirrorGraph,
|
||
|
|
design_names: List[str],
|
||
|
|
design_bounds: Dict[str, Tuple[float, float]],
|
||
|
|
design_mean: torch.Tensor,
|
||
|
|
design_std: torch.Tensor,
|
||
|
|
device: str = 'cpu',
|
||
|
|
disp_scale: float = 1.0
|
||
|
|
):
|
||
|
|
self.model = model.to(device)
|
||
|
|
self.model.eval()
|
||
|
|
self.polar_graph = polar_graph
|
||
|
|
self.design_names = design_names
|
||
|
|
self.design_bounds = design_bounds
|
||
|
|
self.design_mean = design_mean.to(device)
|
||
|
|
self.design_std = design_std.to(device)
|
||
|
|
self.device = torch.device(device)
|
||
|
|
self.disp_scale = disp_scale # Scaling factor from training
|
||
|
|
|
||
|
|
# Prepare fixed graph tensors
|
||
|
|
self.node_features = torch.tensor(
|
||
|
|
polar_graph.get_node_features(normalized=True),
|
||
|
|
dtype=torch.float32
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
self.edge_index = torch.tensor(
|
||
|
|
polar_graph.edge_index,
|
||
|
|
dtype=torch.long
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
self.edge_attr = torch.tensor(
|
||
|
|
polar_graph.get_edge_features(normalized=True),
|
||
|
|
dtype=torch.float32
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
# Objective computation layer (must be on same device as model)
|
||
|
|
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes=50).to(device)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_checkpoint(
|
||
|
|
cls,
|
||
|
|
checkpoint_path: Path,
|
||
|
|
config_path: Optional[Path] = None,
|
||
|
|
device: str = 'auto'
|
||
|
|
) -> 'ZernikeGNNOptimizer':
|
||
|
|
"""
|
||
|
|
Load optimizer from trained checkpoint.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpoint_path: Path to zernike_gnn_checkpoint.pt
|
||
|
|
config_path: Path to optimization_config.json (for design bounds)
|
||
|
|
device: Device to use ('cpu', 'cuda', 'auto')
|
||
|
|
"""
|
||
|
|
if device == 'auto':
|
||
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
|
|
||
|
|
checkpoint_path = Path(checkpoint_path)
|
||
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
||
|
|
|
||
|
|
# Create polar graph
|
||
|
|
polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||
|
|
|
||
|
|
# Create model - handle both old ('model_config') and new ('config') format
|
||
|
|
model_config = checkpoint.get('model_config') or checkpoint.get('config', {})
|
||
|
|
model = create_model(**model_config)
|
||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
|
||
|
|
# Get design info from checkpoint
|
||
|
|
design_mean = checkpoint['design_mean']
|
||
|
|
design_std = checkpoint['design_std']
|
||
|
|
disp_scale = checkpoint.get('disp_scale', 1.0) # Displacement scaling factor
|
||
|
|
|
||
|
|
# Try to get design names and bounds from config
|
||
|
|
design_names = []
|
||
|
|
design_bounds = {}
|
||
|
|
|
||
|
|
if config_path and Path(config_path).exists():
|
||
|
|
with open(config_path, 'r') as f:
|
||
|
|
config = json.load(f)
|
||
|
|
|
||
|
|
for var in config.get('design_variables', []):
|
||
|
|
name = var['name']
|
||
|
|
design_names.append(name)
|
||
|
|
design_bounds[name] = (var['min'], var['max'])
|
||
|
|
else:
|
||
|
|
# Use generic names based on checkpoint
|
||
|
|
n_vars = len(design_mean)
|
||
|
|
design_names = [f'var_{i}' for i in range(n_vars)]
|
||
|
|
# Default bounds (will be overridden if config provided)
|
||
|
|
for name in design_names:
|
||
|
|
design_bounds[name] = (-100, 100)
|
||
|
|
|
||
|
|
return cls(
|
||
|
|
model=model,
|
||
|
|
polar_graph=polar_graph,
|
||
|
|
design_names=design_names,
|
||
|
|
design_bounds=design_bounds,
|
||
|
|
design_mean=design_mean,
|
||
|
|
design_std=design_std,
|
||
|
|
device=device,
|
||
|
|
disp_scale=disp_scale
|
||
|
|
)
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def predict(self, design_vars: Dict[str, float], return_field: bool = False) -> GNNPrediction:
|
||
|
|
"""
|
||
|
|
Predict objectives for a single design.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
design_vars: Dict mapping variable names to values
|
||
|
|
return_field: Whether to include displacement field in result
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
GNNPrediction with objectives
|
||
|
|
"""
|
||
|
|
# Convert to tensor
|
||
|
|
design_values = [design_vars.get(name, 0.0) for name in self.design_names]
|
||
|
|
design_tensor = torch.tensor(design_values, dtype=torch.float32).to(self.device)
|
||
|
|
|
||
|
|
# Normalize
|
||
|
|
design_norm = (design_tensor - self.design_mean) / self.design_std
|
||
|
|
|
||
|
|
# Forward pass
|
||
|
|
z_disp_scaled = self.model(
|
||
|
|
self.node_features,
|
||
|
|
self.edge_index,
|
||
|
|
self.edge_attr,
|
||
|
|
design_norm
|
||
|
|
) # [3000, 4] in scaled units (μm if disp_scale=1e6)
|
||
|
|
|
||
|
|
# Convert back to mm before computing objectives
|
||
|
|
# During training: z_disp_mm * disp_scale = z_disp_scaled
|
||
|
|
# So: z_disp_mm = z_disp_scaled / disp_scale
|
||
|
|
z_disp_mm = z_disp_scaled / self.disp_scale
|
||
|
|
|
||
|
|
# Compute objectives (ZernikeObjectiveLayer expects mm input)
|
||
|
|
objectives = self.objective_layer(z_disp_mm)
|
||
|
|
|
||
|
|
# Objectives are now directly in nm (no additional scaling needed)
|
||
|
|
obj_dict = {
|
||
|
|
'rel_filtered_rms_40_vs_20': objectives['rel_filtered_rms_40_vs_20'].item(),
|
||
|
|
'rel_filtered_rms_60_vs_20': objectives['rel_filtered_rms_60_vs_20'].item(),
|
||
|
|
'mfg_90_optician_workload': objectives['mfg_90_optician_workload'].item(),
|
||
|
|
}
|
||
|
|
|
||
|
|
field_data = z_disp_mm.cpu().numpy() if return_field else None
|
||
|
|
|
||
|
|
return GNNPrediction(
|
||
|
|
design_vars=design_vars,
|
||
|
|
objectives=obj_dict,
|
||
|
|
z_displacement=field_data
|
||
|
|
)
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def predict_batch(self, designs: List[Dict[str, float]]) -> List[GNNPrediction]:
|
||
|
|
"""
|
||
|
|
Predict objectives for multiple designs (batched for efficiency).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
designs: List of design variable dicts
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of GNNPrediction
|
||
|
|
"""
|
||
|
|
results = []
|
||
|
|
for design in designs:
|
||
|
|
results.append(self.predict(design))
|
||
|
|
return results
|
||
|
|
|
||
|
|
def random_design(self) -> Dict[str, float]:
|
||
|
|
"""Generate a random design within bounds."""
|
||
|
|
design = {}
|
||
|
|
for name in self.design_names:
|
||
|
|
low, high = self.design_bounds.get(name, (-100, 100))
|
||
|
|
design[name] = np.random.uniform(low, high)
|
||
|
|
return design
|
||
|
|
|
||
|
|
def turbo_optimize(
|
||
|
|
self,
|
||
|
|
n_trials: int = 5000,
|
||
|
|
sampler: str = 'tpe',
|
||
|
|
seed: int = 42,
|
||
|
|
verbose: bool = True
|
||
|
|
) -> OptimizationResults:
|
||
|
|
"""
|
||
|
|
Run fast GNN-based optimization.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
n_trials: Number of GNN trials to run
|
||
|
|
sampler: Optuna sampler ('tpe', 'random', 'cmaes')
|
||
|
|
seed: Random seed
|
||
|
|
verbose: Print progress
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
OptimizationResults with all predictions
|
||
|
|
"""
|
||
|
|
np.random.seed(seed)
|
||
|
|
results = OptimizationResults()
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print("GNN TURBO OPTIMIZATION")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"Trials: {n_trials}")
|
||
|
|
print(f"Sampler: {sampler}")
|
||
|
|
print(f"Design variables: {len(self.design_names)}")
|
||
|
|
print(f"Device: {self.device}")
|
||
|
|
|
||
|
|
# Create Optuna study for smart sampling
|
||
|
|
if sampler == 'tpe':
|
||
|
|
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
|
||
|
|
elif sampler == 'random':
|
||
|
|
optuna_sampler = optuna.samplers.RandomSampler(seed=seed)
|
||
|
|
elif sampler == 'cmaes':
|
||
|
|
optuna_sampler = optuna.samplers.CmaEsSampler(seed=seed)
|
||
|
|
else:
|
||
|
|
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
|
||
|
|
|
||
|
|
study = optuna.create_study(
|
||
|
|
directions=['minimize', 'minimize', 'minimize'], # 3 objectives
|
||
|
|
sampler=optuna_sampler
|
||
|
|
)
|
||
|
|
|
||
|
|
start_time = datetime.now()
|
||
|
|
|
||
|
|
def objective(trial):
|
||
|
|
# Sample design
|
||
|
|
design = {}
|
||
|
|
for name in self.design_names:
|
||
|
|
low, high = self.design_bounds.get(name, (-100, 100))
|
||
|
|
design[name] = trial.suggest_float(name, low, high)
|
||
|
|
|
||
|
|
# Predict with GNN
|
||
|
|
pred = self.predict(design)
|
||
|
|
results.add(pred)
|
||
|
|
|
||
|
|
return (
|
||
|
|
pred.objectives['rel_filtered_rms_40_vs_20'],
|
||
|
|
pred.objectives['rel_filtered_rms_60_vs_20'],
|
||
|
|
pred.objectives['mfg_90_optician_workload']
|
||
|
|
)
|
||
|
|
|
||
|
|
# Run optimization
|
||
|
|
if verbose:
|
||
|
|
print(f"\nRunning {n_trials} GNN trials...")
|
||
|
|
|
||
|
|
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||
|
|
study.optimize(objective, n_trials=n_trials, show_progress_bar=verbose)
|
||
|
|
|
||
|
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f"\nCompleted in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
|
||
|
|
|
||
|
|
# Compute Pareto front
|
||
|
|
pareto = results.get_pareto_front()
|
||
|
|
print(f"Pareto front: {len(pareto)} designs")
|
||
|
|
|
||
|
|
# Best by each objective
|
||
|
|
print("\nBest by objective:")
|
||
|
|
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||
|
|
best = results.get_best(n=1, objective=obj)[0]
|
||
|
|
print(f" {obj}: {best.objectives[obj]:.2f} nm")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def validate_with_fea(
|
||
|
|
self,
|
||
|
|
candidates: List[GNNPrediction],
|
||
|
|
study_dir: Path,
|
||
|
|
verbose: bool = True,
|
||
|
|
start_trial_num: int = 9000
|
||
|
|
) -> List[Dict]:
|
||
|
|
"""
|
||
|
|
Validate GNN predictions with actual FEA.
|
||
|
|
|
||
|
|
This runs the full NX + Nastran workflow on each candidate
|
||
|
|
to get true objective values.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
candidates: GNN predictions to validate
|
||
|
|
study_dir: Path to study directory (for config and scripts)
|
||
|
|
verbose: Print progress
|
||
|
|
start_trial_num: Starting trial number for iteration folders
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of dicts with 'gnn' and 'fea' objectives for comparison
|
||
|
|
"""
|
||
|
|
import time
|
||
|
|
import re
|
||
|
|
from optimization_engine.nx_solver import NXSolver
|
||
|
|
from optimization_engine.extractors import ZernikeExtractor
|
||
|
|
|
||
|
|
study_dir = Path(study_dir)
|
||
|
|
config_path = study_dir / "1_setup" / "optimization_config.json"
|
||
|
|
model_dir = study_dir / "1_setup" / "model"
|
||
|
|
iterations_dir = study_dir / "2_iterations"
|
||
|
|
|
||
|
|
# Load config
|
||
|
|
if not config_path.exists():
|
||
|
|
raise FileNotFoundError(f"Config not found: {config_path}")
|
||
|
|
|
||
|
|
with open(config_path) as f:
|
||
|
|
config = json.load(f)
|
||
|
|
|
||
|
|
# Initialize NX Solver
|
||
|
|
nx_settings = config.get('nx_settings', {})
|
||
|
|
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
|
||
|
|
version_match = re.search(r'NX(\d+)', nx_install_dir)
|
||
|
|
nastran_version = version_match.group(1) if version_match else "2506"
|
||
|
|
|
||
|
|
solver = NXSolver(
|
||
|
|
master_model_dir=str(model_dir),
|
||
|
|
nx_install_dir=nx_install_dir,
|
||
|
|
nastran_version=nastran_version,
|
||
|
|
timeout=nx_settings.get('simulation_timeout_s', 600),
|
||
|
|
use_iteration_folders=True,
|
||
|
|
study_name="gnn_validation"
|
||
|
|
)
|
||
|
|
|
||
|
|
iterations_dir.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
results = []
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print("FEA VALIDATION OF GNN PREDICTIONS")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"Validating {len(candidates)} candidates")
|
||
|
|
print(f"Study: {study_dir.name}")
|
||
|
|
|
||
|
|
for i, candidate in enumerate(candidates):
|
||
|
|
trial_num = start_trial_num + i
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f"\n[{i+1}/{len(candidates)}] Trial {trial_num}")
|
||
|
|
print(f" GNN predicted: 40vs20={candidate.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||
|
|
|
||
|
|
# Build expression updates from design variables
|
||
|
|
expressions = {}
|
||
|
|
for var in config.get('design_variables', []):
|
||
|
|
var_name = var['name']
|
||
|
|
expr_name = var.get('expression_name', var_name)
|
||
|
|
if var_name in candidate.design_vars:
|
||
|
|
expressions[expr_name] = candidate.design_vars[var_name]
|
||
|
|
|
||
|
|
# Create iteration folder with model copies
|
||
|
|
try:
|
||
|
|
iter_folder = solver.create_iteration_folder(
|
||
|
|
iterations_base_dir=iterations_dir,
|
||
|
|
iteration_number=trial_num,
|
||
|
|
expression_updates=expressions
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
if verbose:
|
||
|
|
print(f" ERROR creating iteration folder: {e}")
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': None,
|
||
|
|
'status': 'error',
|
||
|
|
'error': str(e)
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Run simulation
|
||
|
|
sim_file = iter_folder / nx_settings.get('sim_file', 'ASSY_M1_assyfem1_sim1.sim')
|
||
|
|
solution_name = nx_settings.get('solution_name', 'Solution 1')
|
||
|
|
|
||
|
|
t_start = time.time()
|
||
|
|
try:
|
||
|
|
solve_result = solver.run_simulation(
|
||
|
|
sim_file=sim_file,
|
||
|
|
working_dir=iter_folder,
|
||
|
|
expression_updates=expressions,
|
||
|
|
solution_name=solution_name,
|
||
|
|
cleanup=False
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
if verbose:
|
||
|
|
print(f" ERROR in simulation: {e}")
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': None,
|
||
|
|
'status': 'solve_error',
|
||
|
|
'error': str(e)
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
solve_time = time.time() - t_start
|
||
|
|
|
||
|
|
if not solve_result['success']:
|
||
|
|
if verbose:
|
||
|
|
print(f" Solve FAILED: {solve_result.get('errors', ['Unknown'])}")
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': None,
|
||
|
|
'status': 'solve_failed',
|
||
|
|
'errors': solve_result.get('errors', [])
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f" Solved in {solve_time:.1f}s")
|
||
|
|
|
||
|
|
# Extract objectives using ZernikeExtractor
|
||
|
|
op2_path = solve_result['op2_file']
|
||
|
|
if op2_path is None or not Path(op2_path).exists():
|
||
|
|
if verbose:
|
||
|
|
print(f" ERROR: OP2 file not found")
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': None,
|
||
|
|
'status': 'no_op2',
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
zernike_settings = config.get('zernike_settings', {})
|
||
|
|
extractor = ZernikeExtractor(
|
||
|
|
op2_path,
|
||
|
|
bdf_path=None,
|
||
|
|
displacement_unit=zernike_settings.get('displacement_unit', 'mm'),
|
||
|
|
n_modes=zernike_settings.get('n_modes', 50),
|
||
|
|
filter_orders=zernike_settings.get('filter_low_orders', 4)
|
||
|
|
)
|
||
|
|
|
||
|
|
ref = zernike_settings.get('reference_subcase', '2')
|
||
|
|
|
||
|
|
# Extract objectives: 40 vs 20, 60 vs 20, mfg 90
|
||
|
|
rel_40 = extractor.extract_relative("3", ref)
|
||
|
|
rel_60 = extractor.extract_relative("4", ref)
|
||
|
|
rel_90 = extractor.extract_relative("1", ref)
|
||
|
|
|
||
|
|
fea_objectives = {
|
||
|
|
'rel_filtered_rms_40_vs_20': rel_40['relative_filtered_rms_nm'],
|
||
|
|
'rel_filtered_rms_60_vs_20': rel_60['relative_filtered_rms_nm'],
|
||
|
|
'mfg_90_optician_workload': rel_90['relative_rms_filter_j1to3'],
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
if verbose:
|
||
|
|
print(f" ERROR in Zernike extraction: {e}")
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': None,
|
||
|
|
'status': 'extraction_error',
|
||
|
|
'error': str(e)
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Compute errors
|
||
|
|
errors = {}
|
||
|
|
for obj_name in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||
|
|
gnn_val = candidate.objectives[obj_name]
|
||
|
|
fea_val = fea_objectives[obj_name]
|
||
|
|
errors[f'{obj_name}_abs_error'] = abs(gnn_val - fea_val)
|
||
|
|
errors[f'{obj_name}_pct_error'] = 100 * abs(gnn_val - fea_val) / max(fea_val, 0.01)
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
print(f" FEA results:")
|
||
|
|
print(f" 40vs20: {fea_objectives['rel_filtered_rms_40_vs_20']:.2f} nm "
|
||
|
|
f"(GNN: {candidate.objectives['rel_filtered_rms_40_vs_20']:.2f}, "
|
||
|
|
f"err: {errors['rel_filtered_rms_40_vs_20_pct_error']:.1f}%)")
|
||
|
|
print(f" 60vs20: {fea_objectives['rel_filtered_rms_60_vs_20']:.2f} nm "
|
||
|
|
f"(GNN: {candidate.objectives['rel_filtered_rms_60_vs_20']:.2f}, "
|
||
|
|
f"err: {errors['rel_filtered_rms_60_vs_20_pct_error']:.1f}%)")
|
||
|
|
print(f" mfg90: {fea_objectives['mfg_90_optician_workload']:.2f} nm "
|
||
|
|
f"(GNN: {candidate.objectives['mfg_90_optician_workload']:.2f}, "
|
||
|
|
f"err: {errors['mfg_90_optician_workload_pct_error']:.1f}%)")
|
||
|
|
|
||
|
|
results.append({
|
||
|
|
'design': candidate.design_vars,
|
||
|
|
'gnn_objectives': candidate.objectives,
|
||
|
|
'fea_objectives': fea_objectives,
|
||
|
|
'errors': errors,
|
||
|
|
'solve_time': solve_time,
|
||
|
|
'trial_num': trial_num,
|
||
|
|
'status': 'success'
|
||
|
|
})
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
if verbose:
|
||
|
|
successful = [r for r in results if r['status'] == 'success']
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"VALIDATION SUMMARY")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"Successful: {len(successful)}/{len(candidates)}")
|
||
|
|
|
||
|
|
if successful:
|
||
|
|
avg_errors = {}
|
||
|
|
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||
|
|
avg_errors[obj] = np.mean([r['errors'][f'{obj}_pct_error'] for r in successful])
|
||
|
|
|
||
|
|
print(f"\nAverage prediction errors:")
|
||
|
|
print(f" 40 vs 20: {avg_errors['rel_filtered_rms_40_vs_20']:.1f}%")
|
||
|
|
print(f" 60 vs 20: {avg_errors['rel_filtered_rms_60_vs_20']:.1f}%")
|
||
|
|
print(f" mfg 90: {avg_errors['mfg_90_optician_workload']:.1f}%")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def save_validation_report(
|
||
|
|
self,
|
||
|
|
validation_results: List[Dict],
|
||
|
|
output_path: Path
|
||
|
|
):
|
||
|
|
"""Save validation results to JSON file."""
|
||
|
|
report = {
|
||
|
|
'timestamp': datetime.now().isoformat(),
|
||
|
|
'n_candidates': len(validation_results),
|
||
|
|
'n_successful': len([r for r in validation_results if r['status'] == 'success']),
|
||
|
|
'results': validation_results,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Compute summary statistics if we have successful results
|
||
|
|
successful = [r for r in validation_results if r['status'] == 'success']
|
||
|
|
if successful:
|
||
|
|
avg_errors = {}
|
||
|
|
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||
|
|
errors = [r['errors'][f'{obj}_pct_error'] for r in successful]
|
||
|
|
avg_errors[obj] = {
|
||
|
|
'mean_pct': float(np.mean(errors)),
|
||
|
|
'std_pct': float(np.std(errors)),
|
||
|
|
'max_pct': float(np.max(errors)),
|
||
|
|
}
|
||
|
|
report['error_summary'] = avg_errors
|
||
|
|
|
||
|
|
with open(output_path, 'w') as f:
|
||
|
|
json.dump(report, f, indent=2)
|
||
|
|
|
||
|
|
print(f"Validation report saved to: {output_path}")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Example usage of GNN optimizer."""
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description='GNN-based Zernike optimization')
|
||
|
|
parser.add_argument('checkpoint', type=Path, help='Path to GNN checkpoint')
|
||
|
|
parser.add_argument('--config', type=Path, help='Path to optimization_config.json')
|
||
|
|
parser.add_argument('--trials', type=int, default=5000, help='Number of GNN trials')
|
||
|
|
parser.add_argument('--output', '-o', type=Path, help='Output results JSON')
|
||
|
|
parser.add_argument('--top-n', type=int, default=20, help='Number of top candidates to show')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Load optimizer
|
||
|
|
print(f"Loading GNN from {args.checkpoint}...")
|
||
|
|
optimizer = ZernikeGNNOptimizer.from_checkpoint(
|
||
|
|
args.checkpoint,
|
||
|
|
config_path=args.config
|
||
|
|
)
|
||
|
|
|
||
|
|
# Run turbo optimization
|
||
|
|
results = optimizer.turbo_optimize(n_trials=args.trials)
|
||
|
|
|
||
|
|
# Show top candidates
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"TOP {args.top_n} CANDIDATES (by rel_filtered_rms_40_vs_20)")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
|
||
|
|
top = results.get_best(n=args.top_n, objective='rel_filtered_rms_40_vs_20')
|
||
|
|
for i, pred in enumerate(top):
|
||
|
|
print(f"\n#{i+1}:")
|
||
|
|
print(f" 40 vs 20: {pred.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||
|
|
print(f" 60 vs 20: {pred.objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
|
||
|
|
print(f" mfg_90: {pred.objectives['mfg_90_optician_workload']:.2f} nm")
|
||
|
|
|
||
|
|
# Save results
|
||
|
|
if args.output:
|
||
|
|
results.save(args.output)
|
||
|
|
print(f"\nResults saved to {args.output}")
|
||
|
|
|
||
|
|
# Show Pareto front
|
||
|
|
pareto = results.get_pareto_front()
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"PARETO FRONT: {len(pareto)} designs")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
|
||
|
|
for i, pred in enumerate(pareto[:10]): # Show first 10
|
||
|
|
print(f" [{i+1}] 40vs20={pred.objectives['rel_filtered_rms_40_vs_20']:.1f}, "
|
||
|
|
f"60vs20={pred.objectives['rel_filtered_rms_60_vs_20']:.1f}, "
|
||
|
|
f"mfg={pred.objectives['mfg_90_optician_workload']:.1f}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|