Files
Atomizer/optimization_engine/gnn/polar_graph.py
Antoine 96b196de58 feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00

618 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Polar Mirror Graph for GNN Training
====================================
This module creates a fixed polar grid graph structure for the mirror optical surface.
The key insight is that the mirror has a fixed topology (circular annulus), so we can
use a fixed graph structure regardless of FEA mesh variations.
Why Polar Grid?
1. Matches mirror geometry (annulus)
2. Same approach as extract_zernike_surface.py
3. Enables mesh-independent training
4. Edge structure respects radial/angular physics
Grid Structure:
- n_radial points from r_inner to r_outer
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
- Total nodes = n_radial × n_angular
- Edges connect radial neighbors and angular neighbors (wrap-around)
Usage:
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
# Interpolate FEA results to fixed grid
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
# Get PyTorch Geometric data
data = graph.to_pyg_data(z_disp_grid, design_vars)
"""
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
import json
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
from scipy.spatial import Delaunay
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
class PolarMirrorGraph:
"""
Fixed polar grid graph for mirror optical surface.
This creates a mesh-independent graph structure that can be used for GNN training
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
Attributes:
n_nodes: Total number of nodes (n_radial × n_angular)
r: Radial coordinates [n_nodes]
theta: Angular coordinates [n_nodes]
x: Cartesian X coordinates [n_nodes]
y: Cartesian Y coordinates [n_nodes]
edge_index: Graph edges [2, n_edges]
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
"""
def __init__(
self,
r_inner: float = 100.0,
r_outer: float = 650.0,
n_radial: int = 50,
n_angular: int = 60
):
"""
Initialize polar grid graph.
Args:
r_inner: Inner radius (central hole), mm
r_outer: Outer radius, mm
n_radial: Number of radial samples
n_angular: Number of angular samples
"""
self.r_inner = r_inner
self.r_outer = r_outer
self.n_radial = n_radial
self.n_angular = n_angular
self.n_nodes = n_radial * n_angular
# Create polar grid coordinates
r_1d = np.linspace(r_inner, r_outer, n_radial)
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
# Flatten: radial index varies slowest
self.r = R.flatten().astype(np.float32)
self.theta = Theta.flatten().astype(np.float32)
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
# Build graph edges
self.edge_index, self.edge_attr = self._build_polar_edges()
# Precompute normalization factors
self._r_mean = (r_inner + r_outer) / 2
self._r_std = (r_outer - r_inner) / 2
def _node_index(self, i_r: int, i_theta: int) -> int:
"""Convert (radial_index, angular_index) to flat node index."""
# Angular wraps around
i_theta = i_theta % self.n_angular
return i_r * self.n_angular + i_theta
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Create graph edges respecting polar topology.
Edge types:
1. Radial edges: Connect adjacent radial rings
2. Angular edges: Connect adjacent angular positions (with wrap-around)
3. Diagonal edges: Connect diagonal neighbors for better message passing
Returns:
edge_index: [2, n_edges] array of (source, target) pairs
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
"""
edges = []
edge_features = []
for i_r in range(self.n_radial):
for i_theta in range(self.n_angular):
node = self._node_index(i_r, i_theta)
# Radial neighbor (outward)
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features: (dr, dtheta, distance, relative_angle)
dr = self.r[neighbor] - self.r[node]
dtheta = 0.0
dist = abs(dr)
angle = 0.0 # Radial direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
# Angular neighbor (counterclockwise, with wrap-around)
neighbor = self._node_index(i_r, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features for angular edge
dr = 0.0
dtheta = 2 * np.pi / self.n_angular
# Arc length at this radius
dist = self.r[node] * dtheta
angle = np.pi / 2 # Tangential direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
# Diagonal neighbor (outward + counterclockwise) for better connectivity
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node])
dr = self.r[neighbor] - self.r[node]
dtheta = 2 * np.pi / self.n_angular
dx = self.x[neighbor] - self.x[node]
dy = self.y[neighbor] - self.y[node]
dist = np.sqrt(dx**2 + dy**2)
angle = np.arctan2(dy, dx)
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
return edge_index, edge_attr
def get_node_features(self, normalized: bool = True) -> np.ndarray:
"""
Get node features for GNN input.
Features: (r, theta, x, y) - polar and Cartesian coordinates
Args:
normalized: If True, normalize features to ~[-1, 1] range
Returns:
Node features [n_nodes, 4]
"""
if normalized:
r_norm = (self.r - self._r_mean) / self._r_std
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
x_norm = self.x / self.r_outer
y_norm = self.y / self.r_outer
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
else:
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
"""
Get edge features for GNN input.
Features: (dr, dtheta, distance, angle)
Args:
normalized: If True, normalize features
Returns:
Edge features [n_edges, 4]
"""
if normalized:
edge_attr = self.edge_attr.copy()
edge_attr[:, 0] /= self._r_std # dr
edge_attr[:, 1] /= np.pi # dtheta
edge_attr[:, 2] /= self.r_outer # distance
edge_attr[:, 3] /= np.pi # angle
return edge_attr
else:
return self.edge_attr
def interpolate_from_mesh(
self,
mesh_coords: np.ndarray,
mesh_values: np.ndarray,
method: str = 'rbf'
) -> np.ndarray:
"""
Interpolate FEA results from mesh nodes to fixed polar grid.
Args:
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
Returns:
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
"""
if not HAS_SCIPY:
raise ImportError("scipy required for interpolation: pip install scipy")
# Use only X, Y coordinates
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
# Handle multi-dimensional values
values_1d = mesh_values.ndim == 1
if values_1d:
mesh_values = mesh_values.reshape(-1, 1)
# Target coordinates
target_xy = np.column_stack([self.x, self.y])
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
for i in range(mesh_values.shape[1]):
vals = mesh_values[:, i]
if method == 'rbf':
# RBF interpolation - smooth, handles scattered data well
interp = RBFInterpolator(
xy, vals,
kernel='thin_plate_spline',
smoothing=0.0
)
result[:, i] = interp(target_xy)
elif method == 'linear':
# Linear interpolation via Delaunay triangulation
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN (points outside convex hull) with nearest neighbor
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
elif method == 'clough_tocher':
# Clough-Tocher (C1 smooth) interpolation
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
else:
raise ValueError(f"Unknown interpolation method: {method}")
return result[:, 0] if values_1d else result
def interpolate_field_data(
self,
field_data: Dict[str, Any],
subcases: List[int] = [1, 2, 3, 4],
method: str = 'linear' # Changed from 'rbf' - much faster
) -> Dict[str, np.ndarray]:
"""
Interpolate field data from extract_displacement_field() to polar grid.
Args:
field_data: Output from extract_displacement_field()
subcases: List of subcases to interpolate
method: Interpolation method
Returns:
Dictionary with:
- z_displacement: [n_nodes, n_subcases] array
- original_n_nodes: Number of FEA nodes
"""
mesh_coords = field_data['node_coords']
z_disp_dict = field_data['z_displacement']
# Stack subcases
z_disp_list = []
for sc in subcases:
if sc in z_disp_dict:
z_disp_list.append(z_disp_dict[sc])
else:
raise KeyError(f"Subcase {sc} not found in field_data")
# [n_fea_nodes, n_subcases]
z_disp_mesh = np.column_stack(z_disp_list)
# Interpolate to polar grid
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
return {
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
'original_n_nodes': len(mesh_coords),
}
def to_pyg_data(
self,
z_displacement: np.ndarray,
design_vars: np.ndarray,
objectives: Optional[Dict[str, float]] = None
):
"""
Convert to PyTorch Geometric Data object.
Args:
z_displacement: [n_nodes, n_subcases] displacement field
design_vars: [n_design_vars] design parameters
objectives: Optional dict of objective values (ground truth)
Returns:
torch_geometric.data.Data object
"""
if not HAS_TORCH:
raise ImportError("PyTorch required: pip install torch")
try:
from torch_geometric.data import Data
except ImportError:
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
# Node features: (r, theta, x, y)
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
# Edge index and features
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
# Target: Z-displacement field
y = torch.tensor(z_displacement, dtype=torch.float32)
# Design variables (global feature)
design = torch.tensor(design_vars, dtype=torch.float32)
data = Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
y=y,
design=design,
)
# Add objectives if provided
if objectives:
for key, value in objectives.items():
setattr(data, key, torch.tensor([value], dtype=torch.float32))
return data
def save(self, path: Path) -> None:
"""Save graph structure to JSON file."""
path = Path(path)
data = {
'r_inner': self.r_inner,
'r_outer': self.r_outer,
'n_radial': self.n_radial,
'n_angular': self.n_angular,
'n_nodes': self.n_nodes,
'n_edges': self.edge_index.shape[1],
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
# Save arrays separately for efficiency
np.savez_compressed(
path.with_suffix('.npz'),
r=self.r,
theta=self.theta,
x=self.x,
y=self.y,
edge_index=self.edge_index,
edge_attr=self.edge_attr,
)
@classmethod
def load(cls, path: Path) -> 'PolarMirrorGraph':
"""Load graph structure from file."""
path = Path(path)
with open(path, 'r') as f:
data = json.load(f)
return cls(
r_inner=data['r_inner'],
r_outer=data['r_outer'],
n_radial=data['n_radial'],
n_angular=data['n_angular'],
)
def __repr__(self) -> str:
return (
f"PolarMirrorGraph("
f"r=[{self.r_inner}, {self.r_outer}]mm, "
f"grid={self.n_radial}×{self.n_angular}, "
f"nodes={self.n_nodes}, "
f"edges={self.edge_index.shape[1]})"
)
# =============================================================================
# Convenience functions
# =============================================================================
def create_mirror_dataset(
study_dir: Path,
polar_graph: Optional[PolarMirrorGraph] = None,
subcases: List[int] = [1, 2, 3, 4],
verbose: bool = True
) -> List[Dict[str, Any]]:
"""
Create GNN dataset from a study's gnn_data folder.
Args:
study_dir: Path to study directory
polar_graph: PolarMirrorGraph instance (created if None)
subcases: Subcases to include
verbose: Print progress
Returns:
List of data dictionaries, each containing:
- z_displacement: [n_nodes, n_subcases]
- design_vars: [n_vars]
- trial_number: int
- original_n_nodes: int
"""
from optimization_engine.gnn.extract_displacement_field import load_field
study_dir = Path(study_dir)
gnn_data_dir = study_dir / "gnn_data"
if not gnn_data_dir.exists():
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
# Load index
index_path = gnn_data_dir / "dataset_index.json"
with open(index_path, 'r') as f:
index = json.load(f)
if polar_graph is None:
polar_graph = PolarMirrorGraph()
dataset = []
for trial_num, trial_info in index['trials'].items():
if trial_info.get('status') != 'success':
continue
trial_dir = study_dir / trial_info['trial_dir']
# Find field file
field_path = None
for ext in ['.h5', '.npz']:
candidate = trial_dir / f"displacement_field{ext}"
if candidate.exists():
field_path = candidate
break
if field_path is None:
if verbose:
print(f"[WARN] No field file for trial {trial_num}")
continue
try:
# Load field data
field_data = load_field(field_path)
# Interpolate to polar grid
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
# Get design parameters
params = trial_info.get('params', {})
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
dataset.append({
'z_displacement': interp_result['z_displacement'],
'design_vars': design_vars,
'design_names': list(params.keys()) if params else [],
'trial_number': int(trial_num),
'original_n_nodes': interp_result['original_n_nodes'],
})
if verbose:
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']}{polar_graph.n_nodes} nodes")
except Exception as e:
if verbose:
print(f"[ERR] Trial {trial_num}: {e}")
return dataset
# =============================================================================
# CLI
# =============================================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
parser.add_argument('--test', action='store_true', help='Run basic tests')
parser.add_argument('--study', type=Path, help='Create dataset from study')
args = parser.parse_args()
if args.test:
print("="*60)
print("TESTING PolarMirrorGraph")
print("="*60)
# Create graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\n{graph}")
# Check node features
node_feat = graph.get_node_features(normalized=True)
print(f"\nNode features shape: {node_feat.shape}")
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
# Check edge features
edge_feat = graph.get_edge_features(normalized=True)
print(f"\nEdge features shape: {edge_feat.shape}")
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
# Test interpolation with synthetic data
print("\n--- Testing Interpolation ---")
# Create fake mesh data (random points in annulus)
np.random.seed(42)
n_mesh = 5000
r_mesh = np.random.uniform(100, 650, n_mesh)
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
x_mesh = r_mesh * np.cos(theta_mesh)
y_mesh = r_mesh * np.sin(theta_mesh)
mesh_coords = np.column_stack([x_mesh, y_mesh])
# Synthetic displacement: smooth function
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
# Interpolate
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
print("\n✓ All tests passed!")
elif args.study:
# Create dataset from study
print(f"Creating dataset from: {args.study}")
graph = PolarMirrorGraph()
dataset = create_mirror_dataset(args.study, polar_graph=graph)
print(f"\nDataset: {len(dataset)} samples")
if dataset:
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
else:
# Default: just show info
graph = PolarMirrorGraph()
print(graph)
print(f"\nNode features: {graph.get_node_features().shape}")
print(f"Edge index: {graph.edge_index.shape}")
print(f"Edge features: {graph.edge_attr.shape}")