Files
Atomizer/optimization_engine/gnn/polar_graph.py

618 lines
22 KiB
Python
Raw Normal View History

feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00
"""
Polar Mirror Graph for GNN Training
====================================
This module creates a fixed polar grid graph structure for the mirror optical surface.
The key insight is that the mirror has a fixed topology (circular annulus), so we can
use a fixed graph structure regardless of FEA mesh variations.
Why Polar Grid?
1. Matches mirror geometry (annulus)
2. Same approach as extract_zernike_surface.py
3. Enables mesh-independent training
4. Edge structure respects radial/angular physics
Grid Structure:
- n_radial points from r_inner to r_outer
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
- Total nodes = n_radial × n_angular
- Edges connect radial neighbors and angular neighbors (wrap-around)
Usage:
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
# Interpolate FEA results to fixed grid
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
# Get PyTorch Geometric data
data = graph.to_pyg_data(z_disp_grid, design_vars)
"""
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
import json
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
from scipy.spatial import Delaunay
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
class PolarMirrorGraph:
"""
Fixed polar grid graph for mirror optical surface.
This creates a mesh-independent graph structure that can be used for GNN training
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
Attributes:
n_nodes: Total number of nodes (n_radial × n_angular)
r: Radial coordinates [n_nodes]
theta: Angular coordinates [n_nodes]
x: Cartesian X coordinates [n_nodes]
y: Cartesian Y coordinates [n_nodes]
edge_index: Graph edges [2, n_edges]
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
"""
def __init__(
self,
r_inner: float = 100.0,
r_outer: float = 650.0,
n_radial: int = 50,
n_angular: int = 60
):
"""
Initialize polar grid graph.
Args:
r_inner: Inner radius (central hole), mm
r_outer: Outer radius, mm
n_radial: Number of radial samples
n_angular: Number of angular samples
"""
self.r_inner = r_inner
self.r_outer = r_outer
self.n_radial = n_radial
self.n_angular = n_angular
self.n_nodes = n_radial * n_angular
# Create polar grid coordinates
r_1d = np.linspace(r_inner, r_outer, n_radial)
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
# Flatten: radial index varies slowest
self.r = R.flatten().astype(np.float32)
self.theta = Theta.flatten().astype(np.float32)
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
# Build graph edges
self.edge_index, self.edge_attr = self._build_polar_edges()
# Precompute normalization factors
self._r_mean = (r_inner + r_outer) / 2
self._r_std = (r_outer - r_inner) / 2
def _node_index(self, i_r: int, i_theta: int) -> int:
"""Convert (radial_index, angular_index) to flat node index."""
# Angular wraps around
i_theta = i_theta % self.n_angular
return i_r * self.n_angular + i_theta
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Create graph edges respecting polar topology.
Edge types:
1. Radial edges: Connect adjacent radial rings
2. Angular edges: Connect adjacent angular positions (with wrap-around)
3. Diagonal edges: Connect diagonal neighbors for better message passing
Returns:
edge_index: [2, n_edges] array of (source, target) pairs
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
"""
edges = []
edge_features = []
for i_r in range(self.n_radial):
for i_theta in range(self.n_angular):
node = self._node_index(i_r, i_theta)
# Radial neighbor (outward)
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features: (dr, dtheta, distance, relative_angle)
dr = self.r[neighbor] - self.r[node]
dtheta = 0.0
dist = abs(dr)
angle = 0.0 # Radial direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
# Angular neighbor (counterclockwise, with wrap-around)
neighbor = self._node_index(i_r, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node]) # Bidirectional
# Edge features for angular edge
dr = 0.0
dtheta = 2 * np.pi / self.n_angular
# Arc length at this radius
dist = self.r[node] * dtheta
angle = np.pi / 2 # Tangential direction
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
# Diagonal neighbor (outward + counterclockwise) for better connectivity
if i_r < self.n_radial - 1:
neighbor = self._node_index(i_r + 1, i_theta + 1)
edges.append([node, neighbor])
edges.append([neighbor, node])
dr = self.r[neighbor] - self.r[node]
dtheta = 2 * np.pi / self.n_angular
dx = self.x[neighbor] - self.x[node]
dy = self.y[neighbor] - self.y[node]
dist = np.sqrt(dx**2 + dy**2)
angle = np.arctan2(dy, dx)
edge_features.append([dr, dtheta, dist, angle])
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
return edge_index, edge_attr
def get_node_features(self, normalized: bool = True) -> np.ndarray:
"""
Get node features for GNN input.
Features: (r, theta, x, y) - polar and Cartesian coordinates
Args:
normalized: If True, normalize features to ~[-1, 1] range
Returns:
Node features [n_nodes, 4]
"""
if normalized:
r_norm = (self.r - self._r_mean) / self._r_std
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
x_norm = self.x / self.r_outer
y_norm = self.y / self.r_outer
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
else:
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
"""
Get edge features for GNN input.
Features: (dr, dtheta, distance, angle)
Args:
normalized: If True, normalize features
Returns:
Edge features [n_edges, 4]
"""
if normalized:
edge_attr = self.edge_attr.copy()
edge_attr[:, 0] /= self._r_std # dr
edge_attr[:, 1] /= np.pi # dtheta
edge_attr[:, 2] /= self.r_outer # distance
edge_attr[:, 3] /= np.pi # angle
return edge_attr
else:
return self.edge_attr
def interpolate_from_mesh(
self,
mesh_coords: np.ndarray,
mesh_values: np.ndarray,
method: str = 'rbf'
) -> np.ndarray:
"""
Interpolate FEA results from mesh nodes to fixed polar grid.
Args:
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
Returns:
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
"""
if not HAS_SCIPY:
raise ImportError("scipy required for interpolation: pip install scipy")
# Use only X, Y coordinates
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
# Handle multi-dimensional values
values_1d = mesh_values.ndim == 1
if values_1d:
mesh_values = mesh_values.reshape(-1, 1)
# Target coordinates
target_xy = np.column_stack([self.x, self.y])
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
for i in range(mesh_values.shape[1]):
vals = mesh_values[:, i]
if method == 'rbf':
# RBF interpolation - smooth, handles scattered data well
interp = RBFInterpolator(
xy, vals,
kernel='thin_plate_spline',
smoothing=0.0
)
result[:, i] = interp(target_xy)
elif method == 'linear':
# Linear interpolation via Delaunay triangulation
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN (points outside convex hull) with nearest neighbor
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
elif method == 'clough_tocher':
# Clough-Tocher (C1 smooth) interpolation
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
result[:, i] = interp(target_xy)
# Handle NaN
nan_mask = np.isnan(result[:, i])
if nan_mask.any():
from scipy.spatial import cKDTree
tree = cKDTree(xy)
_, idx = tree.query(target_xy[nan_mask])
result[nan_mask, i] = vals[idx]
else:
raise ValueError(f"Unknown interpolation method: {method}")
return result[:, 0] if values_1d else result
def interpolate_field_data(
self,
field_data: Dict[str, Any],
subcases: List[int] = [1, 2, 3, 4],
method: str = 'linear' # Changed from 'rbf' - much faster
) -> Dict[str, np.ndarray]:
"""
Interpolate field data from extract_displacement_field() to polar grid.
Args:
field_data: Output from extract_displacement_field()
subcases: List of subcases to interpolate
method: Interpolation method
Returns:
Dictionary with:
- z_displacement: [n_nodes, n_subcases] array
- original_n_nodes: Number of FEA nodes
"""
mesh_coords = field_data['node_coords']
z_disp_dict = field_data['z_displacement']
# Stack subcases
z_disp_list = []
for sc in subcases:
if sc in z_disp_dict:
z_disp_list.append(z_disp_dict[sc])
else:
raise KeyError(f"Subcase {sc} not found in field_data")
# [n_fea_nodes, n_subcases]
z_disp_mesh = np.column_stack(z_disp_list)
# Interpolate to polar grid
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
return {
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
'original_n_nodes': len(mesh_coords),
}
def to_pyg_data(
self,
z_displacement: np.ndarray,
design_vars: np.ndarray,
objectives: Optional[Dict[str, float]] = None
):
"""
Convert to PyTorch Geometric Data object.
Args:
z_displacement: [n_nodes, n_subcases] displacement field
design_vars: [n_design_vars] design parameters
objectives: Optional dict of objective values (ground truth)
Returns:
torch_geometric.data.Data object
"""
if not HAS_TORCH:
raise ImportError("PyTorch required: pip install torch")
try:
from torch_geometric.data import Data
except ImportError:
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
# Node features: (r, theta, x, y)
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
# Edge index and features
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
# Target: Z-displacement field
y = torch.tensor(z_displacement, dtype=torch.float32)
# Design variables (global feature)
design = torch.tensor(design_vars, dtype=torch.float32)
data = Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
y=y,
design=design,
)
# Add objectives if provided
if objectives:
for key, value in objectives.items():
setattr(data, key, torch.tensor([value], dtype=torch.float32))
return data
def save(self, path: Path) -> None:
"""Save graph structure to JSON file."""
path = Path(path)
data = {
'r_inner': self.r_inner,
'r_outer': self.r_outer,
'n_radial': self.n_radial,
'n_angular': self.n_angular,
'n_nodes': self.n_nodes,
'n_edges': self.edge_index.shape[1],
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
# Save arrays separately for efficiency
np.savez_compressed(
path.with_suffix('.npz'),
r=self.r,
theta=self.theta,
x=self.x,
y=self.y,
edge_index=self.edge_index,
edge_attr=self.edge_attr,
)
@classmethod
def load(cls, path: Path) -> 'PolarMirrorGraph':
"""Load graph structure from file."""
path = Path(path)
with open(path, 'r') as f:
data = json.load(f)
return cls(
r_inner=data['r_inner'],
r_outer=data['r_outer'],
n_radial=data['n_radial'],
n_angular=data['n_angular'],
)
def __repr__(self) -> str:
return (
f"PolarMirrorGraph("
f"r=[{self.r_inner}, {self.r_outer}]mm, "
f"grid={self.n_radial}×{self.n_angular}, "
f"nodes={self.n_nodes}, "
f"edges={self.edge_index.shape[1]})"
)
# =============================================================================
# Convenience functions
# =============================================================================
def create_mirror_dataset(
study_dir: Path,
polar_graph: Optional[PolarMirrorGraph] = None,
subcases: List[int] = [1, 2, 3, 4],
verbose: bool = True
) -> List[Dict[str, Any]]:
"""
Create GNN dataset from a study's gnn_data folder.
Args:
study_dir: Path to study directory
polar_graph: PolarMirrorGraph instance (created if None)
subcases: Subcases to include
verbose: Print progress
Returns:
List of data dictionaries, each containing:
- z_displacement: [n_nodes, n_subcases]
- design_vars: [n_vars]
- trial_number: int
- original_n_nodes: int
"""
from optimization_engine.gnn.extract_displacement_field import load_field
study_dir = Path(study_dir)
gnn_data_dir = study_dir / "gnn_data"
if not gnn_data_dir.exists():
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
# Load index
index_path = gnn_data_dir / "dataset_index.json"
with open(index_path, 'r') as f:
index = json.load(f)
if polar_graph is None:
polar_graph = PolarMirrorGraph()
dataset = []
for trial_num, trial_info in index['trials'].items():
if trial_info.get('status') != 'success':
continue
trial_dir = study_dir / trial_info['trial_dir']
# Find field file
field_path = None
for ext in ['.h5', '.npz']:
candidate = trial_dir / f"displacement_field{ext}"
if candidate.exists():
field_path = candidate
break
if field_path is None:
if verbose:
print(f"[WARN] No field file for trial {trial_num}")
continue
try:
# Load field data
field_data = load_field(field_path)
# Interpolate to polar grid
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
# Get design parameters
params = trial_info.get('params', {})
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
dataset.append({
'z_displacement': interp_result['z_displacement'],
'design_vars': design_vars,
'design_names': list(params.keys()) if params else [],
'trial_number': int(trial_num),
'original_n_nodes': interp_result['original_n_nodes'],
})
if verbose:
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']}{polar_graph.n_nodes} nodes")
except Exception as e:
if verbose:
print(f"[ERR] Trial {trial_num}: {e}")
return dataset
# =============================================================================
# CLI
# =============================================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
parser.add_argument('--test', action='store_true', help='Run basic tests')
parser.add_argument('--study', type=Path, help='Create dataset from study')
args = parser.parse_args()
if args.test:
print("="*60)
print("TESTING PolarMirrorGraph")
print("="*60)
# Create graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\n{graph}")
# Check node features
node_feat = graph.get_node_features(normalized=True)
print(f"\nNode features shape: {node_feat.shape}")
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
# Check edge features
edge_feat = graph.get_edge_features(normalized=True)
print(f"\nEdge features shape: {edge_feat.shape}")
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
# Test interpolation with synthetic data
print("\n--- Testing Interpolation ---")
# Create fake mesh data (random points in annulus)
np.random.seed(42)
n_mesh = 5000
r_mesh = np.random.uniform(100, 650, n_mesh)
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
x_mesh = r_mesh * np.cos(theta_mesh)
y_mesh = r_mesh * np.sin(theta_mesh)
mesh_coords = np.column_stack([x_mesh, y_mesh])
# Synthetic displacement: smooth function
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
# Interpolate
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
print("\n✓ All tests passed!")
elif args.study:
# Create dataset from study
print(f"Creating dataset from: {args.study}")
graph = PolarMirrorGraph()
dataset = create_mirror_dataset(args.study, polar_graph=graph)
print(f"\nDataset: {len(dataset)} samples")
if dataset:
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
else:
# Default: just show info
graph = PolarMirrorGraph()
print(graph)
print(f"\nNode features: {graph.get_node_features().shape}")
print(f"Edge index: {graph.edge_index.shape}")
print(f"Edge features: {graph.edge_attr.shape}")