""" Training Pipeline for ZernikeGNN ================================= This module provides the complete training pipeline for the Zernike GNN surrogate. Training Flow: 1. Load displacement field data from gnn_data/ folders 2. Interpolate to fixed polar grid 3. Normalize inputs (design vars) and outputs (displacements) 4. Train with multi-task loss (field + objectives) 5. Validate on held-out data 6. Save best model checkpoint Usage: # Command line python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200 # Python API from optimization_engine.gnn.train_zernike_gnn import ZernikeGNNTrainer trainer = ZernikeGNNTrainer(['V11', 'V12']) trainer.train(epochs=200) trainer.save_checkpoint('model.pt') """ import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from datetime import datetime import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset from optimization_engine.gnn.zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer, ZernikeRMSLoss class MirrorDataset(Dataset): """PyTorch Dataset for mirror displacement fields.""" def __init__( self, data_list: List[Dict[str, Any]], design_mean: Optional[torch.Tensor] = None, design_std: Optional[torch.Tensor] = None, disp_scale: float = 1e6 # mm → μm for numerical stability ): """ Args: data_list: Output from create_mirror_dataset() design_mean: Mean for design normalization (computed if None) design_std: Std for design normalization (computed if None) disp_scale: Scale factor for displacements """ self.data_list = data_list self.disp_scale = disp_scale # Stack all design variables for normalization all_designs = np.stack([d['design_vars'] for d in data_list]) if design_mean is None: self.design_mean = torch.tensor(np.mean(all_designs, axis=0), dtype=torch.float32) else: self.design_mean = design_mean if design_std is None: self.design_std = torch.tensor(np.std(all_designs, axis=0) + 1e-6, dtype=torch.float32) else: self.design_std = design_std def __len__(self) -> int: return len(self.data_list) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: item = self.data_list[idx] # Normalize design variables design = torch.tensor(item['design_vars'], dtype=torch.float32) design_norm = (design - self.design_mean) / self.design_std # Scale displacements for numerical stability z_disp = torch.tensor(item['z_displacement'], dtype=torch.float32) z_disp_scaled = z_disp * self.disp_scale return { 'design': design_norm, 'design_raw': design, 'z_displacement': z_disp_scaled, 'trial_number': item['trial_number'], } class ZernikeGNNTrainer: """ Complete training pipeline for ZernikeGNN. Handles: - Data loading and preprocessing - Model initialization - Training loop with validation - Checkpointing - Metrics tracking """ def __init__( self, study_versions: List[str], base_dir: Optional[Path] = None, model_type: str = 'full', hidden_dim: int = 128, n_layers: int = 6, device: str = 'auto' ): """ Args: study_versions: List of study versions (e.g., ['V11', 'V12']) base_dir: Base Atomizer directory model_type: 'full' or 'lite' hidden_dim: Model hidden dimension n_layers: Number of message passing layers device: 'cpu', 'cuda', or 'auto' """ if base_dir is None: base_dir = Path(__file__).parent.parent.parent self.base_dir = Path(base_dir) self.study_versions = study_versions # Determine device if device == 'auto': self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) print(f"[TRAINER] Device: {self.device}", flush=True) # Create polar graph (fixed structure) self.polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60) print(f"[TRAINER] Polar graph: {self.polar_graph.n_nodes} nodes, {self.polar_graph.edge_index.shape[1]} edges", flush=True) # Prepare graph tensors self.node_features = torch.tensor( self.polar_graph.get_node_features(normalized=True), dtype=torch.float32 ).to(self.device) self.edge_index = torch.tensor( self.polar_graph.edge_index, dtype=torch.long ).to(self.device) self.edge_attr = torch.tensor( self.polar_graph.get_edge_features(normalized=True), dtype=torch.float32 ).to(self.device) # Load data self._load_data() # Create model self.model_config = { 'model_type': model_type, 'n_design_vars': len(self.train_dataset.data_list[0]['design_vars']), 'n_subcases': 4, 'hidden_dim': hidden_dim, 'n_layers': n_layers, } self.model = create_model(**self.model_config).to(self.device) print(f"[TRAINER] Model: {self.model.__class__.__name__} with {sum(p.numel() for p in self.model.parameters()):,} parameters", flush=True) # Objective layer for evaluation self.objective_layer = ZernikeObjectiveLayer(self.polar_graph, n_modes=50) # Training state self.best_val_loss = float('inf') self.history = {'train_loss': [], 'val_loss': [], 'val_r2': []} def _load_data(self): """Load and prepare training data from studies.""" all_data = [] for version in self.study_versions: study_dir = self.base_dir / "studies" / f"m1_mirror_adaptive_{version}" if not study_dir.exists(): print(f"[WARN] Study not found: {study_dir}", flush=True) continue print(f"[TRAINER] Loading data from {study_dir.name}...", flush=True) dataset = create_mirror_dataset(study_dir, polar_graph=self.polar_graph, verbose=True) print(f"[TRAINER] Loaded {len(dataset)} samples", flush=True) all_data.extend(dataset) if not all_data: raise ValueError("No data loaded!") print(f"[TRAINER] Total samples: {len(all_data)}", flush=True) # Train/val split (80/20) np.random.seed(42) indices = np.random.permutation(len(all_data)) n_train = int(0.8 * len(all_data)) train_data = [all_data[i] for i in indices[:n_train]] val_data = [all_data[i] for i in indices[n_train:]] print(f"[TRAINER] Train: {len(train_data)}, Val: {len(val_data)}", flush=True) # Create datasets self.train_dataset = MirrorDataset(train_data) self.val_dataset = MirrorDataset( val_data, design_mean=self.train_dataset.design_mean, design_std=self.train_dataset.design_std ) # Store normalization params for inference self.design_mean = self.train_dataset.design_mean self.design_std = self.train_dataset.design_std self.disp_scale = self.train_dataset.disp_scale def train( self, epochs: int = 200, lr: float = 1e-3, weight_decay: float = 1e-5, batch_size: int = 4, field_weight: float = 1.0, patience: int = 50, verbose: bool = True ): """ Train the GNN model. Args: epochs: Number of training epochs lr: Learning rate weight_decay: Weight decay for regularization batch_size: Training batch size field_weight: Weight for field loss patience: Early stopping patience verbose: Print training progress """ # Create data loaders train_loader = DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True ) val_loader = DataLoader( self.val_dataset, batch_size=batch_size, shuffle=False ) # Optimizer optimizer = torch.optim.AdamW( self.model.parameters(), lr=lr, weight_decay=weight_decay ) # Learning rate scheduler scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # Training loop no_improve = 0 for epoch in range(epochs): # Training self.model.train() train_loss = 0.0 for batch in train_loader: optimizer.zero_grad() # Move to device design = batch['design'].to(self.device) z_disp_true = batch['z_displacement'].to(self.device) # Forward pass for each sample in batch batch_loss = 0.0 for i in range(design.size(0)): z_disp_pred = self.model( self.node_features, self.edge_index, self.edge_attr, design[i] ) # MSE loss on displacement field loss = F.mse_loss(z_disp_pred, z_disp_true[i]) batch_loss = batch_loss + loss batch_loss = batch_loss / design.size(0) batch_loss.backward() optimizer.step() train_loss += batch_loss.item() train_loss /= len(train_loader) scheduler.step() # Validation val_loss, val_metrics = self._validate(val_loader) # Track history self.history['train_loss'].append(train_loss) self.history['val_loss'].append(val_loss) self.history['val_r2'].append(val_metrics.get('r2_mean', 0)) # Early stopping if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} no_improve = 0 else: no_improve += 1 if verbose and epoch % 10 == 0: print(f"[Epoch {epoch:3d}] Train: {train_loss:.6f}, Val: {val_loss:.6f}, " f"R²: {val_metrics.get('r2_mean', 0):.4f}, LR: {scheduler.get_last_lr()[0]:.2e}", flush=True) if no_improve >= patience: print(f"[TRAINER] Early stopping at epoch {epoch}", flush=True) break # Restore best model self.model.load_state_dict(self.best_model_state) print(f"[TRAINER] Training complete. Best val loss: {self.best_val_loss:.6f}", flush=True) def _validate(self, val_loader: DataLoader) -> Tuple[float, Dict[str, float]]: """Run validation and compute metrics.""" self.model.eval() val_loss = 0.0 all_pred = [] all_true = [] with torch.no_grad(): for batch in val_loader: design = batch['design'].to(self.device) z_disp_true = batch['z_displacement'].to(self.device) for i in range(design.size(0)): z_disp_pred = self.model( self.node_features, self.edge_index, self.edge_attr, design[i] ) loss = F.mse_loss(z_disp_pred, z_disp_true[i]) val_loss += loss.item() all_pred.append(z_disp_pred.cpu()) all_true.append(z_disp_true[i].cpu()) val_loss /= len(self.val_dataset) # Compute R² for each subcase all_pred = torch.stack(all_pred) # [n_val, n_nodes, 4] all_true = torch.stack(all_true) r2_per_subcase = [] for sc in range(4): pred_flat = all_pred[:, :, sc].flatten() true_flat = all_true[:, :, sc].flatten() ss_res = ((true_flat - pred_flat) ** 2).sum() ss_tot = ((true_flat - true_flat.mean()) ** 2).sum() r2 = 1 - ss_res / (ss_tot + 1e-8) r2_per_subcase.append(r2.item()) metrics = { 'r2_mean': np.mean(r2_per_subcase), 'r2_per_subcase': r2_per_subcase, } return val_loss, metrics def evaluate_objectives(self) -> Dict[str, Any]: """ Evaluate objective prediction accuracy on validation set. Returns: Dictionary with per-objective metrics """ self.model.eval() obj_pred_all = {k: [] for k in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']} obj_true_all = {k: [] for k in obj_pred_all} # Move objective layer to CPU for now (small dataset) with torch.no_grad(): for i in range(len(self.val_dataset)): item = self.val_dataset[i] design = item['design'].to(self.device) z_disp_true = item['z_displacement'] # Already scaled # Predict z_disp_pred = self.model( self.node_features, self.edge_index, self.edge_attr, design ).cpu() # Unscale for objective computation z_disp_pred_mm = z_disp_pred / self.disp_scale z_disp_true_mm = z_disp_true / self.disp_scale # Compute objectives obj_pred = self.objective_layer(z_disp_pred_mm) obj_true = self.objective_layer(z_disp_true_mm) for k in obj_pred_all: obj_pred_all[k].append(obj_pred[k].item()) obj_true_all[k].append(obj_true[k].item()) # Compute metrics per objective results = {} for k in obj_pred_all: pred = np.array(obj_pred_all[k]) true = np.array(obj_true_all[k]) mae = np.mean(np.abs(pred - true)) mape = np.mean(np.abs(pred - true) / (np.abs(true) + 1e-6)) * 100 ss_res = np.sum((true - pred) ** 2) ss_tot = np.sum((true - np.mean(true)) ** 2) r2 = 1 - ss_res / (ss_tot + 1e-8) results[k] = { 'mae': mae, 'mape': mape, 'r2': r2, 'pred_range': [pred.min(), pred.max()], 'true_range': [true.min(), true.max()], } return results def save_checkpoint(self, path: Path) -> None: """Save model checkpoint.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) checkpoint = { 'model_state_dict': self.model.state_dict(), 'config': self.model_config, 'design_mean': self.design_mean, 'design_std': self.design_std, 'disp_scale': self.disp_scale, 'history': self.history, 'best_val_loss': self.best_val_loss, 'study_versions': self.study_versions, 'timestamp': datetime.now().isoformat(), } torch.save(checkpoint, path) print(f"[TRAINER] Saved checkpoint to {path}", flush=True) @classmethod def load_checkpoint(cls, path: Path, device: str = 'auto') -> 'ZernikeGNNTrainer': """Load trainer from checkpoint.""" checkpoint = torch.load(path, map_location='cpu') # Create trainer with same config trainer = cls( study_versions=checkpoint['study_versions'], model_type=checkpoint['config']['model_type'], hidden_dim=checkpoint['config']['hidden_dim'], n_layers=checkpoint['config']['n_layers'], device=device, ) # Load model weights trainer.model.load_state_dict(checkpoint['model_state_dict']) # Restore normalization trainer.design_mean = checkpoint['design_mean'] trainer.design_std = checkpoint['design_std'] trainer.disp_scale = checkpoint['disp_scale'] # Restore history trainer.history = checkpoint['history'] trainer.best_val_loss = checkpoint['best_val_loss'] return trainer def predict(self, design_vars: Dict[str, float]) -> Dict[str, Any]: """ Make prediction for new design. Args: design_vars: Dictionary of design parameter values Returns: Dictionary with displacement field and objectives """ self.model.eval() # Convert to tensor design_names = self.train_dataset.data_list[0]['design_names'] design = torch.tensor( [design_vars[name] for name in design_names], dtype=torch.float32 ) # Normalize design_norm = (design - self.design_mean) / self.design_std with torch.no_grad(): z_disp_scaled = self.model( self.node_features, self.edge_index, self.edge_attr, design_norm.to(self.device) ).cpu() # Unscale z_disp_mm = z_disp_scaled / self.disp_scale # Compute objectives objectives = self.objective_layer(z_disp_mm) return { 'z_displacement': z_disp_mm.numpy(), 'objectives': {k: v.item() for k, v in objectives.items()}, } # ============================================================================= # CLI # ============================================================================= def main(): import argparse parser = argparse.ArgumentParser(description='Train ZernikeGNN surrogate') parser.add_argument('studies', nargs='+', help='Study versions (e.g., V11 V12)') parser.add_argument('--epochs', type=int, default=200, help='Training epochs') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--batch-size', type=int, default=4, help='Batch size') parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension') parser.add_argument('--n-layers', type=int, default=6, help='Message passing layers') parser.add_argument('--model-type', choices=['full', 'lite'], default='full') parser.add_argument('--output', '-o', type=Path, help='Output checkpoint path') parser.add_argument('--device', default='auto', help='Device (cpu, cuda, auto)') args = parser.parse_args() # Create trainer print("="*60, flush=True) print("ZERNIKE GNN TRAINING", flush=True) print("="*60, flush=True) trainer = ZernikeGNNTrainer( study_versions=args.studies, model_type=args.model_type, hidden_dim=args.hidden_dim, n_layers=args.n_layers, device=args.device, ) # Train trainer.train( epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, ) # Evaluate objectives print("\n--- Objective Prediction Evaluation ---", flush=True) obj_results = trainer.evaluate_objectives() for k, v in obj_results.items(): print(f"\n{k}:", flush=True) print(f" MAE: {v['mae']:.2f} nm", flush=True) print(f" MAPE: {v['mape']:.1f}%", flush=True) print(f" R²: {v['r2']:.4f}", flush=True) # Save checkpoint if args.output: output_path = args.output else: output_path = Path("zernike_gnn_checkpoint.pt") trainer.save_checkpoint(output_path) print("\n" + "="*60, flush=True) print("✓ Training complete!", flush=True) print("="*60, flush=True) if __name__ == '__main__': main()