Files
Atomizer/optimization_engine/gnn/train_zernike_gnn.py
Antoine 96b196de58 feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00

601 lines
20 KiB
Python

"""
Training Pipeline for ZernikeGNN
=================================
This module provides the complete training pipeline for the Zernike GNN surrogate.
Training Flow:
1. Load displacement field data from gnn_data/ folders
2. Interpolate to fixed polar grid
3. Normalize inputs (design vars) and outputs (displacements)
4. Train with multi-task loss (field + objectives)
5. Validate on held-out data
6. Save best model checkpoint
Usage:
# Command line
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
# Python API
from optimization_engine.gnn.train_zernike_gnn import ZernikeGNNTrainer
trainer = ZernikeGNNTrainer(['V11', 'V12'])
trainer.train(epochs=200)
trainer.save_checkpoint('model.pt')
"""
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
from optimization_engine.gnn.zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer, ZernikeRMSLoss
class MirrorDataset(Dataset):
"""PyTorch Dataset for mirror displacement fields."""
def __init__(
self,
data_list: List[Dict[str, Any]],
design_mean: Optional[torch.Tensor] = None,
design_std: Optional[torch.Tensor] = None,
disp_scale: float = 1e6 # mm → μm for numerical stability
):
"""
Args:
data_list: Output from create_mirror_dataset()
design_mean: Mean for design normalization (computed if None)
design_std: Std for design normalization (computed if None)
disp_scale: Scale factor for displacements
"""
self.data_list = data_list
self.disp_scale = disp_scale
# Stack all design variables for normalization
all_designs = np.stack([d['design_vars'] for d in data_list])
if design_mean is None:
self.design_mean = torch.tensor(np.mean(all_designs, axis=0), dtype=torch.float32)
else:
self.design_mean = design_mean
if design_std is None:
self.design_std = torch.tensor(np.std(all_designs, axis=0) + 1e-6, dtype=torch.float32)
else:
self.design_std = design_std
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
item = self.data_list[idx]
# Normalize design variables
design = torch.tensor(item['design_vars'], dtype=torch.float32)
design_norm = (design - self.design_mean) / self.design_std
# Scale displacements for numerical stability
z_disp = torch.tensor(item['z_displacement'], dtype=torch.float32)
z_disp_scaled = z_disp * self.disp_scale
return {
'design': design_norm,
'design_raw': design,
'z_displacement': z_disp_scaled,
'trial_number': item['trial_number'],
}
class ZernikeGNNTrainer:
"""
Complete training pipeline for ZernikeGNN.
Handles:
- Data loading and preprocessing
- Model initialization
- Training loop with validation
- Checkpointing
- Metrics tracking
"""
def __init__(
self,
study_versions: List[str],
base_dir: Optional[Path] = None,
model_type: str = 'full',
hidden_dim: int = 128,
n_layers: int = 6,
device: str = 'auto'
):
"""
Args:
study_versions: List of study versions (e.g., ['V11', 'V12'])
base_dir: Base Atomizer directory
model_type: 'full' or 'lite'
hidden_dim: Model hidden dimension
n_layers: Number of message passing layers
device: 'cpu', 'cuda', or 'auto'
"""
if base_dir is None:
base_dir = Path(__file__).parent.parent.parent
self.base_dir = Path(base_dir)
self.study_versions = study_versions
# Determine device
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"[TRAINER] Device: {self.device}", flush=True)
# Create polar graph (fixed structure)
self.polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"[TRAINER] Polar graph: {self.polar_graph.n_nodes} nodes, {self.polar_graph.edge_index.shape[1]} edges", flush=True)
# Prepare graph tensors
self.node_features = torch.tensor(
self.polar_graph.get_node_features(normalized=True),
dtype=torch.float32
).to(self.device)
self.edge_index = torch.tensor(
self.polar_graph.edge_index,
dtype=torch.long
).to(self.device)
self.edge_attr = torch.tensor(
self.polar_graph.get_edge_features(normalized=True),
dtype=torch.float32
).to(self.device)
# Load data
self._load_data()
# Create model
self.model_config = {
'model_type': model_type,
'n_design_vars': len(self.train_dataset.data_list[0]['design_vars']),
'n_subcases': 4,
'hidden_dim': hidden_dim,
'n_layers': n_layers,
}
self.model = create_model(**self.model_config).to(self.device)
print(f"[TRAINER] Model: {self.model.__class__.__name__} with {sum(p.numel() for p in self.model.parameters()):,} parameters", flush=True)
# Objective layer for evaluation
self.objective_layer = ZernikeObjectiveLayer(self.polar_graph, n_modes=50)
# Training state
self.best_val_loss = float('inf')
self.history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
def _load_data(self):
"""Load and prepare training data from studies."""
all_data = []
for version in self.study_versions:
study_dir = self.base_dir / "studies" / f"m1_mirror_adaptive_{version}"
if not study_dir.exists():
print(f"[WARN] Study not found: {study_dir}", flush=True)
continue
print(f"[TRAINER] Loading data from {study_dir.name}...", flush=True)
dataset = create_mirror_dataset(study_dir, polar_graph=self.polar_graph, verbose=True)
print(f"[TRAINER] Loaded {len(dataset)} samples", flush=True)
all_data.extend(dataset)
if not all_data:
raise ValueError("No data loaded!")
print(f"[TRAINER] Total samples: {len(all_data)}", flush=True)
# Train/val split (80/20)
np.random.seed(42)
indices = np.random.permutation(len(all_data))
n_train = int(0.8 * len(all_data))
train_data = [all_data[i] for i in indices[:n_train]]
val_data = [all_data[i] for i in indices[n_train:]]
print(f"[TRAINER] Train: {len(train_data)}, Val: {len(val_data)}", flush=True)
# Create datasets
self.train_dataset = MirrorDataset(train_data)
self.val_dataset = MirrorDataset(
val_data,
design_mean=self.train_dataset.design_mean,
design_std=self.train_dataset.design_std
)
# Store normalization params for inference
self.design_mean = self.train_dataset.design_mean
self.design_std = self.train_dataset.design_std
self.disp_scale = self.train_dataset.disp_scale
def train(
self,
epochs: int = 200,
lr: float = 1e-3,
weight_decay: float = 1e-5,
batch_size: int = 4,
field_weight: float = 1.0,
patience: int = 50,
verbose: bool = True
):
"""
Train the GNN model.
Args:
epochs: Number of training epochs
lr: Learning rate
weight_decay: Weight decay for regularization
batch_size: Training batch size
field_weight: Weight for field loss
patience: Early stopping patience
verbose: Print training progress
"""
# Create data loaders
train_loader = DataLoader(
self.train_dataset, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(
self.val_dataset, batch_size=batch_size, shuffle=False
)
# Optimizer
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Training loop
no_improve = 0
for epoch in range(epochs):
# Training
self.model.train()
train_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
# Move to device
design = batch['design'].to(self.device)
z_disp_true = batch['z_displacement'].to(self.device)
# Forward pass for each sample in batch
batch_loss = 0.0
for i in range(design.size(0)):
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design[i]
)
# MSE loss on displacement field
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
batch_loss = batch_loss + loss
batch_loss = batch_loss / design.size(0)
batch_loss.backward()
optimizer.step()
train_loss += batch_loss.item()
train_loss /= len(train_loader)
scheduler.step()
# Validation
val_loss, val_metrics = self._validate(val_loader)
# Track history
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['val_r2'].append(val_metrics.get('r2_mean', 0))
# Early stopping
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
no_improve = 0
else:
no_improve += 1
if verbose and epoch % 10 == 0:
print(f"[Epoch {epoch:3d}] Train: {train_loss:.6f}, Val: {val_loss:.6f}, "
f"R²: {val_metrics.get('r2_mean', 0):.4f}, LR: {scheduler.get_last_lr()[0]:.2e}", flush=True)
if no_improve >= patience:
print(f"[TRAINER] Early stopping at epoch {epoch}", flush=True)
break
# Restore best model
self.model.load_state_dict(self.best_model_state)
print(f"[TRAINER] Training complete. Best val loss: {self.best_val_loss:.6f}", flush=True)
def _validate(self, val_loader: DataLoader) -> Tuple[float, Dict[str, float]]:
"""Run validation and compute metrics."""
self.model.eval()
val_loss = 0.0
all_pred = []
all_true = []
with torch.no_grad():
for batch in val_loader:
design = batch['design'].to(self.device)
z_disp_true = batch['z_displacement'].to(self.device)
for i in range(design.size(0)):
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design[i]
)
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
val_loss += loss.item()
all_pred.append(z_disp_pred.cpu())
all_true.append(z_disp_true[i].cpu())
val_loss /= len(self.val_dataset)
# Compute R² for each subcase
all_pred = torch.stack(all_pred) # [n_val, n_nodes, 4]
all_true = torch.stack(all_true)
r2_per_subcase = []
for sc in range(4):
pred_flat = all_pred[:, :, sc].flatten()
true_flat = all_true[:, :, sc].flatten()
ss_res = ((true_flat - pred_flat) ** 2).sum()
ss_tot = ((true_flat - true_flat.mean()) ** 2).sum()
r2 = 1 - ss_res / (ss_tot + 1e-8)
r2_per_subcase.append(r2.item())
metrics = {
'r2_mean': np.mean(r2_per_subcase),
'r2_per_subcase': r2_per_subcase,
}
return val_loss, metrics
def evaluate_objectives(self) -> Dict[str, Any]:
"""
Evaluate objective prediction accuracy on validation set.
Returns:
Dictionary with per-objective metrics
"""
self.model.eval()
obj_pred_all = {k: [] for k in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']}
obj_true_all = {k: [] for k in obj_pred_all}
# Move objective layer to CPU for now (small dataset)
with torch.no_grad():
for i in range(len(self.val_dataset)):
item = self.val_dataset[i]
design = item['design'].to(self.device)
z_disp_true = item['z_displacement'] # Already scaled
# Predict
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design
).cpu()
# Unscale for objective computation
z_disp_pred_mm = z_disp_pred / self.disp_scale
z_disp_true_mm = z_disp_true / self.disp_scale
# Compute objectives
obj_pred = self.objective_layer(z_disp_pred_mm)
obj_true = self.objective_layer(z_disp_true_mm)
for k in obj_pred_all:
obj_pred_all[k].append(obj_pred[k].item())
obj_true_all[k].append(obj_true[k].item())
# Compute metrics per objective
results = {}
for k in obj_pred_all:
pred = np.array(obj_pred_all[k])
true = np.array(obj_true_all[k])
mae = np.mean(np.abs(pred - true))
mape = np.mean(np.abs(pred - true) / (np.abs(true) + 1e-6)) * 100
ss_res = np.sum((true - pred) ** 2)
ss_tot = np.sum((true - np.mean(true)) ** 2)
r2 = 1 - ss_res / (ss_tot + 1e-8)
results[k] = {
'mae': mae,
'mape': mape,
'r2': r2,
'pred_range': [pred.min(), pred.max()],
'true_range': [true.min(), true.max()],
}
return results
def save_checkpoint(self, path: Path) -> None:
"""Save model checkpoint."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
'model_state_dict': self.model.state_dict(),
'config': self.model_config,
'design_mean': self.design_mean,
'design_std': self.design_std,
'disp_scale': self.disp_scale,
'history': self.history,
'best_val_loss': self.best_val_loss,
'study_versions': self.study_versions,
'timestamp': datetime.now().isoformat(),
}
torch.save(checkpoint, path)
print(f"[TRAINER] Saved checkpoint to {path}", flush=True)
@classmethod
def load_checkpoint(cls, path: Path, device: str = 'auto') -> 'ZernikeGNNTrainer':
"""Load trainer from checkpoint."""
checkpoint = torch.load(path, map_location='cpu')
# Create trainer with same config
trainer = cls(
study_versions=checkpoint['study_versions'],
model_type=checkpoint['config']['model_type'],
hidden_dim=checkpoint['config']['hidden_dim'],
n_layers=checkpoint['config']['n_layers'],
device=device,
)
# Load model weights
trainer.model.load_state_dict(checkpoint['model_state_dict'])
# Restore normalization
trainer.design_mean = checkpoint['design_mean']
trainer.design_std = checkpoint['design_std']
trainer.disp_scale = checkpoint['disp_scale']
# Restore history
trainer.history = checkpoint['history']
trainer.best_val_loss = checkpoint['best_val_loss']
return trainer
def predict(self, design_vars: Dict[str, float]) -> Dict[str, Any]:
"""
Make prediction for new design.
Args:
design_vars: Dictionary of design parameter values
Returns:
Dictionary with displacement field and objectives
"""
self.model.eval()
# Convert to tensor
design_names = self.train_dataset.data_list[0]['design_names']
design = torch.tensor(
[design_vars[name] for name in design_names],
dtype=torch.float32
)
# Normalize
design_norm = (design - self.design_mean) / self.design_std
with torch.no_grad():
z_disp_scaled = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design_norm.to(self.device)
).cpu()
# Unscale
z_disp_mm = z_disp_scaled / self.disp_scale
# Compute objectives
objectives = self.objective_layer(z_disp_mm)
return {
'z_displacement': z_disp_mm.numpy(),
'objectives': {k: v.item() for k, v in objectives.items()},
}
# =============================================================================
# CLI
# =============================================================================
def main():
import argparse
parser = argparse.ArgumentParser(description='Train ZernikeGNN surrogate')
parser.add_argument('studies', nargs='+', help='Study versions (e.g., V11 V12)')
parser.add_argument('--epochs', type=int, default=200, help='Training epochs')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension')
parser.add_argument('--n-layers', type=int, default=6, help='Message passing layers')
parser.add_argument('--model-type', choices=['full', 'lite'], default='full')
parser.add_argument('--output', '-o', type=Path, help='Output checkpoint path')
parser.add_argument('--device', default='auto', help='Device (cpu, cuda, auto)')
args = parser.parse_args()
# Create trainer
print("="*60, flush=True)
print("ZERNIKE GNN TRAINING", flush=True)
print("="*60, flush=True)
trainer = ZernikeGNNTrainer(
study_versions=args.studies,
model_type=args.model_type,
hidden_dim=args.hidden_dim,
n_layers=args.n_layers,
device=args.device,
)
# Train
trainer.train(
epochs=args.epochs,
lr=args.lr,
batch_size=args.batch_size,
)
# Evaluate objectives
print("\n--- Objective Prediction Evaluation ---", flush=True)
obj_results = trainer.evaluate_objectives()
for k, v in obj_results.items():
print(f"\n{k}:", flush=True)
print(f" MAE: {v['mae']:.2f} nm", flush=True)
print(f" MAPE: {v['mape']:.1f}%", flush=True)
print(f" R²: {v['r2']:.4f}", flush=True)
# Save checkpoint
if args.output:
output_path = args.output
else:
output_path = Path("zernike_gnn_checkpoint.pt")
trainer.save_checkpoint(output_path)
print("\n" + "="*60, flush=True)
print("✓ Training complete!", flush=True)
print("="*60, flush=True)
if __name__ == '__main__':
main()