583 lines
19 KiB
Python
583 lines
19 KiB
Python
|
|
"""
|
|||
|
|
Zernike GNN Model for Mirror Surface Deformation Prediction
|
|||
|
|
============================================================
|
|||
|
|
|
|||
|
|
This module implements a Graph Neural Network specifically designed for predicting
|
|||
|
|
mirror surface displacement fields from design parameters. The key innovation is
|
|||
|
|
using design-conditioned message passing on a polar grid graph.
|
|||
|
|
|
|||
|
|
Architecture:
|
|||
|
|
Design Variables [11]
|
|||
|
|
│
|
|||
|
|
▼
|
|||
|
|
Design Encoder [11 → 128]
|
|||
|
|
│
|
|||
|
|
└──────────────────┐
|
|||
|
|
│
|
|||
|
|
Node Features │
|
|||
|
|
[r, θ, x, y] │
|
|||
|
|
│ │
|
|||
|
|
▼ │
|
|||
|
|
Node Encoder │
|
|||
|
|
[4 → 128] │
|
|||
|
|
│ │
|
|||
|
|
└─────────┬────────┘
|
|||
|
|
│
|
|||
|
|
▼
|
|||
|
|
┌─────────────────────────────┐
|
|||
|
|
│ Design-Conditioned │
|
|||
|
|
│ Message Passing (× 6) │
|
|||
|
|
│ │
|
|||
|
|
│ • Polar-aware edges │
|
|||
|
|
│ • Design modulates messages │
|
|||
|
|
│ • Residual connections │
|
|||
|
|
└─────────────┬───────────────┘
|
|||
|
|
│
|
|||
|
|
▼
|
|||
|
|
Per-Node Decoder [128 → 4]
|
|||
|
|
│
|
|||
|
|
▼
|
|||
|
|
Z-Displacement Field [3000, 4]
|
|||
|
|
(one value per node per subcase)
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
from optimization_engine.gnn.zernike_gnn import ZernikeGNN
|
|||
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
|||
|
|
|
|||
|
|
graph = PolarMirrorGraph()
|
|||
|
|
model = ZernikeGNN(n_design_vars=11, n_subcases=4)
|
|||
|
|
|
|||
|
|
# Forward pass
|
|||
|
|
z_disp = model(
|
|||
|
|
node_features=graph.get_node_features(),
|
|||
|
|
edge_index=graph.edge_index,
|
|||
|
|
edge_attr=graph.get_edge_features(),
|
|||
|
|
design_vars=design_tensor
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from torch_geometric.nn import MessagePassing
|
|||
|
|
HAS_PYG = True
|
|||
|
|
except ImportError:
|
|||
|
|
HAS_PYG = False
|
|||
|
|
MessagePassing = nn.Module # Fallback for type hints
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module):
|
|||
|
|
"""
|
|||
|
|
Message passing layer conditioned on global design parameters.
|
|||
|
|
|
|||
|
|
This layer propagates information through the polar graph while
|
|||
|
|
conditioning on design parameters. The design embedding modulates
|
|||
|
|
how messages flow between nodes.
|
|||
|
|
|
|||
|
|
Key insight: Design parameters affect the stiffness distribution
|
|||
|
|
in the mirror support structure. This layer learns how those changes
|
|||
|
|
propagate spatially through the optical surface.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
in_channels: int,
|
|||
|
|
out_channels: int,
|
|||
|
|
design_channels: int,
|
|||
|
|
edge_channels: int = 4,
|
|||
|
|
aggr: str = 'mean'
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
in_channels: Input node feature dimension
|
|||
|
|
out_channels: Output node feature dimension
|
|||
|
|
design_channels: Design embedding dimension
|
|||
|
|
edge_channels: Edge feature dimension
|
|||
|
|
aggr: Aggregation method ('mean', 'sum', 'max')
|
|||
|
|
"""
|
|||
|
|
if HAS_PYG:
|
|||
|
|
super().__init__(aggr=aggr)
|
|||
|
|
else:
|
|||
|
|
super().__init__()
|
|||
|
|
self.aggr = aggr
|
|||
|
|
|
|||
|
|
self.in_channels = in_channels
|
|||
|
|
self.out_channels = out_channels
|
|||
|
|
|
|||
|
|
# Message network: source node + target node + design + edge
|
|||
|
|
msg_input_dim = 2 * in_channels + design_channels + edge_channels
|
|||
|
|
self.message_net = nn.Sequential(
|
|||
|
|
nn.Linear(msg_input_dim, out_channels * 2),
|
|||
|
|
nn.LayerNorm(out_channels * 2),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Dropout(0.1),
|
|||
|
|
nn.Linear(out_channels * 2, out_channels),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Update network: combines aggregated messages with original features
|
|||
|
|
self.update_net = nn.Sequential(
|
|||
|
|
nn.Linear(in_channels + out_channels, out_channels),
|
|||
|
|
nn.LayerNorm(out_channels),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Design gate: allows design to modulate message importance
|
|||
|
|
self.design_gate = nn.Sequential(
|
|||
|
|
nn.Linear(design_channels, out_channels),
|
|||
|
|
nn.Sigmoid(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
x: torch.Tensor,
|
|||
|
|
edge_index: torch.Tensor,
|
|||
|
|
edge_attr: torch.Tensor,
|
|||
|
|
design_embed: torch.Tensor
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Forward pass with design conditioning.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
x: Node features [n_nodes, in_channels]
|
|||
|
|
edge_index: Graph connectivity [2, n_edges]
|
|||
|
|
edge_attr: Edge features [n_edges, edge_channels]
|
|||
|
|
design_embed: Design embedding [design_channels]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Updated node features [n_nodes, out_channels]
|
|||
|
|
"""
|
|||
|
|
if HAS_PYG:
|
|||
|
|
# Use PyG's message passing
|
|||
|
|
out = self.propagate(
|
|||
|
|
edge_index, x=x, edge_attr=edge_attr, design=design_embed
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# Fallback implementation without PyG
|
|||
|
|
out = self._manual_propagate(x, edge_index, edge_attr, design_embed)
|
|||
|
|
|
|||
|
|
# Apply design-based gating
|
|||
|
|
gate = self.design_gate(design_embed)
|
|||
|
|
out = out * gate
|
|||
|
|
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
def message(
|
|||
|
|
self,
|
|||
|
|
x_i: torch.Tensor,
|
|||
|
|
x_j: torch.Tensor,
|
|||
|
|
edge_attr: torch.Tensor,
|
|||
|
|
design: torch.Tensor
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Compute messages from source (j) to target (i) nodes.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
x_i: Target node features [n_edges, in_channels]
|
|||
|
|
x_j: Source node features [n_edges, in_channels]
|
|||
|
|
edge_attr: Edge features [n_edges, edge_channels]
|
|||
|
|
design: Design embedding, broadcast to edges
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Messages [n_edges, out_channels]
|
|||
|
|
"""
|
|||
|
|
# Broadcast design to all edges
|
|||
|
|
design_broadcast = design.expand(x_i.size(0), -1)
|
|||
|
|
|
|||
|
|
# Concatenate all inputs
|
|||
|
|
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
|||
|
|
|
|||
|
|
return self.message_net(msg_input)
|
|||
|
|
|
|||
|
|
def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Update node features with aggregated messages.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
aggr_out: Aggregated messages [n_nodes, out_channels]
|
|||
|
|
x: Original node features [n_nodes, in_channels]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Updated node features [n_nodes, out_channels]
|
|||
|
|
"""
|
|||
|
|
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
|||
|
|
|
|||
|
|
def _manual_propagate(
|
|||
|
|
self,
|
|||
|
|
x: torch.Tensor,
|
|||
|
|
edge_index: torch.Tensor,
|
|||
|
|
edge_attr: torch.Tensor,
|
|||
|
|
design: torch.Tensor
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""Fallback message passing without PyG."""
|
|||
|
|
row, col = edge_index # row = target, col = source
|
|||
|
|
|
|||
|
|
# Gather features
|
|||
|
|
x_i = x[row] # Target features
|
|||
|
|
x_j = x[col] # Source features
|
|||
|
|
|
|||
|
|
# Compute messages
|
|||
|
|
design_broadcast = design.expand(x_i.size(0), -1)
|
|||
|
|
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
|||
|
|
messages = self.message_net(msg_input)
|
|||
|
|
|
|||
|
|
# Aggregate (mean)
|
|||
|
|
n_nodes = x.size(0)
|
|||
|
|
aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device)
|
|||
|
|
count = torch.zeros(n_nodes, 1, device=x.device)
|
|||
|
|
|
|||
|
|
aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages)
|
|||
|
|
count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1))
|
|||
|
|
count = count.clamp(min=1)
|
|||
|
|
aggr_out = aggr_out / count
|
|||
|
|
|
|||
|
|
# Update
|
|||
|
|
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZernikeGNN(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Graph Neural Network for mirror surface displacement prediction.
|
|||
|
|
|
|||
|
|
This model learns to predict Z-displacement fields for all 4 gravity
|
|||
|
|
subcases from 11 design parameters. It uses a fixed polar grid graph
|
|||
|
|
structure and design-conditioned message passing.
|
|||
|
|
|
|||
|
|
The key advantages over MLP:
|
|||
|
|
1. Spatial awareness through message passing
|
|||
|
|
2. Design conditioning modulates spatial information flow
|
|||
|
|
3. Predicts full field (enabling correct relative computation)
|
|||
|
|
4. Respects physics: smooth fields, radial/angular structure
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
n_design_vars: int = 11,
|
|||
|
|
n_subcases: int = 4,
|
|||
|
|
hidden_dim: int = 128,
|
|||
|
|
n_layers: int = 6,
|
|||
|
|
node_feat_dim: int = 4,
|
|||
|
|
edge_feat_dim: int = 4,
|
|||
|
|
dropout: float = 0.1
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
n_design_vars: Number of design parameters (11 for mirror)
|
|||
|
|
n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°)
|
|||
|
|
hidden_dim: Hidden layer dimension
|
|||
|
|
n_layers: Number of message passing layers
|
|||
|
|
node_feat_dim: Node feature dimension (r, theta, x, y)
|
|||
|
|
edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle)
|
|||
|
|
dropout: Dropout rate
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.n_design_vars = n_design_vars
|
|||
|
|
self.n_subcases = n_subcases
|
|||
|
|
self.hidden_dim = hidden_dim
|
|||
|
|
self.n_layers = n_layers
|
|||
|
|
|
|||
|
|
# === Design Encoder ===
|
|||
|
|
# Maps design parameters to hidden space
|
|||
|
|
self.design_encoder = nn.Sequential(
|
|||
|
|
nn.Linear(n_design_vars, hidden_dim),
|
|||
|
|
nn.LayerNorm(hidden_dim),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Dropout(dropout),
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|||
|
|
nn.LayerNorm(hidden_dim),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# === Node Encoder ===
|
|||
|
|
# Maps polar coordinates to hidden space
|
|||
|
|
self.node_encoder = nn.Sequential(
|
|||
|
|
nn.Linear(node_feat_dim, hidden_dim),
|
|||
|
|
nn.LayerNorm(hidden_dim),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Dropout(dropout),
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|||
|
|
nn.LayerNorm(hidden_dim),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# === Edge Encoder ===
|
|||
|
|
# Maps edge features (dr, dtheta, distance, angle) to hidden space
|
|||
|
|
edge_hidden = hidden_dim // 2
|
|||
|
|
self.edge_encoder = nn.Sequential(
|
|||
|
|
nn.Linear(edge_feat_dim, edge_hidden),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Linear(edge_hidden, edge_hidden),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# === Message Passing Layers ===
|
|||
|
|
self.conv_layers = nn.ModuleList([
|
|||
|
|
DesignConditionedConv(
|
|||
|
|
in_channels=hidden_dim,
|
|||
|
|
out_channels=hidden_dim,
|
|||
|
|
design_channels=hidden_dim,
|
|||
|
|
edge_channels=edge_hidden,
|
|||
|
|
)
|
|||
|
|
for _ in range(n_layers)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# Layer norms for residual connections
|
|||
|
|
self.layer_norms = nn.ModuleList([
|
|||
|
|
nn.LayerNorm(hidden_dim) for _ in range(n_layers)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# === Displacement Decoder ===
|
|||
|
|
# Predicts Z-displacement for each subcase
|
|||
|
|
self.displacement_decoder = nn.Sequential(
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|||
|
|
nn.LayerNorm(hidden_dim),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Dropout(dropout),
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Linear(hidden_dim // 2, n_subcases),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Initialize weights
|
|||
|
|
self._init_weights()
|
|||
|
|
|
|||
|
|
def _init_weights(self):
|
|||
|
|
"""Initialize weights with Xavier/Glorot initialization."""
|
|||
|
|
for module in self.modules():
|
|||
|
|
if isinstance(module, nn.Linear):
|
|||
|
|
nn.init.xavier_uniform_(module.weight)
|
|||
|
|
if module.bias is not None:
|
|||
|
|
nn.init.zeros_(module.bias)
|
|||
|
|
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
node_features: torch.Tensor,
|
|||
|
|
edge_index: torch.Tensor,
|
|||
|
|
edge_attr: torch.Tensor,
|
|||
|
|
design_vars: torch.Tensor
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Forward pass: design parameters → displacement field.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node_features: [n_nodes, 4] - (r, theta, x, y) normalized
|
|||
|
|
edge_index: [2, n_edges] - graph connectivity
|
|||
|
|
edge_attr: [n_edges, 4] - edge features normalized
|
|||
|
|
design_vars: [n_design_vars] or [batch, n_design_vars]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
z_displacement: [n_nodes, n_subcases] - Z-disp per subcase
|
|||
|
|
or [batch, n_nodes, n_subcases] if batched
|
|||
|
|
"""
|
|||
|
|
# Handle batched vs single design
|
|||
|
|
is_batched = design_vars.dim() == 2
|
|||
|
|
if not is_batched:
|
|||
|
|
design_vars = design_vars.unsqueeze(0) # [1, n_design_vars]
|
|||
|
|
|
|||
|
|
batch_size = design_vars.size(0)
|
|||
|
|
n_nodes = node_features.size(0)
|
|||
|
|
|
|||
|
|
# Encode inputs
|
|||
|
|
design_h = self.design_encoder(design_vars) # [batch, hidden]
|
|||
|
|
node_h = self.node_encoder(node_features) # [n_nodes, hidden]
|
|||
|
|
edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden]
|
|||
|
|
|
|||
|
|
# Process each batch item
|
|||
|
|
outputs = []
|
|||
|
|
for b in range(batch_size):
|
|||
|
|
h = node_h.clone() # Start fresh for each design
|
|||
|
|
|
|||
|
|
# Message passing with residual connections
|
|||
|
|
for conv, norm in zip(self.conv_layers, self.layer_norms):
|
|||
|
|
h_new = conv(h, edge_index, edge_h, design_h[b])
|
|||
|
|
h = norm(h + h_new) # Residual + LayerNorm
|
|||
|
|
|
|||
|
|
# Decode to displacement
|
|||
|
|
z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases]
|
|||
|
|
outputs.append(z_disp)
|
|||
|
|
|
|||
|
|
# Stack outputs
|
|||
|
|
if is_batched:
|
|||
|
|
return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases]
|
|||
|
|
else:
|
|||
|
|
return outputs[0] # [n_nodes, n_subcases]
|
|||
|
|
|
|||
|
|
def count_parameters(self) -> int:
|
|||
|
|
"""Count trainable parameters."""
|
|||
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZernikeGNNLite(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Lightweight version of ZernikeGNN for faster training/inference.
|
|||
|
|
|
|||
|
|
Uses fewer layers and smaller hidden dimension, suitable for
|
|||
|
|
initial experiments or when training data is limited.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
n_design_vars: int = 11,
|
|||
|
|
n_subcases: int = 4,
|
|||
|
|
hidden_dim: int = 64,
|
|||
|
|
n_layers: int = 4
|
|||
|
|
):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.n_subcases = n_subcases
|
|||
|
|
|
|||
|
|
# Simpler design encoder
|
|||
|
|
self.design_encoder = nn.Sequential(
|
|||
|
|
nn.Linear(n_design_vars, hidden_dim),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Simpler node encoder
|
|||
|
|
self.node_encoder = nn.Sequential(
|
|||
|
|
nn.Linear(4, hidden_dim),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Edge encoder
|
|||
|
|
self.edge_encoder = nn.Linear(4, hidden_dim // 2)
|
|||
|
|
|
|||
|
|
# Message passing
|
|||
|
|
self.conv_layers = nn.ModuleList([
|
|||
|
|
DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2)
|
|||
|
|
for _ in range(n_layers)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# Decoder
|
|||
|
|
self.decoder = nn.Sequential(
|
|||
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|||
|
|
nn.SiLU(),
|
|||
|
|
nn.Linear(hidden_dim // 2, n_subcases),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, node_features, edge_index, edge_attr, design_vars):
|
|||
|
|
"""Forward pass."""
|
|||
|
|
design_h = self.design_encoder(design_vars)
|
|||
|
|
node_h = self.node_encoder(node_features)
|
|||
|
|
edge_h = self.edge_encoder(edge_attr)
|
|||
|
|
|
|||
|
|
for conv in self.conv_layers:
|
|||
|
|
node_h = node_h + conv(node_h, edge_index, edge_h, design_h)
|
|||
|
|
|
|||
|
|
return self.decoder(node_h)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Utility functions
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
def create_model(
|
|||
|
|
n_design_vars: int = 11,
|
|||
|
|
n_subcases: int = 4,
|
|||
|
|
model_type: str = 'full',
|
|||
|
|
**kwargs
|
|||
|
|
) -> nn.Module:
|
|||
|
|
"""
|
|||
|
|
Factory function to create GNN model.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
n_design_vars: Number of design parameters
|
|||
|
|
n_subcases: Number of subcases
|
|||
|
|
model_type: 'full' or 'lite'
|
|||
|
|
**kwargs: Additional arguments passed to model
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
GNN model instance
|
|||
|
|
"""
|
|||
|
|
if model_type == 'lite':
|
|||
|
|
return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs)
|
|||
|
|
else:
|
|||
|
|
return ZernikeGNN(n_design_vars, n_subcases, **kwargs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module:
|
|||
|
|
"""
|
|||
|
|
Load trained model from checkpoint.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
checkpoint_path: Path to .pt checkpoint file
|
|||
|
|
device: Device to load model to
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Loaded model in eval mode
|
|||
|
|
"""
|
|||
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|||
|
|
|
|||
|
|
# Get model config
|
|||
|
|
config = checkpoint.get('config', {})
|
|||
|
|
model_type = config.get('model_type', 'full')
|
|||
|
|
|
|||
|
|
# Create model
|
|||
|
|
model = create_model(
|
|||
|
|
n_design_vars=config.get('n_design_vars', 11),
|
|||
|
|
n_subcases=config.get('n_subcases', 4),
|
|||
|
|
model_type=model_type,
|
|||
|
|
hidden_dim=config.get('hidden_dim', 128),
|
|||
|
|
n_layers=config.get('n_layers', 6),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Load weights
|
|||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
return model
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Testing
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
print("="*60)
|
|||
|
|
print("Testing ZernikeGNN")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# Create model
|
|||
|
|
model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6)
|
|||
|
|
print(f"\nModel: {model.__class__.__name__}")
|
|||
|
|
print(f"Parameters: {model.count_parameters():,}")
|
|||
|
|
|
|||
|
|
# Create dummy inputs
|
|||
|
|
n_nodes = 3000
|
|||
|
|
n_edges = 17760
|
|||
|
|
|
|||
|
|
node_features = torch.randn(n_nodes, 4)
|
|||
|
|
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
|||
|
|
edge_attr = torch.randn(n_edges, 4)
|
|||
|
|
design_vars = torch.randn(11)
|
|||
|
|
|
|||
|
|
# Forward pass
|
|||
|
|
print("\n--- Single Forward Pass ---")
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output = model(node_features, edge_index, edge_attr, design_vars)
|
|||
|
|
print(f"Input design: {design_vars.shape}")
|
|||
|
|
print(f"Output shape: {output.shape}")
|
|||
|
|
print(f"Output range: [{output.min():.6f}, {output.max():.6f}]")
|
|||
|
|
|
|||
|
|
# Batched forward pass
|
|||
|
|
print("\n--- Batched Forward Pass ---")
|
|||
|
|
batch_design = torch.randn(8, 11)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output_batch = model(node_features, edge_index, edge_attr, batch_design)
|
|||
|
|
print(f"Batch design: {batch_design.shape}")
|
|||
|
|
print(f"Batch output: {output_batch.shape}")
|
|||
|
|
|
|||
|
|
# Test lite model
|
|||
|
|
print("\n--- Lite Model ---")
|
|||
|
|
model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4)
|
|||
|
|
print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}")
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output_lite = model_lite(node_features, edge_index, edge_attr, design_vars)
|
|||
|
|
print(f"Lite output shape: {output_lite.shape}")
|
|||
|
|
|
|||
|
|
print("\n" + "="*60)
|
|||
|
|
print("✓ All tests passed!")
|
|||
|
|
print("="*60)
|