450 lines
17 KiB
Markdown
450 lines
17 KiB
Markdown
|
|
# GNN Architecture Deep Dive
|
||
|
|
|
||
|
|
**Technical documentation for AtomizerField Graph Neural Networks**
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Overview
|
||
|
|
|
||
|
|
AtomizerField uses Graph Neural Networks (GNNs) to learn physics from FEA simulations. This document explains the architecture in detail.
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Why Graph Neural Networks?
|
||
|
|
|
||
|
|
FEA meshes are naturally graphs:
|
||
|
|
- **Nodes** = Grid points (GRID cards in Nastran)
|
||
|
|
- **Edges** = Element connectivity (CTETRA, CQUAD, etc.)
|
||
|
|
- **Node features** = Position, BCs, material properties
|
||
|
|
- **Edge features** = Element type, length, direction
|
||
|
|
|
||
|
|
Traditional neural networks (MLPs, CNNs) can't handle this irregular structure. GNNs can.
|
||
|
|
|
||
|
|
```
|
||
|
|
FEA Mesh Graph
|
||
|
|
═══════════════════════════════════════════════
|
||
|
|
o───o (N1)──(N2)
|
||
|
|
/│ │\ │╲ │╱
|
||
|
|
o─┼───┼─o → (N3)──(N4)
|
||
|
|
\│ │/ │╱ │╲
|
||
|
|
o───o (N5)──(N6)
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Model Architectures
|
||
|
|
|
||
|
|
### 1. Field Predictor GNN
|
||
|
|
|
||
|
|
Predicts complete displacement and stress fields.
|
||
|
|
|
||
|
|
```
|
||
|
|
┌─────────────────────────────────────────────────────────┐
|
||
|
|
│ Field Predictor GNN │
|
||
|
|
├─────────────────────────────────────────────────────────┤
|
||
|
|
│ │
|
||
|
|
│ Input Encoding │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ Node Features (12D per node): │ │
|
||
|
|
│ │ • Position (x, y, z) [3D] │ │
|
||
|
|
│ │ • Material (E, nu, rho) [3D] │ │
|
||
|
|
│ │ • Boundary conditions (fixed per DOF) [6D] │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ Edge Features (5D per edge): │ │
|
||
|
|
│ │ • Edge length [1D] │ │
|
||
|
|
│ │ • Direction vector [3D] │ │
|
||
|
|
│ │ • Element type [1D] │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ ↓ │
|
||
|
|
│ Message Passing Layers (6 layers) │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ for layer in range(6): │ │
|
||
|
|
│ │ h = MeshGraphConv(h, edge_index, edge_attr) │ │
|
||
|
|
│ │ h = LayerNorm(h) │ │
|
||
|
|
│ │ h = ReLU(h) │ │
|
||
|
|
│ │ h = Dropout(h, p=0.1) │ │
|
||
|
|
│ │ h = h + residual # Skip connection │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ ↓ │
|
||
|
|
│ Output Heads │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ Displacement Head: │ │
|
||
|
|
│ │ Linear(hidden → 64 → 6) # 6 DOF per node │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ Stress Head: │ │
|
||
|
|
│ │ Linear(hidden → 64 → 1) # Von Mises stress │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ │
|
||
|
|
│ Output: [N_nodes, 7] (6 displacement + 1 stress) │
|
||
|
|
└─────────────────────────────────────────────────────────┘
|
||
|
|
```
|
||
|
|
|
||
|
|
**Parameters**: 718,221 trainable
|
||
|
|
|
||
|
|
### 2. Parametric Field Predictor GNN
|
||
|
|
|
||
|
|
Predicts scalar objectives directly from design parameters.
|
||
|
|
|
||
|
|
```
|
||
|
|
┌─────────────────────────────────────────────────────────┐
|
||
|
|
│ Parametric Field Predictor GNN │
|
||
|
|
├─────────────────────────────────────────────────────────┤
|
||
|
|
│ │
|
||
|
|
│ Design Parameter Encoding │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ Design Params (4D): │ │
|
||
|
|
│ │ • beam_half_core_thickness │ │
|
||
|
|
│ │ • beam_face_thickness │ │
|
||
|
|
│ │ • holes_diameter │ │
|
||
|
|
│ │ • hole_count │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ Design Encoder MLP: │ │
|
||
|
|
│ │ Linear(4 → 64) → ReLU → Linear(64 → 128) │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ ↓ │
|
||
|
|
│ Design-Conditioned GNN │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ # Broadcast design encoding to all nodes │ │
|
||
|
|
│ │ node_features = node_features + design_encoding │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ for layer in range(4): │ │
|
||
|
|
│ │ h = GraphConv(h, edge_index) │ │
|
||
|
|
│ │ h = BatchNorm(h) │ │
|
||
|
|
│ │ h = ReLU(h) │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ ↓ │
|
||
|
|
│ Global Pooling │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ mean_pool = global_mean_pool(h) # [batch, 128] │ │
|
||
|
|
│ │ max_pool = global_max_pool(h) # [batch, 128] │ │
|
||
|
|
│ │ design = design_encoding # [batch, 128] │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ global_features = concat([mean_pool, max_pool, │ │
|
||
|
|
│ │ design]) # [batch, 384]│ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ ↓ │
|
||
|
|
│ Scalar Prediction Heads │
|
||
|
|
│ ┌──────────────────────────────────────────────────┐ │
|
||
|
|
│ │ MLP: Linear(384 → 128 → 64 → 4) │ │
|
||
|
|
│ │ │ │
|
||
|
|
│ │ Output: │ │
|
||
|
|
│ │ [0] = mass (grams) │ │
|
||
|
|
│ │ [1] = frequency (Hz) │ │
|
||
|
|
│ │ [2] = max_displacement (mm) │ │
|
||
|
|
│ │ [3] = max_stress (MPa) │ │
|
||
|
|
│ └──────────────────────────────────────────────────┘ │
|
||
|
|
│ │
|
||
|
|
│ Output: [batch, 4] (4 objectives) │
|
||
|
|
└─────────────────────────────────────────────────────────┘
|
||
|
|
```
|
||
|
|
|
||
|
|
**Parameters**: ~500,000 trainable
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Message Passing
|
||
|
|
|
||
|
|
The core of GNNs is message passing. Here's how it works:
|
||
|
|
|
||
|
|
### Standard Message Passing
|
||
|
|
|
||
|
|
```python
|
||
|
|
def message_passing(node_features, edge_index, edge_attr):
|
||
|
|
"""
|
||
|
|
node_features: [N_nodes, D_node]
|
||
|
|
edge_index: [2, N_edges] # Source → Target
|
||
|
|
edge_attr: [N_edges, D_edge]
|
||
|
|
"""
|
||
|
|
# Step 1: Compute messages
|
||
|
|
source_nodes = node_features[edge_index[0]] # [N_edges, D_node]
|
||
|
|
target_nodes = node_features[edge_index[1]] # [N_edges, D_node]
|
||
|
|
|
||
|
|
messages = MLP([source_nodes, target_nodes, edge_attr]) # [N_edges, D_msg]
|
||
|
|
|
||
|
|
# Step 2: Aggregate messages at each node
|
||
|
|
aggregated = scatter_add(messages, edge_index[1]) # [N_nodes, D_msg]
|
||
|
|
|
||
|
|
# Step 3: Update node features
|
||
|
|
updated = MLP([node_features, aggregated]) # [N_nodes, D_node]
|
||
|
|
|
||
|
|
return updated
|
||
|
|
```
|
||
|
|
|
||
|
|
### Custom MeshGraphConv
|
||
|
|
|
||
|
|
We use a custom convolution that respects FEA mesh structure:
|
||
|
|
|
||
|
|
```python
|
||
|
|
class MeshGraphConv(MessagePassing):
|
||
|
|
"""
|
||
|
|
Custom message passing for FEA meshes.
|
||
|
|
Accounts for:
|
||
|
|
- Edge lengths (stiffness depends on distance)
|
||
|
|
- Element types (different physics for solid/shell/beam)
|
||
|
|
- Direction vectors (anisotropic behavior)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def message(self, x_i, x_j, edge_attr):
|
||
|
|
# x_i: Target node features
|
||
|
|
# x_j: Source node features
|
||
|
|
# edge_attr: Edge features (length, direction, type)
|
||
|
|
|
||
|
|
# Compute message
|
||
|
|
edge_length = edge_attr[:, 0:1]
|
||
|
|
edge_direction = edge_attr[:, 1:4]
|
||
|
|
element_type = edge_attr[:, 4:5]
|
||
|
|
|
||
|
|
# Scale by inverse distance (like stiffness)
|
||
|
|
distance_weight = 1.0 / (edge_length + 1e-6)
|
||
|
|
|
||
|
|
# Combine source and target features
|
||
|
|
combined = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
||
|
|
message = self.mlp(combined) * distance_weight
|
||
|
|
|
||
|
|
return message
|
||
|
|
|
||
|
|
def aggregate(self, messages, index):
|
||
|
|
# Sum messages at each node (like force equilibrium)
|
||
|
|
return scatter_add(messages, index, dim=0)
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Feature Engineering
|
||
|
|
|
||
|
|
### Node Features (12D)
|
||
|
|
|
||
|
|
| Feature | Dimensions | Range | Description |
|
||
|
|
|---------|------------|-------|-------------|
|
||
|
|
| Position (x, y, z) | 3 | Normalized | Node coordinates |
|
||
|
|
| Material E | 1 | Log-scaled | Young's modulus |
|
||
|
|
| Material nu | 1 | [0, 0.5] | Poisson's ratio |
|
||
|
|
| Material rho | 1 | Log-scaled | Density |
|
||
|
|
| BC_x, BC_y, BC_z | 3 | {0, 1} | Fixed translation |
|
||
|
|
| BC_rx, BC_ry, BC_rz | 3 | {0, 1} | Fixed rotation |
|
||
|
|
|
||
|
|
### Edge Features (5D)
|
||
|
|
|
||
|
|
| Feature | Dimensions | Range | Description |
|
||
|
|
|---------|------------|-------|-------------|
|
||
|
|
| Length | 1 | Normalized | Edge length |
|
||
|
|
| Direction | 3 | [-1, 1] | Unit direction vector |
|
||
|
|
| Element type | 1 | Encoded | CTETRA=0, CHEXA=1, etc. |
|
||
|
|
|
||
|
|
### Normalization
|
||
|
|
|
||
|
|
```python
|
||
|
|
def normalize_features(node_features, edge_features, stats):
|
||
|
|
"""Normalize to zero mean, unit variance"""
|
||
|
|
|
||
|
|
# Node features
|
||
|
|
node_features = (node_features - stats['node_mean']) / stats['node_std']
|
||
|
|
|
||
|
|
# Edge features (length uses log normalization)
|
||
|
|
edge_features[:, 0] = torch.log(edge_features[:, 0] + 1e-6)
|
||
|
|
edge_features = (edge_features - stats['edge_mean']) / stats['edge_std']
|
||
|
|
|
||
|
|
return node_features, edge_features
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Training Details
|
||
|
|
|
||
|
|
### Optimizer
|
||
|
|
|
||
|
|
```python
|
||
|
|
optimizer = AdamW(
|
||
|
|
model.parameters(),
|
||
|
|
lr=1e-3,
|
||
|
|
weight_decay=1e-4,
|
||
|
|
betas=(0.9, 0.999)
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Learning Rate Schedule
|
||
|
|
|
||
|
|
```python
|
||
|
|
scheduler = CosineAnnealingWarmRestarts(
|
||
|
|
optimizer,
|
||
|
|
T_0=50, # Restart every 50 epochs
|
||
|
|
T_mult=2, # Double period after each restart
|
||
|
|
eta_min=1e-6
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Data Augmentation
|
||
|
|
|
||
|
|
```python
|
||
|
|
def augment_graph(data):
|
||
|
|
"""Random augmentation for better generalization"""
|
||
|
|
|
||
|
|
# Random rotation (physics is rotation-invariant)
|
||
|
|
if random.random() < 0.5:
|
||
|
|
angle = random.uniform(0, 2 * math.pi)
|
||
|
|
data = rotate_graph(data, angle, axis='z')
|
||
|
|
|
||
|
|
# Random noise (robustness)
|
||
|
|
if random.random() < 0.3:
|
||
|
|
data.x += torch.randn_like(data.x) * 0.01
|
||
|
|
|
||
|
|
return data
|
||
|
|
```
|
||
|
|
|
||
|
|
### Batch Processing
|
||
|
|
|
||
|
|
```python
|
||
|
|
from torch_geometric.data import DataLoader
|
||
|
|
|
||
|
|
loader = DataLoader(
|
||
|
|
dataset,
|
||
|
|
batch_size=32,
|
||
|
|
shuffle=True,
|
||
|
|
num_workers=4
|
||
|
|
)
|
||
|
|
|
||
|
|
for batch in loader:
|
||
|
|
# batch.x: [total_nodes, D_node]
|
||
|
|
# batch.edge_index: [2, total_edges]
|
||
|
|
# batch.batch: [total_nodes] - maps nodes to graphs
|
||
|
|
predictions = model(batch)
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Model Comparison
|
||
|
|
|
||
|
|
| Model | Parameters | Inference | Output | Use Case |
|
||
|
|
|-------|------------|-----------|--------|----------|
|
||
|
|
| Field Predictor | 718K | 50ms | Full field | When you need field visualization |
|
||
|
|
| Parametric | 500K | 4.5ms | 4 scalars | Direct optimization (fastest) |
|
||
|
|
| Ensemble (5x) | 2.5M | 25ms | 4 scalars + uncertainty | When confidence matters |
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Implementation Notes
|
||
|
|
|
||
|
|
### PyTorch Geometric
|
||
|
|
|
||
|
|
We use [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) for GNN operations:
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch_geometric
|
||
|
|
from torch_geometric.nn import MessagePassing, global_mean_pool
|
||
|
|
|
||
|
|
# Version requirements
|
||
|
|
# torch >= 2.0
|
||
|
|
# torch_geometric >= 2.3
|
||
|
|
```
|
||
|
|
|
||
|
|
### GPU Memory
|
||
|
|
|
||
|
|
| Model | Batch Size | GPU Memory |
|
||
|
|
|-------|------------|------------|
|
||
|
|
| Field Predictor | 16 | 4 GB |
|
||
|
|
| Parametric | 32 | 2 GB |
|
||
|
|
| Training | 16 | 8 GB |
|
||
|
|
|
||
|
|
### Checkpoints
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Save checkpoint
|
||
|
|
torch.save({
|
||
|
|
'model_state_dict': model.state_dict(),
|
||
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
||
|
|
'config': model_config,
|
||
|
|
'normalization_stats': stats,
|
||
|
|
'epoch': epoch,
|
||
|
|
'best_val_loss': best_loss
|
||
|
|
}, 'checkpoint.pt')
|
||
|
|
|
||
|
|
# Load checkpoint
|
||
|
|
checkpoint = torch.load('checkpoint.pt')
|
||
|
|
model = ParametricFieldPredictor(**checkpoint['config'])
|
||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Physics Interpretation
|
||
|
|
|
||
|
|
### Why GNNs Work for FEA
|
||
|
|
|
||
|
|
1. **Locality**: FEA solutions are local (nodes only affect neighbors)
|
||
|
|
2. **Superposition**: Linear FEA is additive (sum of effects)
|
||
|
|
3. **Equilibrium**: Force balance at each node (sum of messages = 0)
|
||
|
|
|
||
|
|
The GNN learns these principles:
|
||
|
|
- Message passing ≈ Force distribution through elements
|
||
|
|
- Aggregation ≈ Force equilibrium at nodes
|
||
|
|
- Multiple layers ≈ Load path propagation
|
||
|
|
|
||
|
|
### Physical Constraints
|
||
|
|
|
||
|
|
The architecture enforces physics:
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Displacement at fixed nodes = 0
|
||
|
|
displacement = model(data)
|
||
|
|
fixed_mask = data.boundary_conditions > 0
|
||
|
|
displacement[fixed_mask] = 0.0 # Hard constraint
|
||
|
|
|
||
|
|
# Stress-strain relationship (implicit)
|
||
|
|
# Learned by the network through training
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## Extension Points
|
||
|
|
|
||
|
|
### Adding New Element Types
|
||
|
|
|
||
|
|
```python
|
||
|
|
# In data_loader.py
|
||
|
|
ELEMENT_TYPES = {
|
||
|
|
'CTETRA': 0,
|
||
|
|
'CHEXA': 1,
|
||
|
|
'CPENTA': 2,
|
||
|
|
'CQUAD4': 3,
|
||
|
|
'CTRIA3': 4,
|
||
|
|
'CBAR': 5,
|
||
|
|
'CBEAM': 6,
|
||
|
|
# Add new types here
|
||
|
|
'CTETRA10': 7, # New 10-node tetrahedron
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
### Custom Output Heads
|
||
|
|
|
||
|
|
```python
|
||
|
|
class CustomFieldPredictor(FieldPredictorGNN):
|
||
|
|
def __init__(self, *args, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
|
||
|
|
# Add custom head for thermal analysis
|
||
|
|
self.temperature_head = nn.Linear(hidden_channels, 1)
|
||
|
|
|
||
|
|
def forward(self, data):
|
||
|
|
h = super().forward(data)
|
||
|
|
|
||
|
|
# Add temperature prediction
|
||
|
|
temperature = self.temperature_head(h)
|
||
|
|
return torch.cat([h, temperature], dim=-1)
|
||
|
|
```
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## References
|
||
|
|
|
||
|
|
1. Battaglia et al. (2018) "Relational inductive biases, deep learning, and graph networks"
|
||
|
|
2. Pfaff et al. (2021) "Learning Mesh-Based Simulation with Graph Networks"
|
||
|
|
3. Sanchez-Gonzalez et al. (2020) "Learning to Simulate Complex Physics with Graph Networks"
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
## See Also
|
||
|
|
|
||
|
|
- [Neural Features Complete](NEURAL_FEATURES_COMPLETE.md) - Overview of all features
|
||
|
|
- [Physics Loss Guide](PHYSICS_LOSS_GUIDE.md) - Loss function selection
|
||
|
|
- [Neural Workflow Tutorial](NEURAL_WORKFLOW_TUTORIAL.md) - Step-by-step guide
|