""" Zernike GNN Model for Mirror Surface Deformation Prediction ============================================================ This module implements a Graph Neural Network specifically designed for predicting mirror surface displacement fields from design parameters. The key innovation is using design-conditioned message passing on a polar grid graph. Architecture: Design Variables [11] │ ▼ Design Encoder [11 → 128] │ └──────────────────┐ │ Node Features │ [r, θ, x, y] │ │ │ ▼ │ Node Encoder │ [4 → 128] │ │ │ └─────────┬────────┘ │ ▼ ┌─────────────────────────────┐ │ Design-Conditioned │ │ Message Passing (× 6) │ │ │ │ • Polar-aware edges │ │ • Design modulates messages │ │ • Residual connections │ └─────────────┬───────────────┘ │ ▼ Per-Node Decoder [128 → 4] │ ▼ Z-Displacement Field [3000, 4] (one value per node per subcase) Usage: from optimization_engine.gnn.zernike_gnn import ZernikeGNN from optimization_engine.gnn.polar_graph import PolarMirrorGraph graph = PolarMirrorGraph() model = ZernikeGNN(n_design_vars=11, n_subcases=4) # Forward pass z_disp = model( node_features=graph.get_node_features(), edge_index=graph.edge_index, edge_attr=graph.get_edge_features(), design_vars=design_tensor ) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional try: from torch_geometric.nn import MessagePassing HAS_PYG = True except ImportError: HAS_PYG = False MessagePassing = nn.Module # Fallback for type hints class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module): """ Message passing layer conditioned on global design parameters. This layer propagates information through the polar graph while conditioning on design parameters. The design embedding modulates how messages flow between nodes. Key insight: Design parameters affect the stiffness distribution in the mirror support structure. This layer learns how those changes propagate spatially through the optical surface. """ def __init__( self, in_channels: int, out_channels: int, design_channels: int, edge_channels: int = 4, aggr: str = 'mean' ): """ Args: in_channels: Input node feature dimension out_channels: Output node feature dimension design_channels: Design embedding dimension edge_channels: Edge feature dimension aggr: Aggregation method ('mean', 'sum', 'max') """ if HAS_PYG: super().__init__(aggr=aggr) else: super().__init__() self.aggr = aggr self.in_channels = in_channels self.out_channels = out_channels # Message network: source node + target node + design + edge msg_input_dim = 2 * in_channels + design_channels + edge_channels self.message_net = nn.Sequential( nn.Linear(msg_input_dim, out_channels * 2), nn.LayerNorm(out_channels * 2), nn.SiLU(), nn.Dropout(0.1), nn.Linear(out_channels * 2, out_channels), ) # Update network: combines aggregated messages with original features self.update_net = nn.Sequential( nn.Linear(in_channels + out_channels, out_channels), nn.LayerNorm(out_channels), nn.SiLU(), ) # Design gate: allows design to modulate message importance self.design_gate = nn.Sequential( nn.Linear(design_channels, out_channels), nn.Sigmoid(), ) def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, design_embed: torch.Tensor ) -> torch.Tensor: """ Forward pass with design conditioning. Args: x: Node features [n_nodes, in_channels] edge_index: Graph connectivity [2, n_edges] edge_attr: Edge features [n_edges, edge_channels] design_embed: Design embedding [design_channels] Returns: Updated node features [n_nodes, out_channels] """ if HAS_PYG: # Use PyG's message passing out = self.propagate( edge_index, x=x, edge_attr=edge_attr, design=design_embed ) else: # Fallback implementation without PyG out = self._manual_propagate(x, edge_index, edge_attr, design_embed) # Apply design-based gating gate = self.design_gate(design_embed) out = out * gate return out def message( self, x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor, design: torch.Tensor ) -> torch.Tensor: """ Compute messages from source (j) to target (i) nodes. Args: x_i: Target node features [n_edges, in_channels] x_j: Source node features [n_edges, in_channels] edge_attr: Edge features [n_edges, edge_channels] design: Design embedding, broadcast to edges Returns: Messages [n_edges, out_channels] """ # Broadcast design to all edges design_broadcast = design.expand(x_i.size(0), -1) # Concatenate all inputs msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1) return self.message_net(msg_input) def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Update node features with aggregated messages. Args: aggr_out: Aggregated messages [n_nodes, out_channels] x: Original node features [n_nodes, in_channels] Returns: Updated node features [n_nodes, out_channels] """ return self.update_net(torch.cat([x, aggr_out], dim=-1)) def _manual_propagate( self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, design: torch.Tensor ) -> torch.Tensor: """Fallback message passing without PyG.""" row, col = edge_index # row = target, col = source # Gather features x_i = x[row] # Target features x_j = x[col] # Source features # Compute messages design_broadcast = design.expand(x_i.size(0), -1) msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1) messages = self.message_net(msg_input) # Aggregate (mean) n_nodes = x.size(0) aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device) count = torch.zeros(n_nodes, 1, device=x.device) aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages) count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1)) count = count.clamp(min=1) aggr_out = aggr_out / count # Update return self.update_net(torch.cat([x, aggr_out], dim=-1)) class ZernikeGNN(nn.Module): """ Graph Neural Network for mirror surface displacement prediction. This model learns to predict Z-displacement fields for all 4 gravity subcases from 11 design parameters. It uses a fixed polar grid graph structure and design-conditioned message passing. The key advantages over MLP: 1. Spatial awareness through message passing 2. Design conditioning modulates spatial information flow 3. Predicts full field (enabling correct relative computation) 4. Respects physics: smooth fields, radial/angular structure """ def __init__( self, n_design_vars: int = 11, n_subcases: int = 4, hidden_dim: int = 128, n_layers: int = 6, node_feat_dim: int = 4, edge_feat_dim: int = 4, dropout: float = 0.1 ): """ Args: n_design_vars: Number of design parameters (11 for mirror) n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°) hidden_dim: Hidden layer dimension n_layers: Number of message passing layers node_feat_dim: Node feature dimension (r, theta, x, y) edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle) dropout: Dropout rate """ super().__init__() self.n_design_vars = n_design_vars self.n_subcases = n_subcases self.hidden_dim = hidden_dim self.n_layers = n_layers # === Design Encoder === # Maps design parameters to hidden space self.design_encoder = nn.Sequential( nn.Linear(n_design_vars, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), ) # === Node Encoder === # Maps polar coordinates to hidden space self.node_encoder = nn.Sequential( nn.Linear(node_feat_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), ) # === Edge Encoder === # Maps edge features (dr, dtheta, distance, angle) to hidden space edge_hidden = hidden_dim // 2 self.edge_encoder = nn.Sequential( nn.Linear(edge_feat_dim, edge_hidden), nn.SiLU(), nn.Linear(edge_hidden, edge_hidden), ) # === Message Passing Layers === self.conv_layers = nn.ModuleList([ DesignConditionedConv( in_channels=hidden_dim, out_channels=hidden_dim, design_channels=hidden_dim, edge_channels=edge_hidden, ) for _ in range(n_layers) ]) # Layer norms for residual connections self.layer_norms = nn.ModuleList([ nn.LayerNorm(hidden_dim) for _ in range(n_layers) ]) # === Displacement Decoder === # Predicts Z-displacement for each subcase self.displacement_decoder = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.SiLU(), nn.Linear(hidden_dim // 2, n_subcases), ) # Initialize weights self._init_weights() def _init_weights(self): """Initialize weights with Xavier/Glorot initialization.""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward( self, node_features: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, design_vars: torch.Tensor ) -> torch.Tensor: """ Forward pass: design parameters → displacement field. Args: node_features: [n_nodes, 4] - (r, theta, x, y) normalized edge_index: [2, n_edges] - graph connectivity edge_attr: [n_edges, 4] - edge features normalized design_vars: [n_design_vars] or [batch, n_design_vars] Returns: z_displacement: [n_nodes, n_subcases] - Z-disp per subcase or [batch, n_nodes, n_subcases] if batched """ # Handle batched vs single design is_batched = design_vars.dim() == 2 if not is_batched: design_vars = design_vars.unsqueeze(0) # [1, n_design_vars] batch_size = design_vars.size(0) n_nodes = node_features.size(0) # Encode inputs design_h = self.design_encoder(design_vars) # [batch, hidden] node_h = self.node_encoder(node_features) # [n_nodes, hidden] edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden] # Process each batch item outputs = [] for b in range(batch_size): h = node_h.clone() # Start fresh for each design # Message passing with residual connections for conv, norm in zip(self.conv_layers, self.layer_norms): h_new = conv(h, edge_index, edge_h, design_h[b]) h = norm(h + h_new) # Residual + LayerNorm # Decode to displacement z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases] outputs.append(z_disp) # Stack outputs if is_batched: return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases] else: return outputs[0] # [n_nodes, n_subcases] def count_parameters(self) -> int: """Count trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) class ZernikeGNNLite(nn.Module): """ Lightweight version of ZernikeGNN for faster training/inference. Uses fewer layers and smaller hidden dimension, suitable for initial experiments or when training data is limited. """ def __init__( self, n_design_vars: int = 11, n_subcases: int = 4, hidden_dim: int = 64, n_layers: int = 4 ): super().__init__() self.n_subcases = n_subcases # Simpler design encoder self.design_encoder = nn.Sequential( nn.Linear(n_design_vars, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), ) # Simpler node encoder self.node_encoder = nn.Sequential( nn.Linear(4, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), ) # Edge encoder self.edge_encoder = nn.Linear(4, hidden_dim // 2) # Message passing self.conv_layers = nn.ModuleList([ DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2) for _ in range(n_layers) ]) # Decoder self.decoder = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.SiLU(), nn.Linear(hidden_dim // 2, n_subcases), ) def forward(self, node_features, edge_index, edge_attr, design_vars): """Forward pass.""" design_h = self.design_encoder(design_vars) node_h = self.node_encoder(node_features) edge_h = self.edge_encoder(edge_attr) for conv in self.conv_layers: node_h = node_h + conv(node_h, edge_index, edge_h, design_h) return self.decoder(node_h) # ============================================================================= # Utility functions # ============================================================================= def create_model( n_design_vars: int = 11, n_subcases: int = 4, model_type: str = 'full', **kwargs ) -> nn.Module: """ Factory function to create GNN model. Args: n_design_vars: Number of design parameters n_subcases: Number of subcases model_type: 'full' or 'lite' **kwargs: Additional arguments passed to model Returns: GNN model instance """ if model_type == 'lite': return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs) else: return ZernikeGNN(n_design_vars, n_subcases, **kwargs) def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module: """ Load trained model from checkpoint. Args: checkpoint_path: Path to .pt checkpoint file device: Device to load model to Returns: Loaded model in eval mode """ checkpoint = torch.load(checkpoint_path, map_location=device) # Get model config config = checkpoint.get('config', {}) model_type = config.get('model_type', 'full') # Create model model = create_model( n_design_vars=config.get('n_design_vars', 11), n_subcases=config.get('n_subcases', 4), model_type=model_type, hidden_dim=config.get('hidden_dim', 128), n_layers=config.get('n_layers', 6), ) # Load weights model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model # ============================================================================= # Testing # ============================================================================= if __name__ == '__main__': print("="*60) print("Testing ZernikeGNN") print("="*60) # Create model model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6) print(f"\nModel: {model.__class__.__name__}") print(f"Parameters: {model.count_parameters():,}") # Create dummy inputs n_nodes = 3000 n_edges = 17760 node_features = torch.randn(n_nodes, 4) edge_index = torch.randint(0, n_nodes, (2, n_edges)) edge_attr = torch.randn(n_edges, 4) design_vars = torch.randn(11) # Forward pass print("\n--- Single Forward Pass ---") with torch.no_grad(): output = model(node_features, edge_index, edge_attr, design_vars) print(f"Input design: {design_vars.shape}") print(f"Output shape: {output.shape}") print(f"Output range: [{output.min():.6f}, {output.max():.6f}]") # Batched forward pass print("\n--- Batched Forward Pass ---") batch_design = torch.randn(8, 11) with torch.no_grad(): output_batch = model(node_features, edge_index, edge_attr, batch_design) print(f"Batch design: {batch_design.shape}") print(f"Batch output: {output_batch.shape}") # Test lite model print("\n--- Lite Model ---") model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4) print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}") with torch.no_grad(): output_lite = model_lite(node_features, edge_index, edge_attr, design_vars) print(f"Lite output shape: {output_lite.shape}") print("\n" + "="*60) print("✓ All tests passed!") print("="*60)