Files
Atomizer/docs/api/GNN_ARCHITECTURE.md

450 lines
17 KiB
Markdown
Raw Normal View History

# GNN Architecture Deep Dive
**Technical documentation for AtomizerField Graph Neural Networks**
---
## Overview
AtomizerField uses Graph Neural Networks (GNNs) to learn physics from FEA simulations. This document explains the architecture in detail.
---
## Why Graph Neural Networks?
FEA meshes are naturally graphs:
- **Nodes** = Grid points (GRID cards in Nastran)
- **Edges** = Element connectivity (CTETRA, CQUAD, etc.)
- **Node features** = Position, BCs, material properties
- **Edge features** = Element type, length, direction
Traditional neural networks (MLPs, CNNs) can't handle this irregular structure. GNNs can.
```
FEA Mesh Graph
═══════════════════════════════════════════════
o───o (N1)──(N2)
/│ │\ │╲ │╱
o─┼───┼─o → (N3)──(N4)
\│ │/ │╱ │╲
o───o (N5)──(N6)
```
---
## Model Architectures
### 1. Field Predictor GNN
Predicts complete displacement and stress fields.
```
┌─────────────────────────────────────────────────────────┐
│ Field Predictor GNN │
├─────────────────────────────────────────────────────────┤
│ │
│ Input Encoding │
│ ┌──────────────────────────────────────────────────┐ │
│ │ Node Features (12D per node): │ │
│ │ • Position (x, y, z) [3D] │ │
│ │ • Material (E, nu, rho) [3D] │ │
│ │ • Boundary conditions (fixed per DOF) [6D] │ │
│ │ │ │
│ │ Edge Features (5D per edge): │ │
│ │ • Edge length [1D] │ │
│ │ • Direction vector [3D] │ │
│ │ • Element type [1D] │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ Message Passing Layers (6 layers) │
│ ┌──────────────────────────────────────────────────┐ │
│ │ for layer in range(6): │ │
│ │ h = MeshGraphConv(h, edge_index, edge_attr) │ │
│ │ h = LayerNorm(h) │ │
│ │ h = ReLU(h) │ │
│ │ h = Dropout(h, p=0.1) │ │
│ │ h = h + residual # Skip connection │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ Output Heads │
│ ┌──────────────────────────────────────────────────┐ │
│ │ Displacement Head: │ │
│ │ Linear(hidden → 64 → 6) # 6 DOF per node │ │
│ │ │ │
│ │ Stress Head: │ │
│ │ Linear(hidden → 64 → 1) # Von Mises stress │ │
│ └──────────────────────────────────────────────────┘ │
│ │
│ Output: [N_nodes, 7] (6 displacement + 1 stress) │
└─────────────────────────────────────────────────────────┘
```
**Parameters**: 718,221 trainable
### 2. Parametric Field Predictor GNN
Predicts scalar objectives directly from design parameters.
```
┌─────────────────────────────────────────────────────────┐
│ Parametric Field Predictor GNN │
├─────────────────────────────────────────────────────────┤
│ │
│ Design Parameter Encoding │
│ ┌──────────────────────────────────────────────────┐ │
│ │ Design Params (4D): │ │
│ │ • beam_half_core_thickness │ │
│ │ • beam_face_thickness │ │
│ │ • holes_diameter │ │
│ │ • hole_count │ │
│ │ │ │
│ │ Design Encoder MLP: │ │
│ │ Linear(4 → 64) → ReLU → Linear(64 → 128) │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ Design-Conditioned GNN │
│ ┌──────────────────────────────────────────────────┐ │
│ │ # Broadcast design encoding to all nodes │ │
│ │ node_features = node_features + design_encoding │ │
│ │ │ │
│ │ for layer in range(4): │ │
│ │ h = GraphConv(h, edge_index) │ │
│ │ h = BatchNorm(h) │ │
│ │ h = ReLU(h) │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ Global Pooling │
│ ┌──────────────────────────────────────────────────┐ │
│ │ mean_pool = global_mean_pool(h) # [batch, 128] │ │
│ │ max_pool = global_max_pool(h) # [batch, 128] │ │
│ │ design = design_encoding # [batch, 128] │ │
│ │ │ │
│ │ global_features = concat([mean_pool, max_pool, │ │
│ │ design]) # [batch, 384]│ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ Scalar Prediction Heads │
│ ┌──────────────────────────────────────────────────┐ │
│ │ MLP: Linear(384 → 128 → 64 → 4) │ │
│ │ │ │
│ │ Output: │ │
│ │ [0] = mass (grams) │ │
│ │ [1] = frequency (Hz) │ │
│ │ [2] = max_displacement (mm) │ │
│ │ [3] = max_stress (MPa) │ │
│ └──────────────────────────────────────────────────┘ │
│ │
│ Output: [batch, 4] (4 objectives) │
└─────────────────────────────────────────────────────────┘
```
**Parameters**: ~500,000 trainable
---
## Message Passing
The core of GNNs is message passing. Here's how it works:
### Standard Message Passing
```python
def message_passing(node_features, edge_index, edge_attr):
"""
node_features: [N_nodes, D_node]
edge_index: [2, N_edges] # Source → Target
edge_attr: [N_edges, D_edge]
"""
# Step 1: Compute messages
source_nodes = node_features[edge_index[0]] # [N_edges, D_node]
target_nodes = node_features[edge_index[1]] # [N_edges, D_node]
messages = MLP([source_nodes, target_nodes, edge_attr]) # [N_edges, D_msg]
# Step 2: Aggregate messages at each node
aggregated = scatter_add(messages, edge_index[1]) # [N_nodes, D_msg]
# Step 3: Update node features
updated = MLP([node_features, aggregated]) # [N_nodes, D_node]
return updated
```
### Custom MeshGraphConv
We use a custom convolution that respects FEA mesh structure:
```python
class MeshGraphConv(MessagePassing):
"""
Custom message passing for FEA meshes.
Accounts for:
- Edge lengths (stiffness depends on distance)
- Element types (different physics for solid/shell/beam)
- Direction vectors (anisotropic behavior)
"""
def message(self, x_i, x_j, edge_attr):
# x_i: Target node features
# x_j: Source node features
# edge_attr: Edge features (length, direction, type)
# Compute message
edge_length = edge_attr[:, 0:1]
edge_direction = edge_attr[:, 1:4]
element_type = edge_attr[:, 4:5]
# Scale by inverse distance (like stiffness)
distance_weight = 1.0 / (edge_length + 1e-6)
# Combine source and target features
combined = torch.cat([x_i, x_j, edge_attr], dim=-1)
message = self.mlp(combined) * distance_weight
return message
def aggregate(self, messages, index):
# Sum messages at each node (like force equilibrium)
return scatter_add(messages, index, dim=0)
```
---
## Feature Engineering
### Node Features (12D)
| Feature | Dimensions | Range | Description |
|---------|------------|-------|-------------|
| Position (x, y, z) | 3 | Normalized | Node coordinates |
| Material E | 1 | Log-scaled | Young's modulus |
| Material nu | 1 | [0, 0.5] | Poisson's ratio |
| Material rho | 1 | Log-scaled | Density |
| BC_x, BC_y, BC_z | 3 | {0, 1} | Fixed translation |
| BC_rx, BC_ry, BC_rz | 3 | {0, 1} | Fixed rotation |
### Edge Features (5D)
| Feature | Dimensions | Range | Description |
|---------|------------|-------|-------------|
| Length | 1 | Normalized | Edge length |
| Direction | 3 | [-1, 1] | Unit direction vector |
| Element type | 1 | Encoded | CTETRA=0, CHEXA=1, etc. |
### Normalization
```python
def normalize_features(node_features, edge_features, stats):
"""Normalize to zero mean, unit variance"""
# Node features
node_features = (node_features - stats['node_mean']) / stats['node_std']
# Edge features (length uses log normalization)
edge_features[:, 0] = torch.log(edge_features[:, 0] + 1e-6)
edge_features = (edge_features - stats['edge_mean']) / stats['edge_std']
return node_features, edge_features
```
---
## Training Details
### Optimizer
```python
optimizer = AdamW(
model.parameters(),
lr=1e-3,
weight_decay=1e-4,
betas=(0.9, 0.999)
)
```
### Learning Rate Schedule
```python
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=50, # Restart every 50 epochs
T_mult=2, # Double period after each restart
eta_min=1e-6
)
```
### Data Augmentation
```python
def augment_graph(data):
"""Random augmentation for better generalization"""
# Random rotation (physics is rotation-invariant)
if random.random() < 0.5:
angle = random.uniform(0, 2 * math.pi)
data = rotate_graph(data, angle, axis='z')
# Random noise (robustness)
if random.random() < 0.3:
data.x += torch.randn_like(data.x) * 0.01
return data
```
### Batch Processing
```python
from torch_geometric.data import DataLoader
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
for batch in loader:
# batch.x: [total_nodes, D_node]
# batch.edge_index: [2, total_edges]
# batch.batch: [total_nodes] - maps nodes to graphs
predictions = model(batch)
```
---
## Model Comparison
| Model | Parameters | Inference | Output | Use Case |
|-------|------------|-----------|--------|----------|
| Field Predictor | 718K | 50ms | Full field | When you need field visualization |
| Parametric | 500K | 4.5ms | 4 scalars | Direct optimization (fastest) |
| Ensemble (5x) | 2.5M | 25ms | 4 scalars + uncertainty | When confidence matters |
---
## Implementation Notes
### PyTorch Geometric
We use [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) for GNN operations:
```python
import torch_geometric
from torch_geometric.nn import MessagePassing, global_mean_pool
# Version requirements
# torch >= 2.0
# torch_geometric >= 2.3
```
### GPU Memory
| Model | Batch Size | GPU Memory |
|-------|------------|------------|
| Field Predictor | 16 | 4 GB |
| Parametric | 32 | 2 GB |
| Training | 16 | 8 GB |
### Checkpoints
```python
# Save checkpoint
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': model_config,
'normalization_stats': stats,
'epoch': epoch,
'best_val_loss': best_loss
}, 'checkpoint.pt')
# Load checkpoint
checkpoint = torch.load('checkpoint.pt')
model = ParametricFieldPredictor(**checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])
```
---
## Physics Interpretation
### Why GNNs Work for FEA
1. **Locality**: FEA solutions are local (nodes only affect neighbors)
2. **Superposition**: Linear FEA is additive (sum of effects)
3. **Equilibrium**: Force balance at each node (sum of messages = 0)
The GNN learns these principles:
- Message passing ≈ Force distribution through elements
- Aggregation ≈ Force equilibrium at nodes
- Multiple layers ≈ Load path propagation
### Physical Constraints
The architecture enforces physics:
```python
# Displacement at fixed nodes = 0
displacement = model(data)
fixed_mask = data.boundary_conditions > 0
displacement[fixed_mask] = 0.0 # Hard constraint
# Stress-strain relationship (implicit)
# Learned by the network through training
```
---
## Extension Points
### Adding New Element Types
```python
# In data_loader.py
ELEMENT_TYPES = {
'CTETRA': 0,
'CHEXA': 1,
'CPENTA': 2,
'CQUAD4': 3,
'CTRIA3': 4,
'CBAR': 5,
'CBEAM': 6,
# Add new types here
'CTETRA10': 7, # New 10-node tetrahedron
}
```
### Custom Output Heads
```python
class CustomFieldPredictor(FieldPredictorGNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Add custom head for thermal analysis
self.temperature_head = nn.Linear(hidden_channels, 1)
def forward(self, data):
h = super().forward(data)
# Add temperature prediction
temperature = self.temperature_head(h)
return torch.cat([h, temperature], dim=-1)
```
---
## References
1. Battaglia et al. (2018) "Relational inductive biases, deep learning, and graph networks"
2. Pfaff et al. (2021) "Learning Mesh-Based Simulation with Graph Networks"
3. Sanchez-Gonzalez et al. (2020) "Learning to Simulate Complex Physics with Graph Networks"
---
## See Also
- [Neural Features Complete](NEURAL_FEATURES_COMPLETE.md) - Overview of all features
- [Physics Loss Guide](PHYSICS_LOSS_GUIDE.md) - Loss function selection
- [Neural Workflow Tutorial](NEURAL_WORKFLOW_TUTORIAL.md) - Step-by-step guide