This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
583 lines
19 KiB
Python
583 lines
19 KiB
Python
"""
|
||
Zernike GNN Model for Mirror Surface Deformation Prediction
|
||
============================================================
|
||
|
||
This module implements a Graph Neural Network specifically designed for predicting
|
||
mirror surface displacement fields from design parameters. The key innovation is
|
||
using design-conditioned message passing on a polar grid graph.
|
||
|
||
Architecture:
|
||
Design Variables [11]
|
||
│
|
||
▼
|
||
Design Encoder [11 → 128]
|
||
│
|
||
└──────────────────┐
|
||
│
|
||
Node Features │
|
||
[r, θ, x, y] │
|
||
│ │
|
||
▼ │
|
||
Node Encoder │
|
||
[4 → 128] │
|
||
│ │
|
||
└─────────┬────────┘
|
||
│
|
||
▼
|
||
┌─────────────────────────────┐
|
||
│ Design-Conditioned │
|
||
│ Message Passing (× 6) │
|
||
│ │
|
||
│ • Polar-aware edges │
|
||
│ • Design modulates messages │
|
||
│ • Residual connections │
|
||
└─────────────┬───────────────┘
|
||
│
|
||
▼
|
||
Per-Node Decoder [128 → 4]
|
||
│
|
||
▼
|
||
Z-Displacement Field [3000, 4]
|
||
(one value per node per subcase)
|
||
|
||
Usage:
|
||
from optimization_engine.gnn.zernike_gnn import ZernikeGNN
|
||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||
|
||
graph = PolarMirrorGraph()
|
||
model = ZernikeGNN(n_design_vars=11, n_subcases=4)
|
||
|
||
# Forward pass
|
||
z_disp = model(
|
||
node_features=graph.get_node_features(),
|
||
edge_index=graph.edge_index,
|
||
edge_attr=graph.get_edge_features(),
|
||
design_vars=design_tensor
|
||
)
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import Optional
|
||
|
||
try:
|
||
from torch_geometric.nn import MessagePassing
|
||
HAS_PYG = True
|
||
except ImportError:
|
||
HAS_PYG = False
|
||
MessagePassing = nn.Module # Fallback for type hints
|
||
|
||
|
||
class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module):
|
||
"""
|
||
Message passing layer conditioned on global design parameters.
|
||
|
||
This layer propagates information through the polar graph while
|
||
conditioning on design parameters. The design embedding modulates
|
||
how messages flow between nodes.
|
||
|
||
Key insight: Design parameters affect the stiffness distribution
|
||
in the mirror support structure. This layer learns how those changes
|
||
propagate spatially through the optical surface.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
in_channels: int,
|
||
out_channels: int,
|
||
design_channels: int,
|
||
edge_channels: int = 4,
|
||
aggr: str = 'mean'
|
||
):
|
||
"""
|
||
Args:
|
||
in_channels: Input node feature dimension
|
||
out_channels: Output node feature dimension
|
||
design_channels: Design embedding dimension
|
||
edge_channels: Edge feature dimension
|
||
aggr: Aggregation method ('mean', 'sum', 'max')
|
||
"""
|
||
if HAS_PYG:
|
||
super().__init__(aggr=aggr)
|
||
else:
|
||
super().__init__()
|
||
self.aggr = aggr
|
||
|
||
self.in_channels = in_channels
|
||
self.out_channels = out_channels
|
||
|
||
# Message network: source node + target node + design + edge
|
||
msg_input_dim = 2 * in_channels + design_channels + edge_channels
|
||
self.message_net = nn.Sequential(
|
||
nn.Linear(msg_input_dim, out_channels * 2),
|
||
nn.LayerNorm(out_channels * 2),
|
||
nn.SiLU(),
|
||
nn.Dropout(0.1),
|
||
nn.Linear(out_channels * 2, out_channels),
|
||
)
|
||
|
||
# Update network: combines aggregated messages with original features
|
||
self.update_net = nn.Sequential(
|
||
nn.Linear(in_channels + out_channels, out_channels),
|
||
nn.LayerNorm(out_channels),
|
||
nn.SiLU(),
|
||
)
|
||
|
||
# Design gate: allows design to modulate message importance
|
||
self.design_gate = nn.Sequential(
|
||
nn.Linear(design_channels, out_channels),
|
||
nn.Sigmoid(),
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
edge_index: torch.Tensor,
|
||
edge_attr: torch.Tensor,
|
||
design_embed: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""
|
||
Forward pass with design conditioning.
|
||
|
||
Args:
|
||
x: Node features [n_nodes, in_channels]
|
||
edge_index: Graph connectivity [2, n_edges]
|
||
edge_attr: Edge features [n_edges, edge_channels]
|
||
design_embed: Design embedding [design_channels]
|
||
|
||
Returns:
|
||
Updated node features [n_nodes, out_channels]
|
||
"""
|
||
if HAS_PYG:
|
||
# Use PyG's message passing
|
||
out = self.propagate(
|
||
edge_index, x=x, edge_attr=edge_attr, design=design_embed
|
||
)
|
||
else:
|
||
# Fallback implementation without PyG
|
||
out = self._manual_propagate(x, edge_index, edge_attr, design_embed)
|
||
|
||
# Apply design-based gating
|
||
gate = self.design_gate(design_embed)
|
||
out = out * gate
|
||
|
||
return out
|
||
|
||
def message(
|
||
self,
|
||
x_i: torch.Tensor,
|
||
x_j: torch.Tensor,
|
||
edge_attr: torch.Tensor,
|
||
design: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute messages from source (j) to target (i) nodes.
|
||
|
||
Args:
|
||
x_i: Target node features [n_edges, in_channels]
|
||
x_j: Source node features [n_edges, in_channels]
|
||
edge_attr: Edge features [n_edges, edge_channels]
|
||
design: Design embedding, broadcast to edges
|
||
|
||
Returns:
|
||
Messages [n_edges, out_channels]
|
||
"""
|
||
# Broadcast design to all edges
|
||
design_broadcast = design.expand(x_i.size(0), -1)
|
||
|
||
# Concatenate all inputs
|
||
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
||
|
||
return self.message_net(msg_input)
|
||
|
||
def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Update node features with aggregated messages.
|
||
|
||
Args:
|
||
aggr_out: Aggregated messages [n_nodes, out_channels]
|
||
x: Original node features [n_nodes, in_channels]
|
||
|
||
Returns:
|
||
Updated node features [n_nodes, out_channels]
|
||
"""
|
||
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
||
|
||
def _manual_propagate(
|
||
self,
|
||
x: torch.Tensor,
|
||
edge_index: torch.Tensor,
|
||
edge_attr: torch.Tensor,
|
||
design: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""Fallback message passing without PyG."""
|
||
row, col = edge_index # row = target, col = source
|
||
|
||
# Gather features
|
||
x_i = x[row] # Target features
|
||
x_j = x[col] # Source features
|
||
|
||
# Compute messages
|
||
design_broadcast = design.expand(x_i.size(0), -1)
|
||
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
||
messages = self.message_net(msg_input)
|
||
|
||
# Aggregate (mean)
|
||
n_nodes = x.size(0)
|
||
aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device)
|
||
count = torch.zeros(n_nodes, 1, device=x.device)
|
||
|
||
aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages)
|
||
count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1))
|
||
count = count.clamp(min=1)
|
||
aggr_out = aggr_out / count
|
||
|
||
# Update
|
||
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
||
|
||
|
||
class ZernikeGNN(nn.Module):
|
||
"""
|
||
Graph Neural Network for mirror surface displacement prediction.
|
||
|
||
This model learns to predict Z-displacement fields for all 4 gravity
|
||
subcases from 11 design parameters. It uses a fixed polar grid graph
|
||
structure and design-conditioned message passing.
|
||
|
||
The key advantages over MLP:
|
||
1. Spatial awareness through message passing
|
||
2. Design conditioning modulates spatial information flow
|
||
3. Predicts full field (enabling correct relative computation)
|
||
4. Respects physics: smooth fields, radial/angular structure
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
n_design_vars: int = 11,
|
||
n_subcases: int = 4,
|
||
hidden_dim: int = 128,
|
||
n_layers: int = 6,
|
||
node_feat_dim: int = 4,
|
||
edge_feat_dim: int = 4,
|
||
dropout: float = 0.1
|
||
):
|
||
"""
|
||
Args:
|
||
n_design_vars: Number of design parameters (11 for mirror)
|
||
n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°)
|
||
hidden_dim: Hidden layer dimension
|
||
n_layers: Number of message passing layers
|
||
node_feat_dim: Node feature dimension (r, theta, x, y)
|
||
edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle)
|
||
dropout: Dropout rate
|
||
"""
|
||
super().__init__()
|
||
|
||
self.n_design_vars = n_design_vars
|
||
self.n_subcases = n_subcases
|
||
self.hidden_dim = hidden_dim
|
||
self.n_layers = n_layers
|
||
|
||
# === Design Encoder ===
|
||
# Maps design parameters to hidden space
|
||
self.design_encoder = nn.Sequential(
|
||
nn.Linear(n_design_vars, hidden_dim),
|
||
nn.LayerNorm(hidden_dim),
|
||
nn.SiLU(),
|
||
nn.Dropout(dropout),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
nn.LayerNorm(hidden_dim),
|
||
)
|
||
|
||
# === Node Encoder ===
|
||
# Maps polar coordinates to hidden space
|
||
self.node_encoder = nn.Sequential(
|
||
nn.Linear(node_feat_dim, hidden_dim),
|
||
nn.LayerNorm(hidden_dim),
|
||
nn.SiLU(),
|
||
nn.Dropout(dropout),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
nn.LayerNorm(hidden_dim),
|
||
)
|
||
|
||
# === Edge Encoder ===
|
||
# Maps edge features (dr, dtheta, distance, angle) to hidden space
|
||
edge_hidden = hidden_dim // 2
|
||
self.edge_encoder = nn.Sequential(
|
||
nn.Linear(edge_feat_dim, edge_hidden),
|
||
nn.SiLU(),
|
||
nn.Linear(edge_hidden, edge_hidden),
|
||
)
|
||
|
||
# === Message Passing Layers ===
|
||
self.conv_layers = nn.ModuleList([
|
||
DesignConditionedConv(
|
||
in_channels=hidden_dim,
|
||
out_channels=hidden_dim,
|
||
design_channels=hidden_dim,
|
||
edge_channels=edge_hidden,
|
||
)
|
||
for _ in range(n_layers)
|
||
])
|
||
|
||
# Layer norms for residual connections
|
||
self.layer_norms = nn.ModuleList([
|
||
nn.LayerNorm(hidden_dim) for _ in range(n_layers)
|
||
])
|
||
|
||
# === Displacement Decoder ===
|
||
# Predicts Z-displacement for each subcase
|
||
self.displacement_decoder = nn.Sequential(
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
nn.LayerNorm(hidden_dim),
|
||
nn.SiLU(),
|
||
nn.Dropout(dropout),
|
||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||
nn.SiLU(),
|
||
nn.Linear(hidden_dim // 2, n_subcases),
|
||
)
|
||
|
||
# Initialize weights
|
||
self._init_weights()
|
||
|
||
def _init_weights(self):
|
||
"""Initialize weights with Xavier/Glorot initialization."""
|
||
for module in self.modules():
|
||
if isinstance(module, nn.Linear):
|
||
nn.init.xavier_uniform_(module.weight)
|
||
if module.bias is not None:
|
||
nn.init.zeros_(module.bias)
|
||
|
||
def forward(
|
||
self,
|
||
node_features: torch.Tensor,
|
||
edge_index: torch.Tensor,
|
||
edge_attr: torch.Tensor,
|
||
design_vars: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""
|
||
Forward pass: design parameters → displacement field.
|
||
|
||
Args:
|
||
node_features: [n_nodes, 4] - (r, theta, x, y) normalized
|
||
edge_index: [2, n_edges] - graph connectivity
|
||
edge_attr: [n_edges, 4] - edge features normalized
|
||
design_vars: [n_design_vars] or [batch, n_design_vars]
|
||
|
||
Returns:
|
||
z_displacement: [n_nodes, n_subcases] - Z-disp per subcase
|
||
or [batch, n_nodes, n_subcases] if batched
|
||
"""
|
||
# Handle batched vs single design
|
||
is_batched = design_vars.dim() == 2
|
||
if not is_batched:
|
||
design_vars = design_vars.unsqueeze(0) # [1, n_design_vars]
|
||
|
||
batch_size = design_vars.size(0)
|
||
n_nodes = node_features.size(0)
|
||
|
||
# Encode inputs
|
||
design_h = self.design_encoder(design_vars) # [batch, hidden]
|
||
node_h = self.node_encoder(node_features) # [n_nodes, hidden]
|
||
edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden]
|
||
|
||
# Process each batch item
|
||
outputs = []
|
||
for b in range(batch_size):
|
||
h = node_h.clone() # Start fresh for each design
|
||
|
||
# Message passing with residual connections
|
||
for conv, norm in zip(self.conv_layers, self.layer_norms):
|
||
h_new = conv(h, edge_index, edge_h, design_h[b])
|
||
h = norm(h + h_new) # Residual + LayerNorm
|
||
|
||
# Decode to displacement
|
||
z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases]
|
||
outputs.append(z_disp)
|
||
|
||
# Stack outputs
|
||
if is_batched:
|
||
return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases]
|
||
else:
|
||
return outputs[0] # [n_nodes, n_subcases]
|
||
|
||
def count_parameters(self) -> int:
|
||
"""Count trainable parameters."""
|
||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||
|
||
|
||
class ZernikeGNNLite(nn.Module):
|
||
"""
|
||
Lightweight version of ZernikeGNN for faster training/inference.
|
||
|
||
Uses fewer layers and smaller hidden dimension, suitable for
|
||
initial experiments or when training data is limited.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
n_design_vars: int = 11,
|
||
n_subcases: int = 4,
|
||
hidden_dim: int = 64,
|
||
n_layers: int = 4
|
||
):
|
||
super().__init__()
|
||
|
||
self.n_subcases = n_subcases
|
||
|
||
# Simpler design encoder
|
||
self.design_encoder = nn.Sequential(
|
||
nn.Linear(n_design_vars, hidden_dim),
|
||
nn.SiLU(),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
)
|
||
|
||
# Simpler node encoder
|
||
self.node_encoder = nn.Sequential(
|
||
nn.Linear(4, hidden_dim),
|
||
nn.SiLU(),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
)
|
||
|
||
# Edge encoder
|
||
self.edge_encoder = nn.Linear(4, hidden_dim // 2)
|
||
|
||
# Message passing
|
||
self.conv_layers = nn.ModuleList([
|
||
DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2)
|
||
for _ in range(n_layers)
|
||
])
|
||
|
||
# Decoder
|
||
self.decoder = nn.Sequential(
|
||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||
nn.SiLU(),
|
||
nn.Linear(hidden_dim // 2, n_subcases),
|
||
)
|
||
|
||
def forward(self, node_features, edge_index, edge_attr, design_vars):
|
||
"""Forward pass."""
|
||
design_h = self.design_encoder(design_vars)
|
||
node_h = self.node_encoder(node_features)
|
||
edge_h = self.edge_encoder(edge_attr)
|
||
|
||
for conv in self.conv_layers:
|
||
node_h = node_h + conv(node_h, edge_index, edge_h, design_h)
|
||
|
||
return self.decoder(node_h)
|
||
|
||
|
||
# =============================================================================
|
||
# Utility functions
|
||
# =============================================================================
|
||
|
||
def create_model(
|
||
n_design_vars: int = 11,
|
||
n_subcases: int = 4,
|
||
model_type: str = 'full',
|
||
**kwargs
|
||
) -> nn.Module:
|
||
"""
|
||
Factory function to create GNN model.
|
||
|
||
Args:
|
||
n_design_vars: Number of design parameters
|
||
n_subcases: Number of subcases
|
||
model_type: 'full' or 'lite'
|
||
**kwargs: Additional arguments passed to model
|
||
|
||
Returns:
|
||
GNN model instance
|
||
"""
|
||
if model_type == 'lite':
|
||
return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs)
|
||
else:
|
||
return ZernikeGNN(n_design_vars, n_subcases, **kwargs)
|
||
|
||
|
||
def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module:
|
||
"""
|
||
Load trained model from checkpoint.
|
||
|
||
Args:
|
||
checkpoint_path: Path to .pt checkpoint file
|
||
device: Device to load model to
|
||
|
||
Returns:
|
||
Loaded model in eval mode
|
||
"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||
|
||
# Get model config
|
||
config = checkpoint.get('config', {})
|
||
model_type = config.get('model_type', 'full')
|
||
|
||
# Create model
|
||
model = create_model(
|
||
n_design_vars=config.get('n_design_vars', 11),
|
||
n_subcases=config.get('n_subcases', 4),
|
||
model_type=model_type,
|
||
hidden_dim=config.get('hidden_dim', 128),
|
||
n_layers=config.get('n_layers', 6),
|
||
)
|
||
|
||
# Load weights
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model.eval()
|
||
|
||
return model
|
||
|
||
|
||
# =============================================================================
|
||
# Testing
|
||
# =============================================================================
|
||
|
||
if __name__ == '__main__':
|
||
print("="*60)
|
||
print("Testing ZernikeGNN")
|
||
print("="*60)
|
||
|
||
# Create model
|
||
model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6)
|
||
print(f"\nModel: {model.__class__.__name__}")
|
||
print(f"Parameters: {model.count_parameters():,}")
|
||
|
||
# Create dummy inputs
|
||
n_nodes = 3000
|
||
n_edges = 17760
|
||
|
||
node_features = torch.randn(n_nodes, 4)
|
||
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||
edge_attr = torch.randn(n_edges, 4)
|
||
design_vars = torch.randn(11)
|
||
|
||
# Forward pass
|
||
print("\n--- Single Forward Pass ---")
|
||
with torch.no_grad():
|
||
output = model(node_features, edge_index, edge_attr, design_vars)
|
||
print(f"Input design: {design_vars.shape}")
|
||
print(f"Output shape: {output.shape}")
|
||
print(f"Output range: [{output.min():.6f}, {output.max():.6f}]")
|
||
|
||
# Batched forward pass
|
||
print("\n--- Batched Forward Pass ---")
|
||
batch_design = torch.randn(8, 11)
|
||
with torch.no_grad():
|
||
output_batch = model(node_features, edge_index, edge_attr, batch_design)
|
||
print(f"Batch design: {batch_design.shape}")
|
||
print(f"Batch output: {output_batch.shape}")
|
||
|
||
# Test lite model
|
||
print("\n--- Lite Model ---")
|
||
model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4)
|
||
print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}")
|
||
|
||
with torch.no_grad():
|
||
output_lite = model_lite(node_features, edge_index, edge_attr, design_vars)
|
||
print(f"Lite output shape: {output_lite.shape}")
|
||
|
||
print("\n" + "="*60)
|
||
print("✓ All tests passed!")
|
||
print("="*60)
|