feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies

This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Antoine
2025-12-10 08:44:04 -05:00
parent c6f39bfd6c
commit 96b196de58
22 changed files with 8329 additions and 2 deletions

View File

@@ -0,0 +1,69 @@
"""
GNN (Graph Neural Network) Surrogate Module for Atomizer
=========================================================
This module provides Graph Neural Network-based surrogates for FEA optimization,
particularly designed for Zernike-based mirror optimization where spatial structure
matters.
Key Components:
- PolarMirrorGraph: Fixed polar grid graph structure for mirror surface
- ZernikeGNN: GNN model for predicting displacement fields
- DifferentiableZernikeFit: GPU-accelerated Zernike fitting
- ZernikeObjectiveLayer: Compute objectives from displacement fields
- ZernikeGNNTrainer: Complete training pipeline
Why GNN over MLP for Zernike?
1. Spatial awareness: GNN learns smooth deformation fields via message passing
2. Correct relative computation: Predicts fields, then subtracts (like FEA)
3. Multi-task learning: Field + objective supervision
4. Physics-informed: Edge structure respects mirror geometry
Usage:
# Training
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
# API
from optimization_engine.gnn import PolarMirrorGraph, ZernikeGNN, ZernikeGNNTrainer
"""
__version__ = "1.0.0"
# Core components
from .polar_graph import PolarMirrorGraph, create_mirror_dataset
from .zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model, load_model
from .differentiable_zernike import (
DifferentiableZernikeFit,
ZernikeObjectiveLayer,
ZernikeRMSLoss,
build_zernike_matrix,
)
from .extract_displacement_field import (
extract_displacement_field,
save_field,
load_field,
)
from .train_zernike_gnn import ZernikeGNNTrainer, MirrorDataset
__all__ = [
# Polar Graph
'PolarMirrorGraph',
'create_mirror_dataset',
# GNN Model
'ZernikeGNN',
'ZernikeGNNLite',
'create_model',
'load_model',
# Zernike Layers
'DifferentiableZernikeFit',
'ZernikeObjectiveLayer',
'ZernikeRMSLoss',
'build_zernike_matrix',
# Field Extraction
'extract_displacement_field',
'save_field',
'load_field',
# Training
'ZernikeGNNTrainer',
'MirrorDataset',
]

View File

@@ -0,0 +1,475 @@
"""
Backfill Displacement Field Data from Existing Trials
======================================================
This script scans existing mirror optimization studies (V11, V12, etc.) and extracts
displacement field data from OP2 files for GNN training.
Structure it expects:
studies/m1_mirror_adaptive_V11/
├── 2_iterations/
│ ├── iter91/
│ │ ├── assy_m1_assyfem1_sim1-solution_1.op2
│ │ ├── assy_m1_assyfem1_sim1-solution_1.dat
│ │ └── params.exp
│ ├── iter92/
│ │ └── ...
└── 3_results/
└── study.db (Optuna database)
Output structure:
studies/m1_mirror_adaptive_V11/
└── gnn_data/
├── trial_0000/
│ ├── displacement_field.h5
│ └── metadata.json
├── trial_0001/
│ └── ...
└── dataset_index.json (maps iter -> trial)
Usage:
python -m optimization_engine.gnn.backfill_field_data V11
python -m optimization_engine.gnn.backfill_field_data V11 V12 --merge
"""
import json
import re
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import numpy as np
# Add parent to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from optimization_engine.gnn.extract_displacement_field import (
extract_displacement_field,
save_field,
load_field,
HAS_H5PY,
)
def find_studies(base_dir: Path, pattern: str = "m1_mirror_adaptive_V*") -> List[Path]:
"""Find all matching study directories."""
studies_dir = base_dir / "studies"
matches = list(studies_dir.glob(pattern))
return sorted(matches)
def find_op2_files(study_dir: Path) -> List[Tuple[int, Path, Path]]:
"""
Find all OP2 files in iteration folders.
Returns:
List of (iter_number, op2_path, dat_path) tuples
"""
iterations_dir = study_dir / "2_iterations"
if not iterations_dir.exists():
print(f"[WARN] No 2_iterations folder in {study_dir.name}")
return []
results = []
for iter_dir in sorted(iterations_dir.iterdir()):
if not iter_dir.is_dir():
continue
# Extract iteration number
match = re.match(r'iter(\d+)', iter_dir.name)
if not match:
continue
iter_num = int(match.group(1))
# Find OP2 file
op2_files = list(iter_dir.glob('*-solution_1.op2'))
if not op2_files:
op2_files = list(iter_dir.glob('*.op2'))
if not op2_files:
continue
op2_path = op2_files[0]
# Find DAT file
dat_path = op2_path.with_suffix('.dat')
if not dat_path.exists():
dat_path = op2_path.with_suffix('.bdf')
if not dat_path.exists():
print(f"[WARN] No DAT/BDF for {op2_path.name}, skipping")
continue
results.append((iter_num, op2_path, dat_path))
return results
def read_params_exp(iter_dir: Path) -> Optional[Dict[str, float]]:
"""Read design parameters from params.exp file."""
params_file = iter_dir / "params.exp"
if not params_file.exists():
return None
params = {}
with open(params_file, 'r') as f:
for line in f:
line = line.strip()
if '=' in line:
# Format: name = value
parts = line.split('=')
if len(parts) == 2:
name = parts[0].strip()
try:
value = float(parts[1].strip())
params[name] = value
except ValueError:
pass
return params
def backfill_study(
study_dir: Path,
output_dir: Optional[Path] = None,
r_inner: float = 100.0,
r_outer: float = 650.0,
overwrite: bool = False,
verbose: bool = True
) -> Dict[str, Any]:
"""
Backfill displacement field data for a single study.
Args:
study_dir: Path to study directory
output_dir: Output directory (default: study_dir/gnn_data)
r_inner: Inner radius for surface identification
r_outer: Outer radius for surface identification
overwrite: Overwrite existing field data
verbose: Print progress
Returns:
Summary dictionary with statistics
"""
if output_dir is None:
output_dir = study_dir / "gnn_data"
output_dir.mkdir(parents=True, exist_ok=True)
if verbose:
print(f"\n{'='*60}")
print(f"BACKFILLING: {study_dir.name}")
print(f"{'='*60}")
# Find all OP2 files
op2_list = find_op2_files(study_dir)
if verbose:
print(f"Found {len(op2_list)} iterations with OP2 files")
# Track results
success_count = 0
skip_count = 0
error_count = 0
index = {}
for iter_num, op2_path, dat_path in op2_list:
# Create trial directory
trial_dir = output_dir / f"trial_{iter_num:04d}"
# Check if already exists
field_ext = '.h5' if HAS_H5PY else '.npz'
field_path = trial_dir / f"displacement_field{field_ext}"
if field_path.exists() and not overwrite:
if verbose:
print(f"[SKIP] iter{iter_num}: already processed")
skip_count += 1
index[iter_num] = {
'trial_dir': str(trial_dir.relative_to(study_dir)),
'status': 'skipped',
}
continue
try:
# Extract displacement field
if verbose:
print(f"[{iter_num:3d}] Extracting from {op2_path.name}...", end=' ')
field_data = extract_displacement_field(
op2_path,
bdf_path=dat_path,
r_inner=r_inner,
r_outer=r_outer,
verbose=False
)
# Save field data
trial_dir.mkdir(parents=True, exist_ok=True)
save_field(field_data, field_path)
# Read params if available
params = read_params_exp(op2_path.parent)
# Save metadata
meta = {
'iter_number': iter_num,
'op2_file': str(op2_path.name),
'n_nodes': len(field_data['node_ids']),
'subcases': list(field_data['z_displacement'].keys()),
'params': params,
'extraction_timestamp': datetime.now().isoformat(),
}
meta_path = trial_dir / "metadata.json"
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
if verbose:
print(f"OK ({len(field_data['node_ids'])} nodes)")
success_count += 1
index[iter_num] = {
'trial_dir': str(trial_dir.relative_to(study_dir)),
'n_nodes': len(field_data['node_ids']),
'params': params,
'status': 'success',
}
except Exception as e:
if verbose:
print(f"ERROR: {e}")
error_count += 1
index[iter_num] = {
'trial_dir': str(trial_dir.relative_to(study_dir)) if trial_dir.exists() else None,
'error': str(e),
'status': 'error',
}
# Save index file
index_path = output_dir / "dataset_index.json"
index_data = {
'study_name': study_dir.name,
'generated': datetime.now().isoformat(),
'summary': {
'total': len(op2_list),
'success': success_count,
'skipped': skip_count,
'errors': error_count,
},
'trials': index,
}
with open(index_path, 'w') as f:
json.dump(index_data, f, indent=2)
if verbose:
print(f"\n{'='*60}")
print(f"SUMMARY: {study_dir.name}")
print(f" Success: {success_count}")
print(f" Skipped: {skip_count}")
print(f" Errors: {error_count}")
print(f" Index: {index_path}")
print(f"{'='*60}")
return index_data
def merge_datasets(
study_dirs: List[Path],
output_dir: Path,
train_ratio: float = 0.8,
verbose: bool = True
) -> Dict[str, Any]:
"""
Merge displacement field data from multiple studies into a single dataset.
Args:
study_dirs: List of study directories
output_dir: Output directory for merged dataset
train_ratio: Fraction of data for training (rest for validation)
verbose: Print progress
Returns:
Dataset metadata dictionary
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if verbose:
print(f"\n{'='*60}")
print("MERGING DATASETS")
print(f"{'='*60}")
all_trials = []
for study_dir in study_dirs:
gnn_data_dir = study_dir / "gnn_data"
index_path = gnn_data_dir / "dataset_index.json"
if not index_path.exists():
print(f"[WARN] No index for {study_dir.name}, run backfill first")
continue
with open(index_path, 'r') as f:
index = json.load(f)
study_name = study_dir.name
for iter_num, trial_info in index['trials'].items():
if trial_info.get('status') != 'success':
continue
trial_dir = study_dir / trial_info['trial_dir']
all_trials.append({
'study': study_name,
'iter': int(iter_num),
'trial_dir': trial_dir,
'params': trial_info.get('params', {}),
'n_nodes': trial_info.get('n_nodes'),
})
if verbose:
print(f"Total successful trials: {len(all_trials)}")
# Shuffle and split
np.random.seed(42)
indices = np.random.permutation(len(all_trials))
n_train = int(len(all_trials) * train_ratio)
train_indices = indices[:n_train]
val_indices = indices[n_train:]
# Create split files
splits = {
'train': [all_trials[i] for i in train_indices],
'val': [all_trials[i] for i in val_indices],
}
for split_name, trials in splits.items():
split_dir = output_dir / split_name
split_dir.mkdir(exist_ok=True)
split_meta = []
for i, trial in enumerate(trials):
# Copy/link field data
src_ext = '.h5' if HAS_H5PY else '.npz'
src_path = trial['trial_dir'] / f"displacement_field{src_ext}"
dst_path = split_dir / f"sample_{i:04d}{src_ext}"
if src_path.exists():
# Copy file (or could use symlink on Linux)
import shutil
shutil.copy(src_path, dst_path)
split_meta.append({
'index': i,
'source_study': trial['study'],
'source_iter': trial['iter'],
'params': trial['params'],
'n_nodes': trial['n_nodes'],
})
# Save split metadata
meta_path = split_dir / "metadata.json"
with open(meta_path, 'w') as f:
json.dump({
'split': split_name,
'n_samples': len(split_meta),
'samples': split_meta,
}, f, indent=2)
if verbose:
print(f" {split_name}: {len(split_meta)} samples")
# Save overall metadata
dataset_meta = {
'created': datetime.now().isoformat(),
'source_studies': [str(s.name) for s in study_dirs],
'total_samples': len(all_trials),
'train_samples': len(splits['train']),
'val_samples': len(splits['val']),
'train_ratio': train_ratio,
}
with open(output_dir / "dataset_meta.json", 'w') as f:
json.dump(dataset_meta, f, indent=2)
if verbose:
print(f"\nDataset saved to: {output_dir}")
print(f" Train: {len(splits['train'])} samples")
print(f" Val: {len(splits['val'])} samples")
return dataset_meta
# =============================================================================
# CLI
# =============================================================================
def main():
import argparse
parser = argparse.ArgumentParser(
description='Backfill displacement field data for GNN training'
)
parser.add_argument('studies', nargs='+', type=str,
help='Study versions (e.g., V11 V12) or "all"')
parser.add_argument('--merge', action='store_true',
help='Merge data from multiple studies')
parser.add_argument('--output', '-o', type=Path,
help='Output directory for merged dataset')
parser.add_argument('--r-inner', type=float, default=100.0,
help='Inner radius (mm)')
parser.add_argument('--r-outer', type=float, default=650.0,
help='Outer radius (mm)')
parser.add_argument('--overwrite', action='store_true',
help='Overwrite existing field data')
parser.add_argument('--train-ratio', type=float, default=0.8,
help='Train/val split ratio')
args = parser.parse_args()
# Find base directory
base_dir = Path(__file__).parent.parent.parent
# Find studies
if args.studies == ['all']:
study_dirs = find_studies(base_dir, "m1_mirror_adaptive_V*")
else:
study_dirs = []
for s in args.studies:
if s.startswith('V'):
pattern = f"m1_mirror_adaptive_{s}"
else:
pattern = s
matches = find_studies(base_dir, pattern)
study_dirs.extend(matches)
if not study_dirs:
print("No studies found!")
return 1
print(f"Found {len(study_dirs)} studies:")
for s in study_dirs:
print(f" - {s.name}")
# Backfill each study
for study_dir in study_dirs:
backfill_study(
study_dir,
r_inner=args.r_inner,
r_outer=args.r_outer,
overwrite=args.overwrite,
)
# Merge if requested
if args.merge and len(study_dirs) > 1:
output_dir = args.output
if output_dir is None:
output_dir = base_dir / "studies" / "gnn_merged_dataset"
merge_datasets(
study_dirs,
output_dir,
train_ratio=args.train_ratio,
)
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,544 @@
"""
Differentiable Zernike Fitting Layer
=====================================
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
The key innovation is putting Zernike fitting INSIDE the neural network for
end-to-end training.
Why Differentiable Zernike?
Current MLP approach:
design → MLP → coefficients (learn 200 outputs independently)
GNN + Differentiable Zernike:
design → GNN → displacement field → Zernike fit → coefficients
Differentiable! Gradients flow back
This allows the network to learn:
1. Spatially coherent displacement fields
2. Fields that produce accurate Zernike coefficients
3. Correct relative deformation computation
Components:
1. DifferentiableZernikeFit - Fits coefficients from displacement field
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
Usage:
from optimization_engine.gnn.differentiable_zernike import (
DifferentiableZernikeFit,
ZernikeObjectiveLayer
)
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph()
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
# In forward pass:
z_disp = gnn_model(...) # [n_nodes, 4]
objectives = objective_layer(z_disp) # Dict with RMS values
"""
import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
from pathlib import Path
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
"""
Compute Zernike polynomial for Noll index j.
Uses the standard Noll indexing convention:
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
Args:
j: Noll index (1-based)
r: Radial coordinates (normalized to [0, 1])
theta: Angular coordinates (radians)
Returns:
Zernike polynomial values at (r, theta)
"""
# Convert Noll index to (n, m)
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
m_sum = j - n * (n + 1) // 2 - 1
if n % 2 == 0:
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
else:
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
if (n - m) % 2 == 1:
m = -m
# Compute radial polynomial R_n^|m|(r)
R = np.zeros_like(r)
m_abs = abs(m)
for k in range((n - m_abs) // 2 + 1):
coef = ((-1) ** k * math.factorial(n - k) /
(math.factorial(k) *
math.factorial((n + m_abs) // 2 - k) *
math.factorial((n - m_abs) // 2 - k)))
R += coef * r ** (n - 2 * k)
# Combine with angular part
if m >= 0:
Z = R * np.cos(m_abs * theta)
else:
Z = R * np.sin(m_abs * theta)
# Normalization factor
if m == 0:
norm = np.sqrt(n + 1)
else:
norm = np.sqrt(2 * (n + 1))
return norm * Z
def build_zernike_matrix(
r: np.ndarray,
theta: np.ndarray,
n_modes: int = 50,
r_max: float = None
) -> np.ndarray:
"""
Build Zernike basis matrix for a set of points.
Args:
r: Radial coordinates
theta: Angular coordinates
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
r_max: Maximum radius for normalization (if None, use max(r))
Returns:
Z: [n_points, n_modes] Zernike basis matrix
"""
if r_max is None:
r_max = r.max()
r_norm = r / r_max
n_points = len(r)
Z = np.zeros((n_points, n_modes), dtype=np.float64)
for j in range(1, n_modes + 1):
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
return Z
class DifferentiableZernikeFit(nn.Module):
"""
GPU-accelerated, differentiable Zernike polynomial fitting.
This layer fits Zernike coefficients to a displacement field using
least squares. The key insight is that least squares has a closed-form
solution: c = (Z^T Z)^{-1} Z^T @ values
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
matrix multiply, which is fully differentiable.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes to fit
regularization: Tikhonov regularization for stability
"""
super().__init__()
self.n_modes = n_modes
# Get coordinates from polar graph
r = polar_graph.r
theta = polar_graph.theta
r_max = polar_graph.r_outer
# Build Zernike basis matrix [n_nodes, n_modes]
Z = build_zernike_matrix(r, theta, n_modes, r_max)
# Convert to tensor and register as buffer
Z_tensor = torch.tensor(Z, dtype=torch.float32)
self.register_buffer('Z', Z_tensor)
# Precompute pseudo-inverse with regularization
# c = (Z^T Z + λI)^{-1} Z^T @ values
ZtZ = Z_tensor.T @ Z_tensor
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
ZtZ_inv = torch.inverse(ZtZ_reg)
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
self.register_buffer('pseudo_inverse', pseudo_inv)
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
"""
Fit Zernike coefficients to displacement field.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
Returns:
coefficients: [n_modes] or [n_subcases, n_modes]
"""
if z_displacement.dim() == 1:
# Single field: [n_nodes] → [n_modes]
return self.pseudo_inverse @ z_displacement
else:
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
# Transpose, multiply, transpose back
return (self.pseudo_inverse @ z_displacement).T
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
"""
Reconstruct displacement field from coefficients.
Args:
coefficients: [n_modes] or [n_subcases, n_modes]
Returns:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
"""
if coefficients.dim() == 1:
return self.Z @ coefficients
else:
return self.Z @ coefficients.T
def fit_and_residual(
self,
z_displacement: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fit coefficients and return residual.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
Returns:
coefficients, residual
"""
coeffs = self.forward(z_displacement)
reconstruction = self.reconstruct(coeffs)
residual = z_displacement - reconstruction
return coeffs, residual
class ZernikeObjectiveLayer(nn.Module):
"""
Compute Zernike-based optimization objectives from displacement field.
This layer replicates the exact computation done in FEA post-processing:
1. Compute relative displacement (e.g., 40° - 20°)
2. Convert to wavefront error (× 2 for reflection, mm → nm)
3. Fit Zernike and remove low-order terms
4. Compute filtered RMS
The computation is fully differentiable, allowing end-to-end training
with objective-based loss.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes
regularization: Regularization for Zernike fitting
"""
super().__init__()
self.n_modes = n_modes
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
# Precompute Zernike basis subsets for filtering
Z = self.zernike_fit.Z
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
self.register_buffer('Z_j1_to_j4', Z[:, :4])
# Only J1-J3 for manufacturing objective
self.register_buffer('Z_j1_to_j3', Z[:, :3])
# Store node count
self.n_nodes = Z.shape[0]
def forward(
self,
z_disp_all_subcases: torch.Tensor,
return_all: bool = False
) -> Dict[str, torch.Tensor]:
"""
Compute Zernike objectives from displacement field.
Args:
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
return_all: If True, return additional diagnostics
Returns:
Dictionary with objective values:
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
"""
# Unpack subcases
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
# === Objective 1: Relative filtered RMS 40° vs 20° ===
disp_rel_40 = disp_40 - disp_20
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
# === Objective 2: Relative filtered RMS 60° vs 20° ===
disp_rel_60 = disp_60 - disp_20
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
disp_rel_90 = disp_90 - disp_20
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
result = {
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
'mfg_90_optician_workload': mfg_90,
}
if return_all:
# Include intermediate values for debugging
result['wfe_rel_40'] = wfe_rel_40
result['wfe_rel_60'] = wfe_rel_60
result['wfe_rel_90'] = wfe_rel_90
return result
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
This is the standard filtered RMS for optical performance.
"""
# Fit low-order coefficients using precomputed pseudo-inverse
# c = (Z^T Z)^{-1} Z^T @ wfe
Z_low = self.Z_j1_to_j4
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
# Reconstruct low-order surface
wfe_low = Z_low @ coeffs_low
# Residual (high-order content)
wfe_filtered = wfe - wfe_low
# RMS
return torch.sqrt(torch.mean(wfe_filtered ** 2))
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing only J1-J3 (piston, tip, tilt).
This keeps defocus (J4), which is harder to polish out - represents
actual manufacturing workload.
"""
Z_low = self.Z_j1_to_j3
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
wfe_low = Z_low @ coeffs_low
wfe_filtered = wfe - wfe_low
return torch.sqrt(torch.mean(wfe_filtered ** 2))
class ZernikeRMSLoss(nn.Module):
"""
Combined loss function for GNN training.
This loss combines:
1. Displacement field reconstruction loss (MSE)
2. Objective prediction loss (relative Zernike RMS)
The multi-task loss helps the network learn both accurate
displacement fields AND accurate objective predictions.
"""
def __init__(
self,
polar_graph,
field_weight: float = 1.0,
objective_weight: float = 0.1,
n_modes: int = 50
):
"""
Args:
polar_graph: PolarMirrorGraph instance
field_weight: Weight for displacement field loss
objective_weight: Weight for objective loss
n_modes: Number of Zernike modes
"""
super().__init__()
self.field_weight = field_weight
self.objective_weight = objective_weight
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
def forward(
self,
z_disp_pred: torch.Tensor,
z_disp_true: torch.Tensor,
objectives_true: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Compute combined loss.
Args:
z_disp_pred: Predicted displacement [n_nodes, 4]
z_disp_true: Ground truth displacement [n_nodes, 4]
objectives_true: Optional dict of true objective values
Returns:
total_loss, loss_components dict
"""
# Field reconstruction loss
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
# Scale field loss to account for small displacement values
# Displacements are ~1e-4 mm, so MSE is ~1e-8
loss_field_scaled = loss_field * 1e8
components = {
'loss_field': loss_field_scaled,
}
total_loss = self.field_weight * loss_field_scaled
# Objective loss (if ground truth provided)
if objectives_true is not None and self.objective_weight > 0:
objectives_pred = self.objective_layer(z_disp_pred)
loss_obj = 0.0
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
if key in objectives_true:
pred = objectives_pred[key]
true = objectives_true[key]
# Relative error squared
rel_err = ((pred - true) / (true + 1e-6)) ** 2
loss_obj = loss_obj + rel_err
components[f'loss_{key}'] = rel_err
components['loss_objectives'] = loss_obj
total_loss = total_loss + self.objective_weight * loss_obj
components['total_loss'] = total_loss
return total_loss, components
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
print("="*60)
print("Testing Differentiable Zernike Layer")
print("="*60)
# Create polar graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\nPolar Graph: {graph.n_nodes} nodes")
# Create Zernike fitting layer
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
print(f" Z matrix: {zernike_fit.Z.shape}")
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
# Test with synthetic displacement
print("\n--- Test Zernike Fitting ---")
# Create synthetic displacement (defocus + astigmatism pattern)
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
theta = torch.tensor(graph.theta, dtype=torch.float32)
# Defocus (J4) + Astigmatism (J5)
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
# Fit coefficients
coeffs = zernike_fit(synthetic_disp)
print(f"Fitted coefficients shape: {coeffs.shape}")
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
# Reconstruct
recon = zernike_fit.reconstruct(coeffs)
error = (synthetic_disp - recon).abs()
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
# Test with multiple subcases
print("\n--- Test Multi-Subcase ---")
z_disp_multi = torch.stack([
synthetic_disp,
synthetic_disp * 0.5,
synthetic_disp * 0.7,
synthetic_disp * 0.9,
], dim=1) # [n_nodes, 4]
coeffs_multi = zernike_fit(z_disp_multi)
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
# Test objective layer
print("\n--- Test Objective Layer ---")
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
objectives = objective_layer(z_disp_multi)
print("Computed objectives:")
for key, val in objectives.items():
print(f" {key}: {val.item():.2f} nm")
# Test gradient flow
print("\n--- Test Gradient Flow ---")
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
objectives = objective_layer(z_disp_grad)
loss = objectives['rel_filtered_rms_40_vs_20']
loss.backward()
print(f"Gradient shape: {z_disp_grad.grad.shape}")
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
print("✓ Gradients flow through Zernike fitting!")
# Test loss function
print("\n--- Test Loss Function ---")
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
print(f"Total loss: {total_loss.item():.6f}")
for key, val in components.items():
if isinstance(val, torch.Tensor):
print(f" {key}: {val.item():.6f}")
print("\n" + "="*60)
print("✓ All tests passed!")
print("="*60)

View File

@@ -0,0 +1,455 @@
"""
Displacement Field Extraction for GNN Training
===============================================
This module extracts full displacement fields from Nastran OP2 files for GNN training.
Unlike the Zernike extractors that reduce to coefficients, this preserves the raw
spatial data that GNN needs to learn the physics.
Key Features:
1. Extract Z-displacement for all optical surface nodes
2. Store node coordinates for graph construction
3. Support for multiple subcases (gravity orientations)
4. HDF5 storage for efficient loading during training
Output Format (HDF5):
/node_ids - [n_nodes] int array
/node_coords - [n_nodes, 3] float array (X, Y, Z)
/subcase_1 - [n_nodes] Z-displacement for subcase 1
/subcase_2 - [n_nodes] Z-displacement for subcase 2
/subcase_3 - [n_nodes] Z-displacement for subcase 3
/subcase_4 - [n_nodes] Z-displacement for subcase 4
/metadata - JSON string with extraction info
Usage:
from optimization_engine.gnn.extract_displacement_field import extract_displacement_field
field_data = extract_displacement_field(op2_path, bdf_path)
save_field_to_hdf5(field_data, output_path)
"""
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime
try:
import h5py
HAS_H5PY = True
except ImportError:
HAS_H5PY = False
from pyNastran.op2.op2 import OP2
from pyNastran.bdf.bdf import BDF
def identify_optical_surface_nodes(
node_coords: Dict[int, np.ndarray],
r_inner: float = 100.0,
r_outer: float = 650.0,
z_tolerance: float = 100.0
) -> Tuple[List[int], np.ndarray]:
"""
Identify nodes on the optical surface by spatial filtering.
The optical surface is identified by:
1. Radial position (between inner and outer radius)
2. Consistent Z range (nodes on the curved mirror surface)
Args:
node_coords: Dictionary mapping node ID to (X, Y, Z) coordinates
r_inner: Inner radius cutoff (central hole)
r_outer: Outer radius limit
z_tolerance: Maximum Z deviation from mean to include
Returns:
Tuple of (node_ids list, coordinates array [n, 3])
"""
# Get all coordinates as arrays
nids = list(node_coords.keys())
coords = np.array([node_coords[nid] for nid in nids])
# Calculate radial position
r = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
# Initial radial filter
radial_mask = (r >= r_inner) & (r <= r_outer)
# Find nodes in radial range
radial_nids = np.array(nids)[radial_mask]
radial_coords = coords[radial_mask]
if len(radial_coords) == 0:
raise ValueError(f"No nodes found in radial range [{r_inner}, {r_outer}]")
# The optical surface should have a relatively small Z range
z_vals = radial_coords[:, 2]
z_mean = np.mean(z_vals)
# Filter to nodes within z_tolerance of the mean Z
z_mask = np.abs(radial_coords[:, 2] - z_mean) < z_tolerance
surface_nids = radial_nids[z_mask].tolist()
surface_coords = radial_coords[z_mask]
return surface_nids, surface_coords
def extract_displacement_field(
op2_path: Path,
bdf_path: Optional[Path] = None,
r_inner: float = 100.0,
r_outer: float = 650.0,
subcases: Optional[List[int]] = None,
verbose: bool = True
) -> Dict[str, Any]:
"""
Extract full displacement field for GNN training.
This function extracts Z-displacement from OP2 files for nodes on the optical
surface (defined by radial position). It builds node coordinates directly from
the OP2 data matched against BDF geometry, then filters by radial position.
Args:
op2_path: Path to OP2 file
bdf_path: Path to BDF/DAT file (auto-detected if None)
r_inner: Inner radius for surface identification (mm)
r_outer: Outer radius for surface identification (mm)
subcases: List of subcases to extract (default: [1, 2, 3, 4])
verbose: Print progress messages
Returns:
Dictionary containing:
- node_ids: List of node IDs on optical surface
- node_coords: Array [n_nodes, 3] of coordinates
- z_displacement: Dict mapping subcase -> [n_nodes] Z-displacements
- metadata: Extraction metadata
"""
op2_path = Path(op2_path)
# Find BDF file
if bdf_path is None:
for ext in ['.dat', '.bdf']:
candidate = op2_path.with_suffix(ext)
if candidate.exists():
bdf_path = candidate
break
if bdf_path is None:
raise FileNotFoundError(f"No .dat or .bdf found for {op2_path}")
if subcases is None:
subcases = [1, 2, 3, 4]
if verbose:
print(f"[FIELD] Reading geometry from: {bdf_path.name}")
# Read geometry from BDF
bdf = BDF()
bdf.read_bdf(str(bdf_path))
node_geo = {int(nid): node.get_position() for nid, node in bdf.nodes.items()}
if verbose:
print(f"[FIELD] Total nodes in BDF: {len(node_geo)}")
# Read OP2
if verbose:
print(f"[FIELD] Reading displacements from: {op2_path.name}")
op2 = OP2()
op2.read_op2(str(op2_path))
if not op2.displacements:
raise RuntimeError("No displacement data in OP2")
# Extract data by iterating through OP2 nodes and matching to BDF geometry
# This approach works even when node numbering differs between sources
subcase_data = {}
for key, darr in op2.displacements.items():
isub = int(getattr(darr, 'isubcase', key))
if isub not in subcases:
continue
data = darr.data
dmat = data[0] if data.ndim == 3 else data
ngt = darr.node_gridtype
op2_node_ids = ngt[:, 0] if ngt.ndim == 2 else ngt
# Build arrays of matched data
nids = []
X = []
Y = []
Z = []
disp_z = []
for i, nid in enumerate(op2_node_ids):
nid_int = int(nid)
if nid_int in node_geo:
pos = node_geo[nid_int]
nids.append(nid_int)
X.append(pos[0])
Y.append(pos[1])
Z.append(pos[2])
disp_z.append(float(dmat[i, 2])) # Z component
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = np.array(Z, dtype=np.float32)
disp_z = np.array(disp_z, dtype=np.float32)
nids = np.array(nids, dtype=np.int32)
# Filter to optical surface by radial position
r = np.sqrt(X**2 + Y**2)
surface_mask = (r >= r_inner) & (r <= r_outer)
subcase_data[isub] = {
'node_ids': nids[surface_mask],
'coords': np.column_stack([X[surface_mask], Y[surface_mask], Z[surface_mask]]),
'disp_z': disp_z[surface_mask],
}
if verbose:
print(f"[FIELD] Subcase {isub}: {len(nids)} matched, {np.sum(surface_mask)} on surface")
# Get common nodes across all subcases (should be the same)
all_subcase_keys = list(subcase_data.keys())
if not all_subcase_keys:
raise RuntimeError("No subcases found in OP2")
# Use first subcase to define node list
ref_subcase = all_subcase_keys[0]
surface_nids = subcase_data[ref_subcase]['node_ids'].tolist()
surface_coords = subcase_data[ref_subcase]['coords']
# Build displacement dict for all subcases
z_displacement = {}
for isub in subcases:
if isub in subcase_data:
z_displacement[isub] = subcase_data[isub]['disp_z']
if verbose:
print(f"[FIELD] Final surface: {len(surface_nids)} nodes")
r_surface = np.sqrt(surface_coords[:, 0]**2 + surface_coords[:, 1]**2)
print(f"[FIELD] Radial range: [{r_surface.min():.1f}, {r_surface.max():.1f}] mm")
# Build metadata
metadata = {
'extraction_timestamp': datetime.now().isoformat(),
'op2_file': str(op2_path.name),
'bdf_file': str(bdf_path.name),
'n_nodes': len(surface_nids),
'r_inner': r_inner,
'r_outer': r_outer,
'subcases': list(z_displacement.keys()),
}
return {
'node_ids': surface_nids,
'node_coords': surface_coords,
'z_displacement': z_displacement,
'metadata': metadata,
}
def save_field_to_hdf5(
field_data: Dict[str, Any],
output_path: Path,
compression: str = 'gzip'
) -> None:
"""
Save displacement field data to HDF5 file.
Args:
field_data: Output from extract_displacement_field()
output_path: Path to save HDF5 file
compression: Compression algorithm ('gzip', 'lzf', or None)
"""
if not HAS_H5PY:
raise ImportError("h5py required for HDF5 storage: pip install h5py")
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(output_path, 'w') as f:
# Node data
f.create_dataset('node_ids', data=np.array(field_data['node_ids'], dtype=np.int32),
compression=compression)
f.create_dataset('node_coords', data=field_data['node_coords'].astype(np.float32),
compression=compression)
# Displacement for each subcase
for subcase, z_disp in field_data['z_displacement'].items():
f.create_dataset(f'subcase_{subcase}', data=z_disp.astype(np.float32),
compression=compression)
# Metadata as JSON string
f.attrs['metadata'] = json.dumps(field_data['metadata'])
# Report file size
size_kb = output_path.stat().st_size / 1024
print(f"[FIELD] Saved to {output_path.name} ({size_kb:.1f} KB)")
def load_field_from_hdf5(hdf5_path: Path) -> Dict[str, Any]:
"""
Load displacement field data from HDF5 file.
Args:
hdf5_path: Path to HDF5 file
Returns:
Dictionary with same structure as extract_displacement_field()
"""
if not HAS_H5PY:
raise ImportError("h5py required for HDF5 storage: pip install h5py")
with h5py.File(hdf5_path, 'r') as f:
node_ids = f['node_ids'][:].tolist()
node_coords = f['node_coords'][:]
# Load subcases
z_displacement = {}
for key in f.keys():
if key.startswith('subcase_'):
subcase = int(key.split('_')[1])
z_displacement[subcase] = f[key][:]
metadata = json.loads(f.attrs['metadata'])
return {
'node_ids': node_ids,
'node_coords': node_coords,
'z_displacement': z_displacement,
'metadata': metadata,
}
def save_field_to_npz(
field_data: Dict[str, Any],
output_path: Path
) -> None:
"""
Save displacement field data to compressed NPZ file (fallback if no h5py).
Args:
field_data: Output from extract_displacement_field()
output_path: Path to save NPZ file
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
save_dict = {
'node_ids': np.array(field_data['node_ids'], dtype=np.int32),
'node_coords': field_data['node_coords'].astype(np.float32),
'metadata_json': np.array([json.dumps(field_data['metadata'])]),
}
# Add subcases
for subcase, z_disp in field_data['z_displacement'].items():
save_dict[f'subcase_{subcase}'] = z_disp.astype(np.float32)
np.savez_compressed(output_path, **save_dict)
size_kb = output_path.stat().st_size / 1024
print(f"[FIELD] Saved to {output_path.name} ({size_kb:.1f} KB)")
def load_field_from_npz(npz_path: Path) -> Dict[str, Any]:
"""
Load displacement field data from NPZ file.
Args:
npz_path: Path to NPZ file
Returns:
Dictionary with same structure as extract_displacement_field()
"""
data = np.load(npz_path, allow_pickle=True)
node_ids = data['node_ids'].tolist()
node_coords = data['node_coords']
metadata = json.loads(str(data['metadata_json'][0]))
# Load subcases
z_displacement = {}
for key in data.keys():
if key.startswith('subcase_'):
subcase = int(key.split('_')[1])
z_displacement[subcase] = data[key]
return {
'node_ids': node_ids,
'node_coords': node_coords,
'z_displacement': z_displacement,
'metadata': metadata,
}
# =============================================================================
# Convenience functions
# =============================================================================
def save_field(field_data: Dict[str, Any], output_path: Path) -> None:
"""Save field data using best available format (HDF5 preferred)."""
output_path = Path(output_path)
if HAS_H5PY and output_path.suffix == '.h5':
save_field_to_hdf5(field_data, output_path)
else:
if output_path.suffix != '.npz':
output_path = output_path.with_suffix('.npz')
save_field_to_npz(field_data, output_path)
def load_field(path: Path) -> Dict[str, Any]:
"""Load field data from HDF5 or NPZ file."""
path = Path(path)
if path.suffix == '.h5':
return load_field_from_hdf5(path)
else:
return load_field_from_npz(path)
# =============================================================================
# CLI
# =============================================================================
if __name__ == '__main__':
import sys
import argparse
parser = argparse.ArgumentParser(
description='Extract displacement field from Nastran OP2 for GNN training'
)
parser.add_argument('op2_path', type=Path, help='Path to OP2 file')
parser.add_argument('-o', '--output', type=Path, help='Output path (default: same dir as OP2)')
parser.add_argument('--r-inner', type=float, default=100.0, help='Inner radius (mm)')
parser.add_argument('--r-outer', type=float, default=650.0, help='Outer radius (mm)')
parser.add_argument('--format', choices=['h5', 'npz'], default='h5',
help='Output format (default: h5)')
args = parser.parse_args()
# Extract field
field_data = extract_displacement_field(
args.op2_path,
r_inner=args.r_inner,
r_outer=args.r_outer,
)
# Determine output path
if args.output:
output_path = args.output
else:
output_path = args.op2_path.parent / f'displacement_field.{args.format}'
# Save
save_field(field_data, output_path)
# Print summary
print("\n" + "="*60)
print("EXTRACTION SUMMARY")
print("="*60)
print(f"Nodes: {len(field_data['node_ids'])}")
print(f"Subcases: {list(field_data['z_displacement'].keys())}")
for sc, disp in field_data['z_displacement'].items():
print(f" Subcase {sc}: Z range [{disp.min():.4f}, {disp.max():.4f}] mm")

View File

@@ -0,0 +1,718 @@
"""
GNN-Based Optimizer for Zernike Mirror Optimization
====================================================
This module provides a fast GNN-based optimization workflow:
1. Load trained GNN checkpoint
2. Run thousands of fast GNN predictions
3. Select top candidates
4. Validate with FEA (optional)
Usage:
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
optimizer = ZernikeGNNOptimizer.from_checkpoint('zernike_gnn_checkpoint.pt')
results = optimizer.turbo_optimize(n_trials=5000)
# Get best designs
best = results.get_best(n=10)
# Validate with FEA
validated = optimizer.validate_with_fea(best, study_dir)
"""
import json
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
from datetime import datetime
import optuna
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
from optimization_engine.gnn.zernike_gnn import create_model, load_model
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer
@dataclass
class GNNPrediction:
"""Single GNN prediction result."""
design_vars: Dict[str, float]
objectives: Dict[str, float]
z_displacement: Optional[np.ndarray] = None # [3000, 4] if stored
def to_dict(self) -> Dict:
return {
'design_vars': self.design_vars,
'objectives': self.objectives,
}
@dataclass
class OptimizationResults:
"""Container for optimization results."""
predictions: List[GNNPrediction] = field(default_factory=list)
pareto_front: List[int] = field(default_factory=list) # Indices
def add(self, pred: GNNPrediction):
self.predictions.append(pred)
def get_best(self, n: int = 10, objective: str = 'rel_filtered_rms_40_vs_20') -> List[GNNPrediction]:
"""Get top N designs by a single objective."""
sorted_preds = sorted(self.predictions, key=lambda p: p.objectives.get(objective, float('inf')))
return sorted_preds[:n]
def get_pareto_front(self, objectives: List[str] = None) -> List[GNNPrediction]:
"""Get Pareto-optimal designs."""
if objectives is None:
objectives = ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']
# Extract objective values
obj_values = np.array([
[p.objectives.get(obj, float('inf')) for obj in objectives]
for p in self.predictions
])
# Find Pareto front (all objectives are minimized)
pareto_indices = []
for i in range(len(self.predictions)):
is_dominated = False
for j in range(len(self.predictions)):
if i != j:
# j dominates i if j is <= in all objectives and < in at least one
if np.all(obj_values[j] <= obj_values[i]) and np.any(obj_values[j] < obj_values[i]):
is_dominated = True
break
if not is_dominated:
pareto_indices.append(i)
self.pareto_front = pareto_indices
return [self.predictions[i] for i in pareto_indices]
def to_dataframe(self):
"""Convert to pandas DataFrame."""
import pandas as pd
rows = []
for i, pred in enumerate(self.predictions):
row = {'index': i}
row.update(pred.design_vars)
row.update({f'obj_{k}': v for k, v in pred.objectives.items()})
rows.append(row)
return pd.DataFrame(rows)
def save(self, path: Path):
"""Save results to JSON."""
data = {
'n_predictions': len(self.predictions),
'pareto_front_indices': self.pareto_front,
'predictions': [p.to_dict() for p in self.predictions],
'timestamp': datetime.now().isoformat(),
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
class ZernikeGNNOptimizer:
"""
GNN-based optimizer for Zernike mirror optimization.
Provides fast objective prediction using trained GNN surrogate.
"""
def __init__(
self,
model: nn.Module,
polar_graph: PolarMirrorGraph,
design_names: List[str],
design_bounds: Dict[str, Tuple[float, float]],
design_mean: torch.Tensor,
design_std: torch.Tensor,
device: str = 'cpu',
disp_scale: float = 1.0
):
self.model = model.to(device)
self.model.eval()
self.polar_graph = polar_graph
self.design_names = design_names
self.design_bounds = design_bounds
self.design_mean = design_mean.to(device)
self.design_std = design_std.to(device)
self.device = torch.device(device)
self.disp_scale = disp_scale # Scaling factor from training
# Prepare fixed graph tensors
self.node_features = torch.tensor(
polar_graph.get_node_features(normalized=True),
dtype=torch.float32
).to(device)
self.edge_index = torch.tensor(
polar_graph.edge_index,
dtype=torch.long
).to(device)
self.edge_attr = torch.tensor(
polar_graph.get_edge_features(normalized=True),
dtype=torch.float32
).to(device)
# Objective computation layer (must be on same device as model)
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes=50).to(device)
@classmethod
def from_checkpoint(
cls,
checkpoint_path: Path,
config_path: Optional[Path] = None,
device: str = 'auto'
) -> 'ZernikeGNNOptimizer':
"""
Load optimizer from trained checkpoint.
Args:
checkpoint_path: Path to zernike_gnn_checkpoint.pt
config_path: Path to optimization_config.json (for design bounds)
device: Device to use ('cpu', 'cuda', 'auto')
"""
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_path = Path(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Create polar graph
polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
# Create model - handle both old ('model_config') and new ('config') format
model_config = checkpoint.get('model_config') or checkpoint.get('config', {})
model = create_model(**model_config)
model.load_state_dict(checkpoint['model_state_dict'])
# Get design info from checkpoint
design_mean = checkpoint['design_mean']
design_std = checkpoint['design_std']
disp_scale = checkpoint.get('disp_scale', 1.0) # Displacement scaling factor
# Try to get design names and bounds from config
design_names = []
design_bounds = {}
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
config = json.load(f)
for var in config.get('design_variables', []):
name = var['name']
design_names.append(name)
design_bounds[name] = (var['min'], var['max'])
else:
# Use generic names based on checkpoint
n_vars = len(design_mean)
design_names = [f'var_{i}' for i in range(n_vars)]
# Default bounds (will be overridden if config provided)
for name in design_names:
design_bounds[name] = (-100, 100)
return cls(
model=model,
polar_graph=polar_graph,
design_names=design_names,
design_bounds=design_bounds,
design_mean=design_mean,
design_std=design_std,
device=device,
disp_scale=disp_scale
)
@torch.no_grad()
def predict(self, design_vars: Dict[str, float], return_field: bool = False) -> GNNPrediction:
"""
Predict objectives for a single design.
Args:
design_vars: Dict mapping variable names to values
return_field: Whether to include displacement field in result
Returns:
GNNPrediction with objectives
"""
# Convert to tensor
design_values = [design_vars.get(name, 0.0) for name in self.design_names]
design_tensor = torch.tensor(design_values, dtype=torch.float32).to(self.device)
# Normalize
design_norm = (design_tensor - self.design_mean) / self.design_std
# Forward pass
z_disp_scaled = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design_norm
) # [3000, 4] in scaled units (μm if disp_scale=1e6)
# Convert back to mm before computing objectives
# During training: z_disp_mm * disp_scale = z_disp_scaled
# So: z_disp_mm = z_disp_scaled / disp_scale
z_disp_mm = z_disp_scaled / self.disp_scale
# Compute objectives (ZernikeObjectiveLayer expects mm input)
objectives = self.objective_layer(z_disp_mm)
# Objectives are now directly in nm (no additional scaling needed)
obj_dict = {
'rel_filtered_rms_40_vs_20': objectives['rel_filtered_rms_40_vs_20'].item(),
'rel_filtered_rms_60_vs_20': objectives['rel_filtered_rms_60_vs_20'].item(),
'mfg_90_optician_workload': objectives['mfg_90_optician_workload'].item(),
}
field_data = z_disp_mm.cpu().numpy() if return_field else None
return GNNPrediction(
design_vars=design_vars,
objectives=obj_dict,
z_displacement=field_data
)
@torch.no_grad()
def predict_batch(self, designs: List[Dict[str, float]]) -> List[GNNPrediction]:
"""
Predict objectives for multiple designs (batched for efficiency).
Args:
designs: List of design variable dicts
Returns:
List of GNNPrediction
"""
results = []
for design in designs:
results.append(self.predict(design))
return results
def random_design(self) -> Dict[str, float]:
"""Generate a random design within bounds."""
design = {}
for name in self.design_names:
low, high = self.design_bounds.get(name, (-100, 100))
design[name] = np.random.uniform(low, high)
return design
def turbo_optimize(
self,
n_trials: int = 5000,
sampler: str = 'tpe',
seed: int = 42,
verbose: bool = True
) -> OptimizationResults:
"""
Run fast GNN-based optimization.
Args:
n_trials: Number of GNN trials to run
sampler: Optuna sampler ('tpe', 'random', 'cmaes')
seed: Random seed
verbose: Print progress
Returns:
OptimizationResults with all predictions
"""
np.random.seed(seed)
results = OptimizationResults()
if verbose:
print(f"\n{'='*60}")
print("GNN TURBO OPTIMIZATION")
print(f"{'='*60}")
print(f"Trials: {n_trials}")
print(f"Sampler: {sampler}")
print(f"Design variables: {len(self.design_names)}")
print(f"Device: {self.device}")
# Create Optuna study for smart sampling
if sampler == 'tpe':
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
elif sampler == 'random':
optuna_sampler = optuna.samplers.RandomSampler(seed=seed)
elif sampler == 'cmaes':
optuna_sampler = optuna.samplers.CmaEsSampler(seed=seed)
else:
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(
directions=['minimize', 'minimize', 'minimize'], # 3 objectives
sampler=optuna_sampler
)
start_time = datetime.now()
def objective(trial):
# Sample design
design = {}
for name in self.design_names:
low, high = self.design_bounds.get(name, (-100, 100))
design[name] = trial.suggest_float(name, low, high)
# Predict with GNN
pred = self.predict(design)
results.add(pred)
return (
pred.objectives['rel_filtered_rms_40_vs_20'],
pred.objectives['rel_filtered_rms_60_vs_20'],
pred.objectives['mfg_90_optician_workload']
)
# Run optimization
if verbose:
print(f"\nRunning {n_trials} GNN trials...")
optuna.logging.set_verbosity(optuna.logging.WARNING)
study.optimize(objective, n_trials=n_trials, show_progress_bar=verbose)
elapsed = (datetime.now() - start_time).total_seconds()
if verbose:
print(f"\nCompleted in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
# Compute Pareto front
pareto = results.get_pareto_front()
print(f"Pareto front: {len(pareto)} designs")
# Best by each objective
print("\nBest by objective:")
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
best = results.get_best(n=1, objective=obj)[0]
print(f" {obj}: {best.objectives[obj]:.2f} nm")
return results
def validate_with_fea(
self,
candidates: List[GNNPrediction],
study_dir: Path,
verbose: bool = True,
start_trial_num: int = 9000
) -> List[Dict]:
"""
Validate GNN predictions with actual FEA.
This runs the full NX + Nastran workflow on each candidate
to get true objective values.
Args:
candidates: GNN predictions to validate
study_dir: Path to study directory (for config and scripts)
verbose: Print progress
start_trial_num: Starting trial number for iteration folders
Returns:
List of dicts with 'gnn' and 'fea' objectives for comparison
"""
import time
import re
from optimization_engine.nx_solver import NXSolver
from optimization_engine.extractors import ZernikeExtractor
study_dir = Path(study_dir)
config_path = study_dir / "1_setup" / "optimization_config.json"
model_dir = study_dir / "1_setup" / "model"
iterations_dir = study_dir / "2_iterations"
# Load config
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
with open(config_path) as f:
config = json.load(f)
# Initialize NX Solver
nx_settings = config.get('nx_settings', {})
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
version_match = re.search(r'NX(\d+)', nx_install_dir)
nastran_version = version_match.group(1) if version_match else "2506"
solver = NXSolver(
master_model_dir=str(model_dir),
nx_install_dir=nx_install_dir,
nastran_version=nastran_version,
timeout=nx_settings.get('simulation_timeout_s', 600),
use_iteration_folders=True,
study_name="gnn_validation"
)
iterations_dir.mkdir(exist_ok=True)
results = []
if verbose:
print(f"\n{'='*60}")
print("FEA VALIDATION OF GNN PREDICTIONS")
print(f"{'='*60}")
print(f"Validating {len(candidates)} candidates")
print(f"Study: {study_dir.name}")
for i, candidate in enumerate(candidates):
trial_num = start_trial_num + i
if verbose:
print(f"\n[{i+1}/{len(candidates)}] Trial {trial_num}")
print(f" GNN predicted: 40vs20={candidate.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
# Build expression updates from design variables
expressions = {}
for var in config.get('design_variables', []):
var_name = var['name']
expr_name = var.get('expression_name', var_name)
if var_name in candidate.design_vars:
expressions[expr_name] = candidate.design_vars[var_name]
# Create iteration folder with model copies
try:
iter_folder = solver.create_iteration_folder(
iterations_base_dir=iterations_dir,
iteration_number=trial_num,
expression_updates=expressions
)
except Exception as e:
if verbose:
print(f" ERROR creating iteration folder: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'error',
'error': str(e)
})
continue
# Run simulation
sim_file = iter_folder / nx_settings.get('sim_file', 'ASSY_M1_assyfem1_sim1.sim')
solution_name = nx_settings.get('solution_name', 'Solution 1')
t_start = time.time()
try:
solve_result = solver.run_simulation(
sim_file=sim_file,
working_dir=iter_folder,
expression_updates=expressions,
solution_name=solution_name,
cleanup=False
)
except Exception as e:
if verbose:
print(f" ERROR in simulation: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'solve_error',
'error': str(e)
})
continue
solve_time = time.time() - t_start
if not solve_result['success']:
if verbose:
print(f" Solve FAILED: {solve_result.get('errors', ['Unknown'])}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'solve_failed',
'errors': solve_result.get('errors', [])
})
continue
if verbose:
print(f" Solved in {solve_time:.1f}s")
# Extract objectives using ZernikeExtractor
op2_path = solve_result['op2_file']
if op2_path is None or not Path(op2_path).exists():
if verbose:
print(f" ERROR: OP2 file not found")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'no_op2',
})
continue
try:
zernike_settings = config.get('zernike_settings', {})
extractor = ZernikeExtractor(
op2_path,
bdf_path=None,
displacement_unit=zernike_settings.get('displacement_unit', 'mm'),
n_modes=zernike_settings.get('n_modes', 50),
filter_orders=zernike_settings.get('filter_low_orders', 4)
)
ref = zernike_settings.get('reference_subcase', '2')
# Extract objectives: 40 vs 20, 60 vs 20, mfg 90
rel_40 = extractor.extract_relative("3", ref)
rel_60 = extractor.extract_relative("4", ref)
rel_90 = extractor.extract_relative("1", ref)
fea_objectives = {
'rel_filtered_rms_40_vs_20': rel_40['relative_filtered_rms_nm'],
'rel_filtered_rms_60_vs_20': rel_60['relative_filtered_rms_nm'],
'mfg_90_optician_workload': rel_90['relative_rms_filter_j1to3'],
}
except Exception as e:
if verbose:
print(f" ERROR in Zernike extraction: {e}")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': None,
'status': 'extraction_error',
'error': str(e)
})
continue
# Compute errors
errors = {}
for obj_name in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
gnn_val = candidate.objectives[obj_name]
fea_val = fea_objectives[obj_name]
errors[f'{obj_name}_abs_error'] = abs(gnn_val - fea_val)
errors[f'{obj_name}_pct_error'] = 100 * abs(gnn_val - fea_val) / max(fea_val, 0.01)
if verbose:
print(f" FEA results:")
print(f" 40vs20: {fea_objectives['rel_filtered_rms_40_vs_20']:.2f} nm "
f"(GNN: {candidate.objectives['rel_filtered_rms_40_vs_20']:.2f}, "
f"err: {errors['rel_filtered_rms_40_vs_20_pct_error']:.1f}%)")
print(f" 60vs20: {fea_objectives['rel_filtered_rms_60_vs_20']:.2f} nm "
f"(GNN: {candidate.objectives['rel_filtered_rms_60_vs_20']:.2f}, "
f"err: {errors['rel_filtered_rms_60_vs_20_pct_error']:.1f}%)")
print(f" mfg90: {fea_objectives['mfg_90_optician_workload']:.2f} nm "
f"(GNN: {candidate.objectives['mfg_90_optician_workload']:.2f}, "
f"err: {errors['mfg_90_optician_workload_pct_error']:.1f}%)")
results.append({
'design': candidate.design_vars,
'gnn_objectives': candidate.objectives,
'fea_objectives': fea_objectives,
'errors': errors,
'solve_time': solve_time,
'trial_num': trial_num,
'status': 'success'
})
# Summary
if verbose:
successful = [r for r in results if r['status'] == 'success']
print(f"\n{'='*60}")
print(f"VALIDATION SUMMARY")
print(f"{'='*60}")
print(f"Successful: {len(successful)}/{len(candidates)}")
if successful:
avg_errors = {}
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
avg_errors[obj] = np.mean([r['errors'][f'{obj}_pct_error'] for r in successful])
print(f"\nAverage prediction errors:")
print(f" 40 vs 20: {avg_errors['rel_filtered_rms_40_vs_20']:.1f}%")
print(f" 60 vs 20: {avg_errors['rel_filtered_rms_60_vs_20']:.1f}%")
print(f" mfg 90: {avg_errors['mfg_90_optician_workload']:.1f}%")
return results
def save_validation_report(
self,
validation_results: List[Dict],
output_path: Path
):
"""Save validation results to JSON file."""
report = {
'timestamp': datetime.now().isoformat(),
'n_candidates': len(validation_results),
'n_successful': len([r for r in validation_results if r['status'] == 'success']),
'results': validation_results,
}
# Compute summary statistics if we have successful results
successful = [r for r in validation_results if r['status'] == 'success']
if successful:
avg_errors = {}
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
errors = [r['errors'][f'{obj}_pct_error'] for r in successful]
avg_errors[obj] = {
'mean_pct': float(np.mean(errors)),
'std_pct': float(np.std(errors)),
'max_pct': float(np.max(errors)),
}
report['error_summary'] = avg_errors
with open(output_path, 'w') as f:
json.dump(report, f, indent=2)
print(f"Validation report saved to: {output_path}")
def main():
"""Example usage of GNN optimizer."""
import argparse
parser = argparse.ArgumentParser(description='GNN-based Zernike optimization')
parser.add_argument('checkpoint', type=Path, help='Path to GNN checkpoint')
parser.add_argument('--config', type=Path, help='Path to optimization_config.json')
parser.add_argument('--trials', type=int, default=5000, help='Number of GNN trials')
parser.add_argument('--output', '-o', type=Path, help='Output results JSON')
parser.add_argument('--top-n', type=int, default=20, help='Number of top candidates to show')
args = parser.parse_args()
# Load optimizer
print(f"Loading GNN from {args.checkpoint}...")
optimizer = ZernikeGNNOptimizer.from_checkpoint(
args.checkpoint,
config_path=args.config
)
# Run turbo optimization
results = optimizer.turbo_optimize(n_trials=args.trials)
# Show top candidates
print(f"\n{'='*60}")
print(f"TOP {args.top_n} CANDIDATES (by rel_filtered_rms_40_vs_20)")
print(f"{'='*60}")
top = results.get_best(n=args.top_n, objective='rel_filtered_rms_40_vs_20')
for i, pred in enumerate(top):
print(f"\n#{i+1}:")
print(f" 40 vs 20: {pred.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
print(f" 60 vs 20: {pred.objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
print(f" mfg_90: {pred.objectives['mfg_90_optician_workload']:.2f} nm")
# Save results
if args.output:
results.save(args.output)
print(f"\nResults saved to {args.output}")
# Show Pareto front
pareto = results.get_pareto_front()
print(f"\n{'='*60}")
print(f"PARETO FRONT: {len(pareto)} designs")
print(f"{'='*60}")
for i, pred in enumerate(pareto[:10]): # Show first 10
print(f" [{i+1}] 40vs20={pred.objectives['rel_filtered_rms_40_vs_20']:.1f}, "
f"60vs20={pred.objectives['rel_filtered_rms_60_vs_20']:.1f}, "
f"mfg={pred.objectives['mfg_90_optician_workload']:.1f}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,617 @@
"""
Polar Mirror Graph for GNN Training
====================================
This module creates a fixed polar grid graph structure for the mirror optical surface.
The key insight is that the mirror has a fixed topology (circular annulus), so we can
use a fixed graph structure regardless of FEA mesh variations.
Why Polar Grid?
1. Matches mirror geometry (annulus)
2. Same approach as extract_zernike_surface.py
3. Enables mesh-independent training
4. Edge structure respects radial/angular physics
Grid Structure:
- n_radial points from r_inner to r_outer
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
- Total nodes = n_radial × n_angular
- Edges connect radial neighbors and angular neighbors (wrap-around)
Usage:
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
# Interpolate FEA results to fixed grid
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
# Get PyTorch Geometric data
data = graph.to_pyg_data(z_disp_grid, design_vars)
"""
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
import json
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
from scipy.spatial import Delaunay
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
class PolarMirrorGraph:
"""
Fixed polar grid graph for mirror optical surface.
This creates a mesh-independent graph structure that can be used for GNN training
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
Attributes:
n_nodes: Total number of nodes (n_radial × n_angular)
r: Radial coordinates [n_nodes]
theta: Angular coordinates [n_nodes]
x: Cartesian X coordinates [n_nodes]
y: Cartesian Y coordinates [n_nodes]
edge_index: Graph edges [2, n_edges]
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
"""
def __init__(
self,
r_inner: float = 100.0,
r_outer: float = 650.0,
n_radial: int = 50,
n_angular: int = 60
):
"""
Initialize polar grid graph.
Args:
r_inner: Inner radius (central hole), mm
r_outer: Outer radius, mm
n_radial: Number of radial samples
n_angular: Number of angular samples
"""
self.r_inner = r_inner
self.r_outer = r_outer
self.n_radial = n_radial
self.n_angular = n_angular
self.n_nodes = n_radial * n_angular
# Create polar grid coordinates
r_1d = np.linspace(r_inner, r_outer, n_radial)
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
# Flatten: radial index varies slowest
self.r = R.flatten().astype(np.float32)
self.theta = Theta.flatten().astype(np.float32)
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
# Build graph edges
self.edge_index, self.edge_attr = self._build_polar_edges()
# Precompute normalization factors
self._r_mean = (r_inner + r_outer) / 2
self._r_std = (r_outer - r_inner) / 2
def _node_index(self, i_r: int, i_theta: int) -> int:
"""Convert (radial_index, angular_index) to flat node index."""
# Angular wraps around
i_theta = i_theta % self.n_angular
return i_r * self.n_angular + i_theta
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Create graph edges respecting polar topology.
Edge types:
1. Radial edges: Connect adjacent radial rings
2. Angular edges: Connect adjacent angular positions (with wrap-around)
3. Diagonal edges: Connect diagonal neighbors for better message passing
Returns:
edge_index: [2, n_edges] array of (source, target) pairs
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
"""
edges = []
edge_features = []
for i_r in range(self.n_radial):
for i_theta in range(self.n_angular):
node = self._node_index(i_r, i_theta)
# Radial neighbor (outward)
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features: (dr, dtheta, distance, relative_angle)
dr = self.r[neighbor] - self.r[node]
dtheta = 0.0
dist = abs(dr)
angle = 0.0 # Radial direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
# Angular neighbor (counterclockwise, with wrap-around)
neighbor = self._node_index(i_r, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features for angular edge
dr = 0.0
dtheta = 2 * np.pi / self.n_angular
# Arc length at this radius
dist = self.r[node] * dtheta
angle = np.pi / 2 # Tangential direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
# Diagonal neighbor (outward + counterclockwise) for better connectivity
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node])
dr = self.r[neighbor] - self.r[node]
dtheta = 2 * np.pi / self.n_angular
dx = self.x[neighbor] - self.x[node]
dy = self.y[neighbor] - self.y[node]
dist = np.sqrt(dx**2 + dy**2)
angle = np.arctan2(dy, dx)
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
return edge_index, edge_attr
def get_node_features(self, normalized: bool = True) -> np.ndarray:
"""
Get node features for GNN input.
Features: (r, theta, x, y) - polar and Cartesian coordinates
Args:
normalized: If True, normalize features to ~[-1, 1] range
Returns:
Node features [n_nodes, 4]
"""
if normalized:
r_norm = (self.r - self._r_mean) / self._r_std
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
x_norm = self.x / self.r_outer
y_norm = self.y / self.r_outer
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
else:
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
"""
Get edge features for GNN input.
Features: (dr, dtheta, distance, angle)
Args:
normalized: If True, normalize features
Returns:
Edge features [n_edges, 4]
"""
if normalized:
edge_attr = self.edge_attr.copy()
edge_attr[:, 0] /= self._r_std # dr
edge_attr[:, 1] /= np.pi # dtheta
edge_attr[:, 2] /= self.r_outer # distance
edge_attr[:, 3] /= np.pi # angle
return edge_attr
else:
return self.edge_attr
def interpolate_from_mesh(
self,
mesh_coords: np.ndarray,
mesh_values: np.ndarray,
method: str = 'rbf'
) -> np.ndarray:
"""
Interpolate FEA results from mesh nodes to fixed polar grid.
Args:
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
Returns:
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
"""
if not HAS_SCIPY:
raise ImportError("scipy required for interpolation: pip install scipy")
# Use only X, Y coordinates
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
# Handle multi-dimensional values
values_1d = mesh_values.ndim == 1
if values_1d:
mesh_values = mesh_values.reshape(-1, 1)
# Target coordinates
target_xy = np.column_stack([self.x, self.y])
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
for i in range(mesh_values.shape[1]):
vals = mesh_values[:, i]
if method == 'rbf':
# RBF interpolation - smooth, handles scattered data well
interp = RBFInterpolator(
xy, vals,
kernel='thin_plate_spline',
smoothing=0.0
)
result[:, i] = interp(target_xy)
elif method == 'linear':
# Linear interpolation via Delaunay triangulation
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN (points outside convex hull) with nearest neighbor
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
elif method == 'clough_tocher':
# Clough-Tocher (C1 smooth) interpolation
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
else:
raise ValueError(f"Unknown interpolation method: {method}")
return result[:, 0] if values_1d else result
def interpolate_field_data(
self,
field_data: Dict[str, Any],
subcases: List[int] = [1, 2, 3, 4],
method: str = 'linear' # Changed from 'rbf' - much faster
) -> Dict[str, np.ndarray]:
"""
Interpolate field data from extract_displacement_field() to polar grid.
Args:
field_data: Output from extract_displacement_field()
subcases: List of subcases to interpolate
method: Interpolation method
Returns:
Dictionary with:
- z_displacement: [n_nodes, n_subcases] array
- original_n_nodes: Number of FEA nodes
"""
mesh_coords = field_data['node_coords']
z_disp_dict = field_data['z_displacement']
# Stack subcases
z_disp_list = []
for sc in subcases:
if sc in z_disp_dict:
z_disp_list.append(z_disp_dict[sc])
else:
raise KeyError(f"Subcase {sc} not found in field_data")
# [n_fea_nodes, n_subcases]
z_disp_mesh = np.column_stack(z_disp_list)
# Interpolate to polar grid
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
return {
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
'original_n_nodes': len(mesh_coords),
}
def to_pyg_data(
self,
z_displacement: np.ndarray,
design_vars: np.ndarray,
objectives: Optional[Dict[str, float]] = None
):
"""
Convert to PyTorch Geometric Data object.
Args:
z_displacement: [n_nodes, n_subcases] displacement field
design_vars: [n_design_vars] design parameters
objectives: Optional dict of objective values (ground truth)
Returns:
torch_geometric.data.Data object
"""
if not HAS_TORCH:
raise ImportError("PyTorch required: pip install torch")
try:
from torch_geometric.data import Data
except ImportError:
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
# Node features: (r, theta, x, y)
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
# Edge index and features
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
# Target: Z-displacement field
y = torch.tensor(z_displacement, dtype=torch.float32)
# Design variables (global feature)
design = torch.tensor(design_vars, dtype=torch.float32)
data = Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
y=y,
design=design,
)
# Add objectives if provided
if objectives:
for key, value in objectives.items():
setattr(data, key, torch.tensor([value], dtype=torch.float32))
return data
def save(self, path: Path) -> None:
"""Save graph structure to JSON file."""
path = Path(path)
data = {
'r_inner': self.r_inner,
'r_outer': self.r_outer,
'n_radial': self.n_radial,
'n_angular': self.n_angular,
'n_nodes': self.n_nodes,
'n_edges': self.edge_index.shape[1],
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
# Save arrays separately for efficiency
np.savez_compressed(
path.with_suffix('.npz'),
r=self.r,
theta=self.theta,
x=self.x,
y=self.y,
edge_index=self.edge_index,
edge_attr=self.edge_attr,
)
@classmethod
def load(cls, path: Path) -> 'PolarMirrorGraph':
"""Load graph structure from file."""
path = Path(path)
with open(path, 'r') as f:
data = json.load(f)
return cls(
r_inner=data['r_inner'],
r_outer=data['r_outer'],
n_radial=data['n_radial'],
n_angular=data['n_angular'],
)
def __repr__(self) -> str:
return (
f"PolarMirrorGraph("
f"r=[{self.r_inner}, {self.r_outer}]mm, "
f"grid={self.n_radial}×{self.n_angular}, "
f"nodes={self.n_nodes}, "
f"edges={self.edge_index.shape[1]})"
)
# =============================================================================
# Convenience functions
# =============================================================================
def create_mirror_dataset(
study_dir: Path,
polar_graph: Optional[PolarMirrorGraph] = None,
subcases: List[int] = [1, 2, 3, 4],
verbose: bool = True
) -> List[Dict[str, Any]]:
"""
Create GNN dataset from a study's gnn_data folder.
Args:
study_dir: Path to study directory
polar_graph: PolarMirrorGraph instance (created if None)
subcases: Subcases to include
verbose: Print progress
Returns:
List of data dictionaries, each containing:
- z_displacement: [n_nodes, n_subcases]
- design_vars: [n_vars]
- trial_number: int
- original_n_nodes: int
"""
from optimization_engine.gnn.extract_displacement_field import load_field
study_dir = Path(study_dir)
gnn_data_dir = study_dir / "gnn_data"
if not gnn_data_dir.exists():
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
# Load index
index_path = gnn_data_dir / "dataset_index.json"
with open(index_path, 'r') as f:
index = json.load(f)
if polar_graph is None:
polar_graph = PolarMirrorGraph()
dataset = []
for trial_num, trial_info in index['trials'].items():
if trial_info.get('status') != 'success':
continue
trial_dir = study_dir / trial_info['trial_dir']
# Find field file
field_path = None
for ext in ['.h5', '.npz']:
candidate = trial_dir / f"displacement_field{ext}"
if candidate.exists():
field_path = candidate
break
if field_path is None:
if verbose:
print(f"[WARN] No field file for trial {trial_num}")
continue
try:
# Load field data
field_data = load_field(field_path)
# Interpolate to polar grid
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
# Get design parameters
params = trial_info.get('params', {})
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
dataset.append({
'z_displacement': interp_result['z_displacement'],
'design_vars': design_vars,
'design_names': list(params.keys()) if params else [],
'trial_number': int(trial_num),
'original_n_nodes': interp_result['original_n_nodes'],
})
if verbose:
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']}{polar_graph.n_nodes} nodes")
except Exception as e:
if verbose:
print(f"[ERR] Trial {trial_num}: {e}")
return dataset
# =============================================================================
# CLI
# =============================================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
parser.add_argument('--test', action='store_true', help='Run basic tests')
parser.add_argument('--study', type=Path, help='Create dataset from study')
args = parser.parse_args()
if args.test:
print("="*60)
print("TESTING PolarMirrorGraph")
print("="*60)
# Create graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\n{graph}")
# Check node features
node_feat = graph.get_node_features(normalized=True)
print(f"\nNode features shape: {node_feat.shape}")
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
# Check edge features
edge_feat = graph.get_edge_features(normalized=True)
print(f"\nEdge features shape: {edge_feat.shape}")
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
# Test interpolation with synthetic data
print("\n--- Testing Interpolation ---")
# Create fake mesh data (random points in annulus)
np.random.seed(42)
n_mesh = 5000
r_mesh = np.random.uniform(100, 650, n_mesh)
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
x_mesh = r_mesh * np.cos(theta_mesh)
y_mesh = r_mesh * np.sin(theta_mesh)
mesh_coords = np.column_stack([x_mesh, y_mesh])
# Synthetic displacement: smooth function
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
# Interpolate
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
print("\n✓ All tests passed!")
elif args.study:
# Create dataset from study
print(f"Creating dataset from: {args.study}")
graph = PolarMirrorGraph()
dataset = create_mirror_dataset(args.study, polar_graph=graph)
print(f"\nDataset: {len(dataset)} samples")
if dataset:
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
else:
# Default: just show info
graph = PolarMirrorGraph()
print(graph)
print(f"\nNode features: {graph.get_node_features().shape}")
print(f"Edge index: {graph.edge_index.shape}")
print(f"Edge features: {graph.edge_attr.shape}")

View File

@@ -0,0 +1,37 @@
"""Quick test script for displacement field extraction."""
import h5py
import numpy as np
from pathlib import Path
# Test file
h5_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/gnn_data/trial_0091/displacement_field.h5")
print(f"Testing: {h5_path}")
print(f"Exists: {h5_path.exists()}")
if h5_path.exists():
with h5py.File(h5_path, 'r') as f:
print(f"\nDatasets in file: {list(f.keys())}")
node_coords = f['node_coords'][:]
node_ids = f['node_ids'][:]
print(f"\nTotal nodes: {len(node_ids)}")
# Calculate radial position
r = np.sqrt(node_coords[:, 0]**2 + node_coords[:, 1]**2)
print(f"Radial range: [{r.min():.1f}, {r.max():.1f}] mm")
print(f"Z range: [{node_coords[:, 2].min():.1f}, {node_coords[:, 2].max():.1f}] mm")
# Check nodes in optical surface range (100-650 mm radius)
surface_mask = (r >= 100) & (r <= 650)
print(f"Nodes in r=[100, 650]: {np.sum(surface_mask)}")
# Check subcases
subcases = [k for k in f.keys() if k.startswith("subcase_")]
print(f"Subcases: {subcases}")
if subcases:
for sc in subcases:
disp = f[sc][:]
print(f" {sc}: Z-disp range [{disp.min():.4f}, {disp.max():.4f}] mm")

View File

@@ -0,0 +1,35 @@
"""Test the fixed extraction function directly on OP2."""
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
import numpy as np
from pathlib import Path
from optimization_engine.gnn.extract_displacement_field import extract_displacement_field
# Test direct extraction from OP2
op2_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/2_iterations/iter91/assy_m1_assyfem1_sim1-solution_1.op2")
print(f"Testing extraction from: {op2_path.name}")
print(f"Exists: {op2_path.exists()}")
if op2_path.exists():
field_data = extract_displacement_field(op2_path, r_inner=100.0, r_outer=650.0)
print(f"\n=== EXTRACTION RESULT ===")
print(f"Total surface nodes: {len(field_data['node_ids'])}")
coords = field_data['node_coords']
r = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
print(f"Radial range: [{r.min():.1f}, {r.max():.1f}] mm")
print(f"Z range: [{coords[:, 2].min():.1f}, {coords[:, 2].max():.1f}] mm")
print(f"\nSubcases: {list(field_data['z_displacement'].keys())}")
for sc, disp in field_data['z_displacement'].items():
nan_count = np.sum(np.isnan(disp))
if nan_count == 0:
print(f" Subcase {sc}: Z-disp range [{disp.min():.6f}, {disp.max():.6f}] mm")
else:
valid = disp[~np.isnan(disp)]
print(f" Subcase {sc}: {nan_count}/{len(disp)} NaN values, valid range: [{valid.min():.6f}, {valid.max():.6f}]")
else:
print("OP2 file not found!")

View File

@@ -0,0 +1,108 @@
"""Test PolarMirrorGraph with actual V11 data."""
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
import numpy as np
from pathlib import Path
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
from optimization_engine.gnn.extract_displacement_field import load_field
# Test 1: Basic graph construction
print("="*60)
print("TEST 1: Graph Construction")
print("="*60)
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\n{graph}")
node_feat = graph.get_node_features(normalized=True)
edge_feat = graph.get_edge_features(normalized=True)
print(f"\nNode features: {node_feat.shape}")
print(f" r normalized: [{node_feat[:, 0].min():.3f}, {node_feat[:, 0].max():.3f}]")
print(f" theta normalized: [{node_feat[:, 1].min():.3f}, {node_feat[:, 1].max():.3f}]")
print(f" x normalized: [{node_feat[:, 2].min():.3f}, {node_feat[:, 2].max():.3f}]")
print(f" y normalized: [{node_feat[:, 3].min():.3f}, {node_feat[:, 3].max():.3f}]")
print(f"\nEdge features: {edge_feat.shape}")
print(f" Edges per node: {edge_feat.shape[0] / graph.n_nodes:.1f}")
# Test 2: Load actual V11 field data and interpolate
print("\n" + "="*60)
print("TEST 2: Interpolation from V11 Data")
print("="*60)
field_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/gnn_data/trial_0091/displacement_field.h5")
if field_path.exists():
field_data = load_field(field_path)
print(f"\nLoaded field data:")
print(f" FEA nodes: {len(field_data['node_ids'])}")
print(f" Subcases: {list(field_data['z_displacement'].keys())}")
# Interpolate to polar grid
result = graph.interpolate_field_data(field_data, subcases=[1, 2, 3, 4])
z_grid = result['z_displacement']
print(f"\nInterpolation result:")
print(f" Shape: {z_grid.shape} (expected: {graph.n_nodes} × 4)")
print(f" NaN count: {np.sum(np.isnan(z_grid))}")
for i, sc in enumerate([1, 2, 3, 4]):
disp = z_grid[:, i]
print(f" Subcase {sc}: [{disp.min():.6f}, {disp.max():.6f}] mm")
# Test relative deformation computation
print("\n--- Relative Deformations (like Zernike extraction) ---")
disp_90 = z_grid[:, 0] # Subcase 1 = 90°
disp_20 = z_grid[:, 1] # Subcase 2 = 20° (reference)
disp_40 = z_grid[:, 2] # Subcase 3 = 40°
disp_60 = z_grid[:, 3] # Subcase 4 = 60°
rel_40_vs_20 = disp_40 - disp_20
rel_60_vs_20 = disp_60 - disp_20
rel_90_vs_20 = disp_90 - disp_20
print(f" 40° - 20°: [{rel_40_vs_20.min():.6f}, {rel_40_vs_20.max():.6f}] mm, RMS={np.std(rel_40_vs_20)*1e6:.2f} nm")
print(f" 60° - 20°: [{rel_60_vs_20.min():.6f}, {rel_60_vs_20.max():.6f}] mm, RMS={np.std(rel_60_vs_20)*1e6:.2f} nm")
print(f" 90° - 20°: [{rel_90_vs_20.min():.6f}, {rel_90_vs_20.max():.6f}] mm, RMS={np.std(rel_90_vs_20)*1e6:.2f} nm")
else:
print(f"Field file not found: {field_path}")
# Test 3: Create full dataset from V11
print("\n" + "="*60)
print("TEST 3: Create Dataset from V11")
print("="*60)
study_dir = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11")
if (study_dir / "gnn_data").exists():
dataset = create_mirror_dataset(study_dir, polar_graph=graph, verbose=True)
print(f"\n--- Dataset Summary ---")
print(f"Total samples: {len(dataset)}")
if dataset:
# Check consistency
shapes = [d['z_displacement'].shape for d in dataset]
unique_shapes = set(shapes)
print(f"Unique shapes: {unique_shapes}")
# Design variable info
n_vars = len(dataset[0]['design_vars'])
print(f"Design variables: {n_vars}")
if dataset[0]['design_names']:
print(f" Names: {dataset[0]['design_names'][:3]}...")
# Stack for statistics
all_z = np.stack([d['z_displacement'] for d in dataset])
print(f"\nAll data shape: {all_z.shape}")
print(f" Per-subcase ranges:")
for i in range(4):
print(f" Subcase {i+1}: [{all_z[:,:,i].min():.6f}, {all_z[:,:,i].max():.6f}] mm")
else:
print(f"No gnn_data folder found in {study_dir}")
print("\n" + "="*60)
print("✓ All tests completed!")
print("="*60)

View File

@@ -0,0 +1,600 @@
"""
Training Pipeline for ZernikeGNN
=================================
This module provides the complete training pipeline for the Zernike GNN surrogate.
Training Flow:
1. Load displacement field data from gnn_data/ folders
2. Interpolate to fixed polar grid
3. Normalize inputs (design vars) and outputs (displacements)
4. Train with multi-task loss (field + objectives)
5. Validate on held-out data
6. Save best model checkpoint
Usage:
# Command line
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
# Python API
from optimization_engine.gnn.train_zernike_gnn import ZernikeGNNTrainer
trainer = ZernikeGNNTrainer(['V11', 'V12'])
trainer.train(epochs=200)
trainer.save_checkpoint('model.pt')
"""
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
from optimization_engine.gnn.zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer, ZernikeRMSLoss
class MirrorDataset(Dataset):
"""PyTorch Dataset for mirror displacement fields."""
def __init__(
self,
data_list: List[Dict[str, Any]],
design_mean: Optional[torch.Tensor] = None,
design_std: Optional[torch.Tensor] = None,
disp_scale: float = 1e6 # mm → μm for numerical stability
):
"""
Args:
data_list: Output from create_mirror_dataset()
design_mean: Mean for design normalization (computed if None)
design_std: Std for design normalization (computed if None)
disp_scale: Scale factor for displacements
"""
self.data_list = data_list
self.disp_scale = disp_scale
# Stack all design variables for normalization
all_designs = np.stack([d['design_vars'] for d in data_list])
if design_mean is None:
self.design_mean = torch.tensor(np.mean(all_designs, axis=0), dtype=torch.float32)
else:
self.design_mean = design_mean
if design_std is None:
self.design_std = torch.tensor(np.std(all_designs, axis=0) + 1e-6, dtype=torch.float32)
else:
self.design_std = design_std
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
item = self.data_list[idx]
# Normalize design variables
design = torch.tensor(item['design_vars'], dtype=torch.float32)
design_norm = (design - self.design_mean) / self.design_std
# Scale displacements for numerical stability
z_disp = torch.tensor(item['z_displacement'], dtype=torch.float32)
z_disp_scaled = z_disp * self.disp_scale
return {
'design': design_norm,
'design_raw': design,
'z_displacement': z_disp_scaled,
'trial_number': item['trial_number'],
}
class ZernikeGNNTrainer:
"""
Complete training pipeline for ZernikeGNN.
Handles:
- Data loading and preprocessing
- Model initialization
- Training loop with validation
- Checkpointing
- Metrics tracking
"""
def __init__(
self,
study_versions: List[str],
base_dir: Optional[Path] = None,
model_type: str = 'full',
hidden_dim: int = 128,
n_layers: int = 6,
device: str = 'auto'
):
"""
Args:
study_versions: List of study versions (e.g., ['V11', 'V12'])
base_dir: Base Atomizer directory
model_type: 'full' or 'lite'
hidden_dim: Model hidden dimension
n_layers: Number of message passing layers
device: 'cpu', 'cuda', or 'auto'
"""
if base_dir is None:
base_dir = Path(__file__).parent.parent.parent
self.base_dir = Path(base_dir)
self.study_versions = study_versions
# Determine device
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"[TRAINER] Device: {self.device}", flush=True)
# Create polar graph (fixed structure)
self.polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"[TRAINER] Polar graph: {self.polar_graph.n_nodes} nodes, {self.polar_graph.edge_index.shape[1]} edges", flush=True)
# Prepare graph tensors
self.node_features = torch.tensor(
self.polar_graph.get_node_features(normalized=True),
dtype=torch.float32
).to(self.device)
self.edge_index = torch.tensor(
self.polar_graph.edge_index,
dtype=torch.long
).to(self.device)
self.edge_attr = torch.tensor(
self.polar_graph.get_edge_features(normalized=True),
dtype=torch.float32
).to(self.device)
# Load data
self._load_data()
# Create model
self.model_config = {
'model_type': model_type,
'n_design_vars': len(self.train_dataset.data_list[0]['design_vars']),
'n_subcases': 4,
'hidden_dim': hidden_dim,
'n_layers': n_layers,
}
self.model = create_model(**self.model_config).to(self.device)
print(f"[TRAINER] Model: {self.model.__class__.__name__} with {sum(p.numel() for p in self.model.parameters()):,} parameters", flush=True)
# Objective layer for evaluation
self.objective_layer = ZernikeObjectiveLayer(self.polar_graph, n_modes=50)
# Training state
self.best_val_loss = float('inf')
self.history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
def _load_data(self):
"""Load and prepare training data from studies."""
all_data = []
for version in self.study_versions:
study_dir = self.base_dir / "studies" / f"m1_mirror_adaptive_{version}"
if not study_dir.exists():
print(f"[WARN] Study not found: {study_dir}", flush=True)
continue
print(f"[TRAINER] Loading data from {study_dir.name}...", flush=True)
dataset = create_mirror_dataset(study_dir, polar_graph=self.polar_graph, verbose=True)
print(f"[TRAINER] Loaded {len(dataset)} samples", flush=True)
all_data.extend(dataset)
if not all_data:
raise ValueError("No data loaded!")
print(f"[TRAINER] Total samples: {len(all_data)}", flush=True)
# Train/val split (80/20)
np.random.seed(42)
indices = np.random.permutation(len(all_data))
n_train = int(0.8 * len(all_data))
train_data = [all_data[i] for i in indices[:n_train]]
val_data = [all_data[i] for i in indices[n_train:]]
print(f"[TRAINER] Train: {len(train_data)}, Val: {len(val_data)}", flush=True)
# Create datasets
self.train_dataset = MirrorDataset(train_data)
self.val_dataset = MirrorDataset(
val_data,
design_mean=self.train_dataset.design_mean,
design_std=self.train_dataset.design_std
)
# Store normalization params for inference
self.design_mean = self.train_dataset.design_mean
self.design_std = self.train_dataset.design_std
self.disp_scale = self.train_dataset.disp_scale
def train(
self,
epochs: int = 200,
lr: float = 1e-3,
weight_decay: float = 1e-5,
batch_size: int = 4,
field_weight: float = 1.0,
patience: int = 50,
verbose: bool = True
):
"""
Train the GNN model.
Args:
epochs: Number of training epochs
lr: Learning rate
weight_decay: Weight decay for regularization
batch_size: Training batch size
field_weight: Weight for field loss
patience: Early stopping patience
verbose: Print training progress
"""
# Create data loaders
train_loader = DataLoader(
self.train_dataset, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(
self.val_dataset, batch_size=batch_size, shuffle=False
)
# Optimizer
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Training loop
no_improve = 0
for epoch in range(epochs):
# Training
self.model.train()
train_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
# Move to device
design = batch['design'].to(self.device)
z_disp_true = batch['z_displacement'].to(self.device)
# Forward pass for each sample in batch
batch_loss = 0.0
for i in range(design.size(0)):
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design[i]
)
# MSE loss on displacement field
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
batch_loss = batch_loss + loss
batch_loss = batch_loss / design.size(0)
batch_loss.backward()
optimizer.step()
train_loss += batch_loss.item()
train_loss /= len(train_loader)
scheduler.step()
# Validation
val_loss, val_metrics = self._validate(val_loader)
# Track history
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['val_r2'].append(val_metrics.get('r2_mean', 0))
# Early stopping
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
no_improve = 0
else:
no_improve += 1
if verbose and epoch % 10 == 0:
print(f"[Epoch {epoch:3d}] Train: {train_loss:.6f}, Val: {val_loss:.6f}, "
f"R²: {val_metrics.get('r2_mean', 0):.4f}, LR: {scheduler.get_last_lr()[0]:.2e}", flush=True)
if no_improve >= patience:
print(f"[TRAINER] Early stopping at epoch {epoch}", flush=True)
break
# Restore best model
self.model.load_state_dict(self.best_model_state)
print(f"[TRAINER] Training complete. Best val loss: {self.best_val_loss:.6f}", flush=True)
def _validate(self, val_loader: DataLoader) -> Tuple[float, Dict[str, float]]:
"""Run validation and compute metrics."""
self.model.eval()
val_loss = 0.0
all_pred = []
all_true = []
with torch.no_grad():
for batch in val_loader:
design = batch['design'].to(self.device)
z_disp_true = batch['z_displacement'].to(self.device)
for i in range(design.size(0)):
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design[i]
)
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
val_loss += loss.item()
all_pred.append(z_disp_pred.cpu())
all_true.append(z_disp_true[i].cpu())
val_loss /= len(self.val_dataset)
# Compute R² for each subcase
all_pred = torch.stack(all_pred) # [n_val, n_nodes, 4]
all_true = torch.stack(all_true)
r2_per_subcase = []
for sc in range(4):
pred_flat = all_pred[:, :, sc].flatten()
true_flat = all_true[:, :, sc].flatten()
ss_res = ((true_flat - pred_flat) ** 2).sum()
ss_tot = ((true_flat - true_flat.mean()) ** 2).sum()
r2 = 1 - ss_res / (ss_tot + 1e-8)
r2_per_subcase.append(r2.item())
metrics = {
'r2_mean': np.mean(r2_per_subcase),
'r2_per_subcase': r2_per_subcase,
}
return val_loss, metrics
def evaluate_objectives(self) -> Dict[str, Any]:
"""
Evaluate objective prediction accuracy on validation set.
Returns:
Dictionary with per-objective metrics
"""
self.model.eval()
obj_pred_all = {k: [] for k in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']}
obj_true_all = {k: [] for k in obj_pred_all}
# Move objective layer to CPU for now (small dataset)
with torch.no_grad():
for i in range(len(self.val_dataset)):
item = self.val_dataset[i]
design = item['design'].to(self.device)
z_disp_true = item['z_displacement'] # Already scaled
# Predict
z_disp_pred = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design
).cpu()
# Unscale for objective computation
z_disp_pred_mm = z_disp_pred / self.disp_scale
z_disp_true_mm = z_disp_true / self.disp_scale
# Compute objectives
obj_pred = self.objective_layer(z_disp_pred_mm)
obj_true = self.objective_layer(z_disp_true_mm)
for k in obj_pred_all:
obj_pred_all[k].append(obj_pred[k].item())
obj_true_all[k].append(obj_true[k].item())
# Compute metrics per objective
results = {}
for k in obj_pred_all:
pred = np.array(obj_pred_all[k])
true = np.array(obj_true_all[k])
mae = np.mean(np.abs(pred - true))
mape = np.mean(np.abs(pred - true) / (np.abs(true) + 1e-6)) * 100
ss_res = np.sum((true - pred) ** 2)
ss_tot = np.sum((true - np.mean(true)) ** 2)
r2 = 1 - ss_res / (ss_tot + 1e-8)
results[k] = {
'mae': mae,
'mape': mape,
'r2': r2,
'pred_range': [pred.min(), pred.max()],
'true_range': [true.min(), true.max()],
}
return results
def save_checkpoint(self, path: Path) -> None:
"""Save model checkpoint."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
'model_state_dict': self.model.state_dict(),
'config': self.model_config,
'design_mean': self.design_mean,
'design_std': self.design_std,
'disp_scale': self.disp_scale,
'history': self.history,
'best_val_loss': self.best_val_loss,
'study_versions': self.study_versions,
'timestamp': datetime.now().isoformat(),
}
torch.save(checkpoint, path)
print(f"[TRAINER] Saved checkpoint to {path}", flush=True)
@classmethod
def load_checkpoint(cls, path: Path, device: str = 'auto') -> 'ZernikeGNNTrainer':
"""Load trainer from checkpoint."""
checkpoint = torch.load(path, map_location='cpu')
# Create trainer with same config
trainer = cls(
study_versions=checkpoint['study_versions'],
model_type=checkpoint['config']['model_type'],
hidden_dim=checkpoint['config']['hidden_dim'],
n_layers=checkpoint['config']['n_layers'],
device=device,
)
# Load model weights
trainer.model.load_state_dict(checkpoint['model_state_dict'])
# Restore normalization
trainer.design_mean = checkpoint['design_mean']
trainer.design_std = checkpoint['design_std']
trainer.disp_scale = checkpoint['disp_scale']
# Restore history
trainer.history = checkpoint['history']
trainer.best_val_loss = checkpoint['best_val_loss']
return trainer
def predict(self, design_vars: Dict[str, float]) -> Dict[str, Any]:
"""
Make prediction for new design.
Args:
design_vars: Dictionary of design parameter values
Returns:
Dictionary with displacement field and objectives
"""
self.model.eval()
# Convert to tensor
design_names = self.train_dataset.data_list[0]['design_names']
design = torch.tensor(
[design_vars[name] for name in design_names],
dtype=torch.float32
)
# Normalize
design_norm = (design - self.design_mean) / self.design_std
with torch.no_grad():
z_disp_scaled = self.model(
self.node_features,
self.edge_index,
self.edge_attr,
design_norm.to(self.device)
).cpu()
# Unscale
z_disp_mm = z_disp_scaled / self.disp_scale
# Compute objectives
objectives = self.objective_layer(z_disp_mm)
return {
'z_displacement': z_disp_mm.numpy(),
'objectives': {k: v.item() for k, v in objectives.items()},
}
# =============================================================================
# CLI
# =============================================================================
def main():
import argparse
parser = argparse.ArgumentParser(description='Train ZernikeGNN surrogate')
parser.add_argument('studies', nargs='+', help='Study versions (e.g., V11 V12)')
parser.add_argument('--epochs', type=int, default=200, help='Training epochs')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension')
parser.add_argument('--n-layers', type=int, default=6, help='Message passing layers')
parser.add_argument('--model-type', choices=['full', 'lite'], default='full')
parser.add_argument('--output', '-o', type=Path, help='Output checkpoint path')
parser.add_argument('--device', default='auto', help='Device (cpu, cuda, auto)')
args = parser.parse_args()
# Create trainer
print("="*60, flush=True)
print("ZERNIKE GNN TRAINING", flush=True)
print("="*60, flush=True)
trainer = ZernikeGNNTrainer(
study_versions=args.studies,
model_type=args.model_type,
hidden_dim=args.hidden_dim,
n_layers=args.n_layers,
device=args.device,
)
# Train
trainer.train(
epochs=args.epochs,
lr=args.lr,
batch_size=args.batch_size,
)
# Evaluate objectives
print("\n--- Objective Prediction Evaluation ---", flush=True)
obj_results = trainer.evaluate_objectives()
for k, v in obj_results.items():
print(f"\n{k}:", flush=True)
print(f" MAE: {v['mae']:.2f} nm", flush=True)
print(f" MAPE: {v['mape']:.1f}%", flush=True)
print(f" R²: {v['r2']:.4f}", flush=True)
# Save checkpoint
if args.output:
output_path = args.output
else:
output_path = Path("zernike_gnn_checkpoint.pt")
trainer.save_checkpoint(output_path)
print("\n" + "="*60, flush=True)
print("✓ Training complete!", flush=True)
print("="*60, flush=True)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,582 @@
"""
Zernike GNN Model for Mirror Surface Deformation Prediction
============================================================
This module implements a Graph Neural Network specifically designed for predicting
mirror surface displacement fields from design parameters. The key innovation is
using design-conditioned message passing on a polar grid graph.
Architecture:
Design Variables [11]
Design Encoder [11 → 128]
└──────────────────┐
Node Features │
[r, θ, x, y] │
│ │
▼ │
Node Encoder │
[4 → 128] │
│ │
└─────────┬────────┘
┌─────────────────────────────┐
│ Design-Conditioned │
│ Message Passing (× 6) │
│ │
│ • Polar-aware edges │
│ • Design modulates messages │
│ • Residual connections │
└─────────────┬───────────────┘
Per-Node Decoder [128 → 4]
Z-Displacement Field [3000, 4]
(one value per node per subcase)
Usage:
from optimization_engine.gnn.zernike_gnn import ZernikeGNN
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph()
model = ZernikeGNN(n_design_vars=11, n_subcases=4)
# Forward pass
z_disp = model(
node_features=graph.get_node_features(),
edge_index=graph.edge_index,
edge_attr=graph.get_edge_features(),
design_vars=design_tensor
)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
try:
from torch_geometric.nn import MessagePassing
HAS_PYG = True
except ImportError:
HAS_PYG = False
MessagePassing = nn.Module # Fallback for type hints
class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module):
"""
Message passing layer conditioned on global design parameters.
This layer propagates information through the polar graph while
conditioning on design parameters. The design embedding modulates
how messages flow between nodes.
Key insight: Design parameters affect the stiffness distribution
in the mirror support structure. This layer learns how those changes
propagate spatially through the optical surface.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
design_channels: int,
edge_channels: int = 4,
aggr: str = 'mean'
):
"""
Args:
in_channels: Input node feature dimension
out_channels: Output node feature dimension
design_channels: Design embedding dimension
edge_channels: Edge feature dimension
aggr: Aggregation method ('mean', 'sum', 'max')
"""
if HAS_PYG:
super().__init__(aggr=aggr)
else:
super().__init__()
self.aggr = aggr
self.in_channels = in_channels
self.out_channels = out_channels
# Message network: source node + target node + design + edge
msg_input_dim = 2 * in_channels + design_channels + edge_channels
self.message_net = nn.Sequential(
nn.Linear(msg_input_dim, out_channels * 2),
nn.LayerNorm(out_channels * 2),
nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(out_channels * 2, out_channels),
)
# Update network: combines aggregated messages with original features
self.update_net = nn.Sequential(
nn.Linear(in_channels + out_channels, out_channels),
nn.LayerNorm(out_channels),
nn.SiLU(),
)
# Design gate: allows design to modulate message importance
self.design_gate = nn.Sequential(
nn.Linear(design_channels, out_channels),
nn.Sigmoid(),
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design_embed: torch.Tensor
) -> torch.Tensor:
"""
Forward pass with design conditioning.
Args:
x: Node features [n_nodes, in_channels]
edge_index: Graph connectivity [2, n_edges]
edge_attr: Edge features [n_edges, edge_channels]
design_embed: Design embedding [design_channels]
Returns:
Updated node features [n_nodes, out_channels]
"""
if HAS_PYG:
# Use PyG's message passing
out = self.propagate(
edge_index, x=x, edge_attr=edge_attr, design=design_embed
)
else:
# Fallback implementation without PyG
out = self._manual_propagate(x, edge_index, edge_attr, design_embed)
# Apply design-based gating
gate = self.design_gate(design_embed)
out = out * gate
return out
def message(
self,
x_i: torch.Tensor,
x_j: torch.Tensor,
edge_attr: torch.Tensor,
design: torch.Tensor
) -> torch.Tensor:
"""
Compute messages from source (j) to target (i) nodes.
Args:
x_i: Target node features [n_edges, in_channels]
x_j: Source node features [n_edges, in_channels]
edge_attr: Edge features [n_edges, edge_channels]
design: Design embedding, broadcast to edges
Returns:
Messages [n_edges, out_channels]
"""
# Broadcast design to all edges
design_broadcast = design.expand(x_i.size(0), -1)
# Concatenate all inputs
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
return self.message_net(msg_input)
def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Update node features with aggregated messages.
Args:
aggr_out: Aggregated messages [n_nodes, out_channels]
x: Original node features [n_nodes, in_channels]
Returns:
Updated node features [n_nodes, out_channels]
"""
return self.update_net(torch.cat([x, aggr_out], dim=-1))
def _manual_propagate(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design: torch.Tensor
) -> torch.Tensor:
"""Fallback message passing without PyG."""
row, col = edge_index # row = target, col = source
# Gather features
x_i = x[row] # Target features
x_j = x[col] # Source features
# Compute messages
design_broadcast = design.expand(x_i.size(0), -1)
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
messages = self.message_net(msg_input)
# Aggregate (mean)
n_nodes = x.size(0)
aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device)
count = torch.zeros(n_nodes, 1, device=x.device)
aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages)
count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1))
count = count.clamp(min=1)
aggr_out = aggr_out / count
# Update
return self.update_net(torch.cat([x, aggr_out], dim=-1))
class ZernikeGNN(nn.Module):
"""
Graph Neural Network for mirror surface displacement prediction.
This model learns to predict Z-displacement fields for all 4 gravity
subcases from 11 design parameters. It uses a fixed polar grid graph
structure and design-conditioned message passing.
The key advantages over MLP:
1. Spatial awareness through message passing
2. Design conditioning modulates spatial information flow
3. Predicts full field (enabling correct relative computation)
4. Respects physics: smooth fields, radial/angular structure
"""
def __init__(
self,
n_design_vars: int = 11,
n_subcases: int = 4,
hidden_dim: int = 128,
n_layers: int = 6,
node_feat_dim: int = 4,
edge_feat_dim: int = 4,
dropout: float = 0.1
):
"""
Args:
n_design_vars: Number of design parameters (11 for mirror)
n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°)
hidden_dim: Hidden layer dimension
n_layers: Number of message passing layers
node_feat_dim: Node feature dimension (r, theta, x, y)
edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle)
dropout: Dropout rate
"""
super().__init__()
self.n_design_vars = n_design_vars
self.n_subcases = n_subcases
self.hidden_dim = hidden_dim
self.n_layers = n_layers
# === Design Encoder ===
# Maps design parameters to hidden space
self.design_encoder = nn.Sequential(
nn.Linear(n_design_vars, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
# === Node Encoder ===
# Maps polar coordinates to hidden space
self.node_encoder = nn.Sequential(
nn.Linear(node_feat_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
# === Edge Encoder ===
# Maps edge features (dr, dtheta, distance, angle) to hidden space
edge_hidden = hidden_dim // 2
self.edge_encoder = nn.Sequential(
nn.Linear(edge_feat_dim, edge_hidden),
nn.SiLU(),
nn.Linear(edge_hidden, edge_hidden),
)
# === Message Passing Layers ===
self.conv_layers = nn.ModuleList([
DesignConditionedConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
design_channels=hidden_dim,
edge_channels=edge_hidden,
)
for _ in range(n_layers)
])
# Layer norms for residual connections
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(n_layers)
])
# === Displacement Decoder ===
# Predicts Z-displacement for each subcase
self.displacement_decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, n_subcases),
)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights with Xavier/Glorot initialization."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
node_features: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design_vars: torch.Tensor
) -> torch.Tensor:
"""
Forward pass: design parameters → displacement field.
Args:
node_features: [n_nodes, 4] - (r, theta, x, y) normalized
edge_index: [2, n_edges] - graph connectivity
edge_attr: [n_edges, 4] - edge features normalized
design_vars: [n_design_vars] or [batch, n_design_vars]
Returns:
z_displacement: [n_nodes, n_subcases] - Z-disp per subcase
or [batch, n_nodes, n_subcases] if batched
"""
# Handle batched vs single design
is_batched = design_vars.dim() == 2
if not is_batched:
design_vars = design_vars.unsqueeze(0) # [1, n_design_vars]
batch_size = design_vars.size(0)
n_nodes = node_features.size(0)
# Encode inputs
design_h = self.design_encoder(design_vars) # [batch, hidden]
node_h = self.node_encoder(node_features) # [n_nodes, hidden]
edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden]
# Process each batch item
outputs = []
for b in range(batch_size):
h = node_h.clone() # Start fresh for each design
# Message passing with residual connections
for conv, norm in zip(self.conv_layers, self.layer_norms):
h_new = conv(h, edge_index, edge_h, design_h[b])
h = norm(h + h_new) # Residual + LayerNorm
# Decode to displacement
z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases]
outputs.append(z_disp)
# Stack outputs
if is_batched:
return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases]
else:
return outputs[0] # [n_nodes, n_subcases]
def count_parameters(self) -> int:
"""Count trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class ZernikeGNNLite(nn.Module):
"""
Lightweight version of ZernikeGNN for faster training/inference.
Uses fewer layers and smaller hidden dimension, suitable for
initial experiments or when training data is limited.
"""
def __init__(
self,
n_design_vars: int = 11,
n_subcases: int = 4,
hidden_dim: int = 64,
n_layers: int = 4
):
super().__init__()
self.n_subcases = n_subcases
# Simpler design encoder
self.design_encoder = nn.Sequential(
nn.Linear(n_design_vars, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Simpler node encoder
self.node_encoder = nn.Sequential(
nn.Linear(4, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Edge encoder
self.edge_encoder = nn.Linear(4, hidden_dim // 2)
# Message passing
self.conv_layers = nn.ModuleList([
DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2)
for _ in range(n_layers)
])
# Decoder
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, n_subcases),
)
def forward(self, node_features, edge_index, edge_attr, design_vars):
"""Forward pass."""
design_h = self.design_encoder(design_vars)
node_h = self.node_encoder(node_features)
edge_h = self.edge_encoder(edge_attr)
for conv in self.conv_layers:
node_h = node_h + conv(node_h, edge_index, edge_h, design_h)
return self.decoder(node_h)
# =============================================================================
# Utility functions
# =============================================================================
def create_model(
n_design_vars: int = 11,
n_subcases: int = 4,
model_type: str = 'full',
**kwargs
) -> nn.Module:
"""
Factory function to create GNN model.
Args:
n_design_vars: Number of design parameters
n_subcases: Number of subcases
model_type: 'full' or 'lite'
**kwargs: Additional arguments passed to model
Returns:
GNN model instance
"""
if model_type == 'lite':
return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs)
else:
return ZernikeGNN(n_design_vars, n_subcases, **kwargs)
def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module:
"""
Load trained model from checkpoint.
Args:
checkpoint_path: Path to .pt checkpoint file
device: Device to load model to
Returns:
Loaded model in eval mode
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
# Get model config
config = checkpoint.get('config', {})
model_type = config.get('model_type', 'full')
# Create model
model = create_model(
n_design_vars=config.get('n_design_vars', 11),
n_subcases=config.get('n_subcases', 4),
model_type=model_type,
hidden_dim=config.get('hidden_dim', 128),
n_layers=config.get('n_layers', 6),
)
# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
print("="*60)
print("Testing ZernikeGNN")
print("="*60)
# Create model
model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6)
print(f"\nModel: {model.__class__.__name__}")
print(f"Parameters: {model.count_parameters():,}")
# Create dummy inputs
n_nodes = 3000
n_edges = 17760
node_features = torch.randn(n_nodes, 4)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
edge_attr = torch.randn(n_edges, 4)
design_vars = torch.randn(11)
# Forward pass
print("\n--- Single Forward Pass ---")
with torch.no_grad():
output = model(node_features, edge_index, edge_attr, design_vars)
print(f"Input design: {design_vars.shape}")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.6f}, {output.max():.6f}]")
# Batched forward pass
print("\n--- Batched Forward Pass ---")
batch_design = torch.randn(8, 11)
with torch.no_grad():
output_batch = model(node_features, edge_index, edge_attr, batch_design)
print(f"Batch design: {batch_design.shape}")
print(f"Batch output: {output_batch.shape}")
# Test lite model
print("\n--- Lite Model ---")
model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4)
print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}")
with torch.no_grad():
output_lite = model_lite(node_features, edge_index, edge_attr, design_vars)
print(f"Lite output shape: {output_lite.shape}")
print("\n" + "="*60)
print("✓ All tests passed!")
print("="*60)