# GNN Architecture Deep Dive **Technical documentation for AtomizerField Graph Neural Networks** --- ## Overview AtomizerField uses Graph Neural Networks (GNNs) to learn physics from FEA simulations. This document explains the architecture in detail. --- ## Why Graph Neural Networks? FEA meshes are naturally graphs: - **Nodes** = Grid points (GRID cards in Nastran) - **Edges** = Element connectivity (CTETRA, CQUAD, etc.) - **Node features** = Position, BCs, material properties - **Edge features** = Element type, length, direction Traditional neural networks (MLPs, CNNs) can't handle this irregular structure. GNNs can. ``` FEA Mesh Graph ═══════════════════════════════════════════════ o───o (N1)──(N2) /│ │\ │╲ │╱ o─┼───┼─o → (N3)──(N4) \│ │/ │╱ │╲ o───o (N5)──(N6) ``` --- ## Model Architectures ### 1. Field Predictor GNN Predicts complete displacement and stress fields. ``` ┌─────────────────────────────────────────────────────────┐ │ Field Predictor GNN │ ├─────────────────────────────────────────────────────────┤ │ │ │ Input Encoding │ │ ┌──────────────────────────────────────────────────┐ │ │ │ Node Features (12D per node): │ │ │ │ • Position (x, y, z) [3D] │ │ │ │ • Material (E, nu, rho) [3D] │ │ │ │ • Boundary conditions (fixed per DOF) [6D] │ │ │ │ │ │ │ │ Edge Features (5D per edge): │ │ │ │ • Edge length [1D] │ │ │ │ • Direction vector [3D] │ │ │ │ • Element type [1D] │ │ │ └──────────────────────────────────────────────────┘ │ │ ↓ │ │ Message Passing Layers (6 layers) │ │ ┌──────────────────────────────────────────────────┐ │ │ │ for layer in range(6): │ │ │ │ h = MeshGraphConv(h, edge_index, edge_attr) │ │ │ │ h = LayerNorm(h) │ │ │ │ h = ReLU(h) │ │ │ │ h = Dropout(h, p=0.1) │ │ │ │ h = h + residual # Skip connection │ │ │ └──────────────────────────────────────────────────┘ │ │ ↓ │ │ Output Heads │ │ ┌──────────────────────────────────────────────────┐ │ │ │ Displacement Head: │ │ │ │ Linear(hidden → 64 → 6) # 6 DOF per node │ │ │ │ │ │ │ │ Stress Head: │ │ │ │ Linear(hidden → 64 → 1) # Von Mises stress │ │ │ └──────────────────────────────────────────────────┘ │ │ │ │ Output: [N_nodes, 7] (6 displacement + 1 stress) │ └─────────────────────────────────────────────────────────┘ ``` **Parameters**: 718,221 trainable ### 2. Parametric Field Predictor GNN Predicts scalar objectives directly from design parameters. ``` ┌─────────────────────────────────────────────────────────┐ │ Parametric Field Predictor GNN │ ├─────────────────────────────────────────────────────────┤ │ │ │ Design Parameter Encoding │ │ ┌──────────────────────────────────────────────────┐ │ │ │ Design Params (4D): │ │ │ │ • beam_half_core_thickness │ │ │ │ • beam_face_thickness │ │ │ │ • holes_diameter │ │ │ │ • hole_count │ │ │ │ │ │ │ │ Design Encoder MLP: │ │ │ │ Linear(4 → 64) → ReLU → Linear(64 → 128) │ │ │ └──────────────────────────────────────────────────┘ │ │ ↓ │ │ Design-Conditioned GNN │ │ ┌──────────────────────────────────────────────────┐ │ │ │ # Broadcast design encoding to all nodes │ │ │ │ node_features = node_features + design_encoding │ │ │ │ │ │ │ │ for layer in range(4): │ │ │ │ h = GraphConv(h, edge_index) │ │ │ │ h = BatchNorm(h) │ │ │ │ h = ReLU(h) │ │ │ └──────────────────────────────────────────────────┘ │ │ ↓ │ │ Global Pooling │ │ ┌──────────────────────────────────────────────────┐ │ │ │ mean_pool = global_mean_pool(h) # [batch, 128] │ │ │ │ max_pool = global_max_pool(h) # [batch, 128] │ │ │ │ design = design_encoding # [batch, 128] │ │ │ │ │ │ │ │ global_features = concat([mean_pool, max_pool, │ │ │ │ design]) # [batch, 384]│ │ │ └──────────────────────────────────────────────────┘ │ │ ↓ │ │ Scalar Prediction Heads │ │ ┌──────────────────────────────────────────────────┐ │ │ │ MLP: Linear(384 → 128 → 64 → 4) │ │ │ │ │ │ │ │ Output: │ │ │ │ [0] = mass (grams) │ │ │ │ [1] = frequency (Hz) │ │ │ │ [2] = max_displacement (mm) │ │ │ │ [3] = max_stress (MPa) │ │ │ └──────────────────────────────────────────────────┘ │ │ │ │ Output: [batch, 4] (4 objectives) │ └─────────────────────────────────────────────────────────┘ ``` **Parameters**: ~500,000 trainable --- ## Message Passing The core of GNNs is message passing. Here's how it works: ### Standard Message Passing ```python def message_passing(node_features, edge_index, edge_attr): """ node_features: [N_nodes, D_node] edge_index: [2, N_edges] # Source → Target edge_attr: [N_edges, D_edge] """ # Step 1: Compute messages source_nodes = node_features[edge_index[0]] # [N_edges, D_node] target_nodes = node_features[edge_index[1]] # [N_edges, D_node] messages = MLP([source_nodes, target_nodes, edge_attr]) # [N_edges, D_msg] # Step 2: Aggregate messages at each node aggregated = scatter_add(messages, edge_index[1]) # [N_nodes, D_msg] # Step 3: Update node features updated = MLP([node_features, aggregated]) # [N_nodes, D_node] return updated ``` ### Custom MeshGraphConv We use a custom convolution that respects FEA mesh structure: ```python class MeshGraphConv(MessagePassing): """ Custom message passing for FEA meshes. Accounts for: - Edge lengths (stiffness depends on distance) - Element types (different physics for solid/shell/beam) - Direction vectors (anisotropic behavior) """ def message(self, x_i, x_j, edge_attr): # x_i: Target node features # x_j: Source node features # edge_attr: Edge features (length, direction, type) # Compute message edge_length = edge_attr[:, 0:1] edge_direction = edge_attr[:, 1:4] element_type = edge_attr[:, 4:5] # Scale by inverse distance (like stiffness) distance_weight = 1.0 / (edge_length + 1e-6) # Combine source and target features combined = torch.cat([x_i, x_j, edge_attr], dim=-1) message = self.mlp(combined) * distance_weight return message def aggregate(self, messages, index): # Sum messages at each node (like force equilibrium) return scatter_add(messages, index, dim=0) ``` --- ## Feature Engineering ### Node Features (12D) | Feature | Dimensions | Range | Description | |---------|------------|-------|-------------| | Position (x, y, z) | 3 | Normalized | Node coordinates | | Material E | 1 | Log-scaled | Young's modulus | | Material nu | 1 | [0, 0.5] | Poisson's ratio | | Material rho | 1 | Log-scaled | Density | | BC_x, BC_y, BC_z | 3 | {0, 1} | Fixed translation | | BC_rx, BC_ry, BC_rz | 3 | {0, 1} | Fixed rotation | ### Edge Features (5D) | Feature | Dimensions | Range | Description | |---------|------------|-------|-------------| | Length | 1 | Normalized | Edge length | | Direction | 3 | [-1, 1] | Unit direction vector | | Element type | 1 | Encoded | CTETRA=0, CHEXA=1, etc. | ### Normalization ```python def normalize_features(node_features, edge_features, stats): """Normalize to zero mean, unit variance""" # Node features node_features = (node_features - stats['node_mean']) / stats['node_std'] # Edge features (length uses log normalization) edge_features[:, 0] = torch.log(edge_features[:, 0] + 1e-6) edge_features = (edge_features - stats['edge_mean']) / stats['edge_std'] return node_features, edge_features ``` --- ## Training Details ### Optimizer ```python optimizer = AdamW( model.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999) ) ``` ### Learning Rate Schedule ```python scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=50, # Restart every 50 epochs T_mult=2, # Double period after each restart eta_min=1e-6 ) ``` ### Data Augmentation ```python def augment_graph(data): """Random augmentation for better generalization""" # Random rotation (physics is rotation-invariant) if random.random() < 0.5: angle = random.uniform(0, 2 * math.pi) data = rotate_graph(data, angle, axis='z') # Random noise (robustness) if random.random() < 0.3: data.x += torch.randn_like(data.x) * 0.01 return data ``` ### Batch Processing ```python from torch_geometric.data import DataLoader loader = DataLoader( dataset, batch_size=32, shuffle=True, num_workers=4 ) for batch in loader: # batch.x: [total_nodes, D_node] # batch.edge_index: [2, total_edges] # batch.batch: [total_nodes] - maps nodes to graphs predictions = model(batch) ``` --- ## Model Comparison | Model | Parameters | Inference | Output | Use Case | |-------|------------|-----------|--------|----------| | Field Predictor | 718K | 50ms | Full field | When you need field visualization | | Parametric | 500K | 4.5ms | 4 scalars | Direct optimization (fastest) | | Ensemble (5x) | 2.5M | 25ms | 4 scalars + uncertainty | When confidence matters | --- ## Implementation Notes ### PyTorch Geometric We use [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) for GNN operations: ```python import torch_geometric from torch_geometric.nn import MessagePassing, global_mean_pool # Version requirements # torch >= 2.0 # torch_geometric >= 2.3 ``` ### GPU Memory | Model | Batch Size | GPU Memory | |-------|------------|------------| | Field Predictor | 16 | 4 GB | | Parametric | 32 | 2 GB | | Training | 16 | 8 GB | ### Checkpoints ```python # Save checkpoint torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'config': model_config, 'normalization_stats': stats, 'epoch': epoch, 'best_val_loss': best_loss }, 'checkpoint.pt') # Load checkpoint checkpoint = torch.load('checkpoint.pt') model = ParametricFieldPredictor(**checkpoint['config']) model.load_state_dict(checkpoint['model_state_dict']) ``` --- ## Physics Interpretation ### Why GNNs Work for FEA 1. **Locality**: FEA solutions are local (nodes only affect neighbors) 2. **Superposition**: Linear FEA is additive (sum of effects) 3. **Equilibrium**: Force balance at each node (sum of messages = 0) The GNN learns these principles: - Message passing ≈ Force distribution through elements - Aggregation ≈ Force equilibrium at nodes - Multiple layers ≈ Load path propagation ### Physical Constraints The architecture enforces physics: ```python # Displacement at fixed nodes = 0 displacement = model(data) fixed_mask = data.boundary_conditions > 0 displacement[fixed_mask] = 0.0 # Hard constraint # Stress-strain relationship (implicit) # Learned by the network through training ``` --- ## Extension Points ### Adding New Element Types ```python # In data_loader.py ELEMENT_TYPES = { 'CTETRA': 0, 'CHEXA': 1, 'CPENTA': 2, 'CQUAD4': 3, 'CTRIA3': 4, 'CBAR': 5, 'CBEAM': 6, # Add new types here 'CTETRA10': 7, # New 10-node tetrahedron } ``` ### Custom Output Heads ```python class CustomFieldPredictor(FieldPredictorGNN): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Add custom head for thermal analysis self.temperature_head = nn.Linear(hidden_channels, 1) def forward(self, data): h = super().forward(data) # Add temperature prediction temperature = self.temperature_head(h) return torch.cat([h, temperature], dim=-1) ``` --- ## References 1. Battaglia et al. (2018) "Relational inductive biases, deep learning, and graph networks" 2. Pfaff et al. (2021) "Learning Mesh-Based Simulation with Graph Networks" 3. Sanchez-Gonzalez et al. (2020) "Learning to Simulate Complex Physics with Graph Networks" --- ## See Also - [Neural Features Complete](NEURAL_FEATURES_COMPLETE.md) - Overview of all features - [Physics Loss Guide](PHYSICS_LOSS_GUIDE.md) - Loss function selection - [Neural Workflow Tutorial](NEURAL_WORKFLOW_TUTORIAL.md) - Step-by-step guide