Files
Atomizer/optimization_engine/gnn/zernike_gnn.py
Antoine 96b196de58 feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00

583 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Zernike GNN Model for Mirror Surface Deformation Prediction
============================================================
This module implements a Graph Neural Network specifically designed for predicting
mirror surface displacement fields from design parameters. The key innovation is
using design-conditioned message passing on a polar grid graph.
Architecture:
Design Variables [11]
Design Encoder [11 → 128]
└──────────────────┐
Node Features │
[r, θ, x, y] │
│ │
▼ │
Node Encoder │
[4 → 128] │
│ │
└─────────┬────────┘
┌─────────────────────────────┐
│ Design-Conditioned │
│ Message Passing (× 6) │
│ │
│ • Polar-aware edges │
│ • Design modulates messages │
│ • Residual connections │
└─────────────┬───────────────┘
Per-Node Decoder [128 → 4]
Z-Displacement Field [3000, 4]
(one value per node per subcase)
Usage:
from optimization_engine.gnn.zernike_gnn import ZernikeGNN
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph()
model = ZernikeGNN(n_design_vars=11, n_subcases=4)
# Forward pass
z_disp = model(
node_features=graph.get_node_features(),
edge_index=graph.edge_index,
edge_attr=graph.get_edge_features(),
design_vars=design_tensor
)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
try:
from torch_geometric.nn import MessagePassing
HAS_PYG = True
except ImportError:
HAS_PYG = False
MessagePassing = nn.Module # Fallback for type hints
class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module):
"""
Message passing layer conditioned on global design parameters.
This layer propagates information through the polar graph while
conditioning on design parameters. The design embedding modulates
how messages flow between nodes.
Key insight: Design parameters affect the stiffness distribution
in the mirror support structure. This layer learns how those changes
propagate spatially through the optical surface.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
design_channels: int,
edge_channels: int = 4,
aggr: str = 'mean'
):
"""
Args:
in_channels: Input node feature dimension
out_channels: Output node feature dimension
design_channels: Design embedding dimension
edge_channels: Edge feature dimension
aggr: Aggregation method ('mean', 'sum', 'max')
"""
if HAS_PYG:
super().__init__(aggr=aggr)
else:
super().__init__()
self.aggr = aggr
self.in_channels = in_channels
self.out_channels = out_channels
# Message network: source node + target node + design + edge
msg_input_dim = 2 * in_channels + design_channels + edge_channels
self.message_net = nn.Sequential(
nn.Linear(msg_input_dim, out_channels * 2),
nn.LayerNorm(out_channels * 2),
nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(out_channels * 2, out_channels),
)
# Update network: combines aggregated messages with original features
self.update_net = nn.Sequential(
nn.Linear(in_channels + out_channels, out_channels),
nn.LayerNorm(out_channels),
nn.SiLU(),
)
# Design gate: allows design to modulate message importance
self.design_gate = nn.Sequential(
nn.Linear(design_channels, out_channels),
nn.Sigmoid(),
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design_embed: torch.Tensor
) -> torch.Tensor:
"""
Forward pass with design conditioning.
Args:
x: Node features [n_nodes, in_channels]
edge_index: Graph connectivity [2, n_edges]
edge_attr: Edge features [n_edges, edge_channels]
design_embed: Design embedding [design_channels]
Returns:
Updated node features [n_nodes, out_channels]
"""
if HAS_PYG:
# Use PyG's message passing
out = self.propagate(
edge_index, x=x, edge_attr=edge_attr, design=design_embed
)
else:
# Fallback implementation without PyG
out = self._manual_propagate(x, edge_index, edge_attr, design_embed)
# Apply design-based gating
gate = self.design_gate(design_embed)
out = out * gate
return out
def message(
self,
x_i: torch.Tensor,
x_j: torch.Tensor,
edge_attr: torch.Tensor,
design: torch.Tensor
) -> torch.Tensor:
"""
Compute messages from source (j) to target (i) nodes.
Args:
x_i: Target node features [n_edges, in_channels]
x_j: Source node features [n_edges, in_channels]
edge_attr: Edge features [n_edges, edge_channels]
design: Design embedding, broadcast to edges
Returns:
Messages [n_edges, out_channels]
"""
# Broadcast design to all edges
design_broadcast = design.expand(x_i.size(0), -1)
# Concatenate all inputs
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
return self.message_net(msg_input)
def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Update node features with aggregated messages.
Args:
aggr_out: Aggregated messages [n_nodes, out_channels]
x: Original node features [n_nodes, in_channels]
Returns:
Updated node features [n_nodes, out_channels]
"""
return self.update_net(torch.cat([x, aggr_out], dim=-1))
def _manual_propagate(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design: torch.Tensor
) -> torch.Tensor:
"""Fallback message passing without PyG."""
row, col = edge_index # row = target, col = source
# Gather features
x_i = x[row] # Target features
x_j = x[col] # Source features
# Compute messages
design_broadcast = design.expand(x_i.size(0), -1)
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
messages = self.message_net(msg_input)
# Aggregate (mean)
n_nodes = x.size(0)
aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device)
count = torch.zeros(n_nodes, 1, device=x.device)
aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages)
count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1))
count = count.clamp(min=1)
aggr_out = aggr_out / count
# Update
return self.update_net(torch.cat([x, aggr_out], dim=-1))
class ZernikeGNN(nn.Module):
"""
Graph Neural Network for mirror surface displacement prediction.
This model learns to predict Z-displacement fields for all 4 gravity
subcases from 11 design parameters. It uses a fixed polar grid graph
structure and design-conditioned message passing.
The key advantages over MLP:
1. Spatial awareness through message passing
2. Design conditioning modulates spatial information flow
3. Predicts full field (enabling correct relative computation)
4. Respects physics: smooth fields, radial/angular structure
"""
def __init__(
self,
n_design_vars: int = 11,
n_subcases: int = 4,
hidden_dim: int = 128,
n_layers: int = 6,
node_feat_dim: int = 4,
edge_feat_dim: int = 4,
dropout: float = 0.1
):
"""
Args:
n_design_vars: Number of design parameters (11 for mirror)
n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°)
hidden_dim: Hidden layer dimension
n_layers: Number of message passing layers
node_feat_dim: Node feature dimension (r, theta, x, y)
edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle)
dropout: Dropout rate
"""
super().__init__()
self.n_design_vars = n_design_vars
self.n_subcases = n_subcases
self.hidden_dim = hidden_dim
self.n_layers = n_layers
# === Design Encoder ===
# Maps design parameters to hidden space
self.design_encoder = nn.Sequential(
nn.Linear(n_design_vars, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
# === Node Encoder ===
# Maps polar coordinates to hidden space
self.node_encoder = nn.Sequential(
nn.Linear(node_feat_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
# === Edge Encoder ===
# Maps edge features (dr, dtheta, distance, angle) to hidden space
edge_hidden = hidden_dim // 2
self.edge_encoder = nn.Sequential(
nn.Linear(edge_feat_dim, edge_hidden),
nn.SiLU(),
nn.Linear(edge_hidden, edge_hidden),
)
# === Message Passing Layers ===
self.conv_layers = nn.ModuleList([
DesignConditionedConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
design_channels=hidden_dim,
edge_channels=edge_hidden,
)
for _ in range(n_layers)
])
# Layer norms for residual connections
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(n_layers)
])
# === Displacement Decoder ===
# Predicts Z-displacement for each subcase
self.displacement_decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, n_subcases),
)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights with Xavier/Glorot initialization."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
node_features: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
design_vars: torch.Tensor
) -> torch.Tensor:
"""
Forward pass: design parameters → displacement field.
Args:
node_features: [n_nodes, 4] - (r, theta, x, y) normalized
edge_index: [2, n_edges] - graph connectivity
edge_attr: [n_edges, 4] - edge features normalized
design_vars: [n_design_vars] or [batch, n_design_vars]
Returns:
z_displacement: [n_nodes, n_subcases] - Z-disp per subcase
or [batch, n_nodes, n_subcases] if batched
"""
# Handle batched vs single design
is_batched = design_vars.dim() == 2
if not is_batched:
design_vars = design_vars.unsqueeze(0) # [1, n_design_vars]
batch_size = design_vars.size(0)
n_nodes = node_features.size(0)
# Encode inputs
design_h = self.design_encoder(design_vars) # [batch, hidden]
node_h = self.node_encoder(node_features) # [n_nodes, hidden]
edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden]
# Process each batch item
outputs = []
for b in range(batch_size):
h = node_h.clone() # Start fresh for each design
# Message passing with residual connections
for conv, norm in zip(self.conv_layers, self.layer_norms):
h_new = conv(h, edge_index, edge_h, design_h[b])
h = norm(h + h_new) # Residual + LayerNorm
# Decode to displacement
z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases]
outputs.append(z_disp)
# Stack outputs
if is_batched:
return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases]
else:
return outputs[0] # [n_nodes, n_subcases]
def count_parameters(self) -> int:
"""Count trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class ZernikeGNNLite(nn.Module):
"""
Lightweight version of ZernikeGNN for faster training/inference.
Uses fewer layers and smaller hidden dimension, suitable for
initial experiments or when training data is limited.
"""
def __init__(
self,
n_design_vars: int = 11,
n_subcases: int = 4,
hidden_dim: int = 64,
n_layers: int = 4
):
super().__init__()
self.n_subcases = n_subcases
# Simpler design encoder
self.design_encoder = nn.Sequential(
nn.Linear(n_design_vars, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Simpler node encoder
self.node_encoder = nn.Sequential(
nn.Linear(4, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Edge encoder
self.edge_encoder = nn.Linear(4, hidden_dim // 2)
# Message passing
self.conv_layers = nn.ModuleList([
DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2)
for _ in range(n_layers)
])
# Decoder
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, n_subcases),
)
def forward(self, node_features, edge_index, edge_attr, design_vars):
"""Forward pass."""
design_h = self.design_encoder(design_vars)
node_h = self.node_encoder(node_features)
edge_h = self.edge_encoder(edge_attr)
for conv in self.conv_layers:
node_h = node_h + conv(node_h, edge_index, edge_h, design_h)
return self.decoder(node_h)
# =============================================================================
# Utility functions
# =============================================================================
def create_model(
n_design_vars: int = 11,
n_subcases: int = 4,
model_type: str = 'full',
**kwargs
) -> nn.Module:
"""
Factory function to create GNN model.
Args:
n_design_vars: Number of design parameters
n_subcases: Number of subcases
model_type: 'full' or 'lite'
**kwargs: Additional arguments passed to model
Returns:
GNN model instance
"""
if model_type == 'lite':
return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs)
else:
return ZernikeGNN(n_design_vars, n_subcases, **kwargs)
def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module:
"""
Load trained model from checkpoint.
Args:
checkpoint_path: Path to .pt checkpoint file
device: Device to load model to
Returns:
Loaded model in eval mode
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
# Get model config
config = checkpoint.get('config', {})
model_type = config.get('model_type', 'full')
# Create model
model = create_model(
n_design_vars=config.get('n_design_vars', 11),
n_subcases=config.get('n_subcases', 4),
model_type=model_type,
hidden_dim=config.get('hidden_dim', 128),
n_layers=config.get('n_layers', 6),
)
# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
print("="*60)
print("Testing ZernikeGNN")
print("="*60)
# Create model
model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6)
print(f"\nModel: {model.__class__.__name__}")
print(f"Parameters: {model.count_parameters():,}")
# Create dummy inputs
n_nodes = 3000
n_edges = 17760
node_features = torch.randn(n_nodes, 4)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
edge_attr = torch.randn(n_edges, 4)
design_vars = torch.randn(11)
# Forward pass
print("\n--- Single Forward Pass ---")
with torch.no_grad():
output = model(node_features, edge_index, edge_attr, design_vars)
print(f"Input design: {design_vars.shape}")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.6f}, {output.max():.6f}]")
# Batched forward pass
print("\n--- Batched Forward Pass ---")
batch_design = torch.randn(8, 11)
with torch.no_grad():
output_batch = model(node_features, edge_index, edge_attr, batch_design)
print(f"Batch design: {batch_design.shape}")
print(f"Batch output: {output_batch.shape}")
# Test lite model
print("\n--- Lite Model ---")
model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4)
print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}")
with torch.no_grad():
output_lite = model_lite(node_features, edge_index, edge_attr, design_vars)
print(f"Lite output shape: {output_lite.shape}")
print("\n" + "="*60)
print("✓ All tests passed!")
print("="*60)