Files
Atomizer/docs/api/GNN_ARCHITECTURE.md
Anto01 ea437d360e docs: Major documentation overhaul - restructure folders, update tagline, add Getting Started guide
- Restructure docs/ folder (remove numeric prefixes):
  - 04_USER_GUIDES -> guides/
  - 05_API_REFERENCE -> api/
  - 06_PHYSICS -> physics/
  - 07_DEVELOPMENT -> development/
  - 08_ARCHIVE -> archive/
  - 09_DIAGRAMS -> diagrams/

- Replace tagline 'Talk, don't click' with 'LLM-driven optimization framework' in 9 files

- Create comprehensive docs/GETTING_STARTED.md:
  - Prerequisites and quick setup
  - Project structure overview
  - First study tutorial (Claude or manual)
  - Dashboard usage guide
  - Neural acceleration introduction

- Rewrite docs/00_INDEX.md with correct paths and modern structure

- Archive obsolete files:
  - 01_PROTOCOLS.md -> archive/historical/01_PROTOCOLS_legacy.md
  - 03_GETTING_STARTED.md -> archive/historical/
  - ATOMIZER_PODCAST_BRIEFING.md -> archive/marketing/

- Update timestamps to 2026-01-20 across all key files

- Update .gitignore to exclude docs/generated/

- Version bump: ATOMIZER_CONTEXT v1.8 -> v2.0
2026-01-20 10:03:45 -05:00

17 KiB

GNN Architecture Deep Dive

Technical documentation for AtomizerField Graph Neural Networks


Overview

AtomizerField uses Graph Neural Networks (GNNs) to learn physics from FEA simulations. This document explains the architecture in detail.


Why Graph Neural Networks?

FEA meshes are naturally graphs:

  • Nodes = Grid points (GRID cards in Nastran)
  • Edges = Element connectivity (CTETRA, CQUAD, etc.)
  • Node features = Position, BCs, material properties
  • Edge features = Element type, length, direction

Traditional neural networks (MLPs, CNNs) can't handle this irregular structure. GNNs can.

FEA Mesh                          Graph
═══════════════════════════════════════════════
         o───o                    (N1)──(N2)
        /│   │\                    │╲    │╱
       o─┼───┼─o       →         (N3)──(N4)
        \│   │/                    │╱    │╲
         o───o                    (N5)──(N6)

Model Architectures

1. Field Predictor GNN

Predicts complete displacement and stress fields.

┌─────────────────────────────────────────────────────────┐
│                  Field Predictor GNN                     │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  Input Encoding                                          │
│  ┌──────────────────────────────────────────────────┐  │
│  │ Node Features (12D per node):                     │  │
│  │ • Position (x, y, z)                    [3D]      │  │
│  │ • Material (E, nu, rho)                 [3D]      │  │
│  │ • Boundary conditions (fixed per DOF)   [6D]      │  │
│  │                                                    │  │
│  │ Edge Features (5D per edge):                      │  │
│  │ • Edge length                           [1D]      │  │
│  │ • Direction vector                      [3D]      │  │
│  │ • Element type                          [1D]      │  │
│  └──────────────────────────────────────────────────┘  │
│                         ↓                                │
│  Message Passing Layers (6 layers)                      │
│  ┌──────────────────────────────────────────────────┐  │
│  │ for layer in range(6):                            │  │
│  │   h = MeshGraphConv(h, edge_index, edge_attr)    │  │
│  │   h = LayerNorm(h)                                │  │
│  │   h = ReLU(h)                                     │  │
│  │   h = Dropout(h, p=0.1)                          │  │
│  │   h = h + residual  # Skip connection            │  │
│  └──────────────────────────────────────────────────┘  │
│                         ↓                                │
│  Output Heads                                           │
│  ┌──────────────────────────────────────────────────┐  │
│  │ Displacement Head:                                │  │
│  │   Linear(hidden → 64 → 6)  # 6 DOF per node      │  │
│  │                                                    │  │
│  │ Stress Head:                                      │  │
│  │   Linear(hidden → 64 → 1)  # Von Mises stress    │  │
│  └──────────────────────────────────────────────────┘  │
│                                                          │
│  Output: [N_nodes, 7] (6 displacement + 1 stress)       │
└─────────────────────────────────────────────────────────┘

Parameters: 718,221 trainable

2. Parametric Field Predictor GNN

Predicts scalar objectives directly from design parameters.

┌─────────────────────────────────────────────────────────┐
│              Parametric Field Predictor GNN              │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  Design Parameter Encoding                               │
│  ┌──────────────────────────────────────────────────┐  │
│  │ Design Params (4D):                               │  │
│  │ • beam_half_core_thickness                        │  │
│  │ • beam_face_thickness                             │  │
│  │ • holes_diameter                                  │  │
│  │ • hole_count                                      │  │
│  │                                                    │  │
│  │ Design Encoder MLP:                               │  │
│  │   Linear(4 → 64) → ReLU → Linear(64 → 128)       │  │
│  └──────────────────────────────────────────────────┘  │
│                         ↓                                │
│  Design-Conditioned GNN                                 │
│  ┌──────────────────────────────────────────────────┐  │
│  │ # Broadcast design encoding to all nodes          │  │
│  │ node_features = node_features + design_encoding   │  │
│  │                                                    │  │
│  │ for layer in range(4):                            │  │
│  │   h = GraphConv(h, edge_index)                    │  │
│  │   h = BatchNorm(h)                                │  │
│  │   h = ReLU(h)                                     │  │
│  └──────────────────────────────────────────────────┘  │
│                         ↓                                │
│  Global Pooling                                         │
│  ┌──────────────────────────────────────────────────┐  │
│  │ mean_pool = global_mean_pool(h)  # [batch, 128]  │  │
│  │ max_pool = global_max_pool(h)    # [batch, 128]  │  │
│  │ design = design_encoding         # [batch, 128]  │  │
│  │                                                    │  │
│  │ global_features = concat([mean_pool, max_pool,   │  │
│  │                           design])  # [batch, 384]│  │
│  └──────────────────────────────────────────────────┘  │
│                         ↓                                │
│  Scalar Prediction Heads                                │
│  ┌──────────────────────────────────────────────────┐  │
│  │ MLP: Linear(384 → 128 → 64 → 4)                  │  │
│  │                                                    │  │
│  │ Output:                                           │  │
│  │   [0] = mass (grams)                              │  │
│  │   [1] = frequency (Hz)                            │  │
│  │   [2] = max_displacement (mm)                     │  │
│  │   [3] = max_stress (MPa)                          │  │
│  └──────────────────────────────────────────────────┘  │
│                                                          │
│  Output: [batch, 4] (4 objectives)                      │
└─────────────────────────────────────────────────────────┘

Parameters: ~500,000 trainable


Message Passing

The core of GNNs is message passing. Here's how it works:

Standard Message Passing

def message_passing(node_features, edge_index, edge_attr):
    """
    node_features: [N_nodes, D_node]
    edge_index: [2, N_edges]  # Source → Target
    edge_attr: [N_edges, D_edge]
    """
    # Step 1: Compute messages
    source_nodes = node_features[edge_index[0]]  # [N_edges, D_node]
    target_nodes = node_features[edge_index[1]]  # [N_edges, D_node]

    messages = MLP([source_nodes, target_nodes, edge_attr])  # [N_edges, D_msg]

    # Step 2: Aggregate messages at each node
    aggregated = scatter_add(messages, edge_index[1])  # [N_nodes, D_msg]

    # Step 3: Update node features
    updated = MLP([node_features, aggregated])  # [N_nodes, D_node]

    return updated

Custom MeshGraphConv

We use a custom convolution that respects FEA mesh structure:

class MeshGraphConv(MessagePassing):
    """
    Custom message passing for FEA meshes.
    Accounts for:
    - Edge lengths (stiffness depends on distance)
    - Element types (different physics for solid/shell/beam)
    - Direction vectors (anisotropic behavior)
    """

    def message(self, x_i, x_j, edge_attr):
        # x_i: Target node features
        # x_j: Source node features
        # edge_attr: Edge features (length, direction, type)

        # Compute message
        edge_length = edge_attr[:, 0:1]
        edge_direction = edge_attr[:, 1:4]
        element_type = edge_attr[:, 4:5]

        # Scale by inverse distance (like stiffness)
        distance_weight = 1.0 / (edge_length + 1e-6)

        # Combine source and target features
        combined = torch.cat([x_i, x_j, edge_attr], dim=-1)
        message = self.mlp(combined) * distance_weight

        return message

    def aggregate(self, messages, index):
        # Sum messages at each node (like force equilibrium)
        return scatter_add(messages, index, dim=0)

Feature Engineering

Node Features (12D)

Feature Dimensions Range Description
Position (x, y, z) 3 Normalized Node coordinates
Material E 1 Log-scaled Young's modulus
Material nu 1 [0, 0.5] Poisson's ratio
Material rho 1 Log-scaled Density
BC_x, BC_y, BC_z 3 {0, 1} Fixed translation
BC_rx, BC_ry, BC_rz 3 {0, 1} Fixed rotation

Edge Features (5D)

Feature Dimensions Range Description
Length 1 Normalized Edge length
Direction 3 [-1, 1] Unit direction vector
Element type 1 Encoded CTETRA=0, CHEXA=1, etc.

Normalization

def normalize_features(node_features, edge_features, stats):
    """Normalize to zero mean, unit variance"""

    # Node features
    node_features = (node_features - stats['node_mean']) / stats['node_std']

    # Edge features (length uses log normalization)
    edge_features[:, 0] = torch.log(edge_features[:, 0] + 1e-6)
    edge_features = (edge_features - stats['edge_mean']) / stats['edge_std']

    return node_features, edge_features

Training Details

Optimizer

optimizer = AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4,
    betas=(0.9, 0.999)
)

Learning Rate Schedule

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=50,      # Restart every 50 epochs
    T_mult=2,    # Double period after each restart
    eta_min=1e-6
)

Data Augmentation

def augment_graph(data):
    """Random augmentation for better generalization"""

    # Random rotation (physics is rotation-invariant)
    if random.random() < 0.5:
        angle = random.uniform(0, 2 * math.pi)
        data = rotate_graph(data, angle, axis='z')

    # Random noise (robustness)
    if random.random() < 0.3:
        data.x += torch.randn_like(data.x) * 0.01

    return data

Batch Processing

from torch_geometric.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

for batch in loader:
    # batch.x: [total_nodes, D_node]
    # batch.edge_index: [2, total_edges]
    # batch.batch: [total_nodes] - maps nodes to graphs
    predictions = model(batch)

Model Comparison

Model Parameters Inference Output Use Case
Field Predictor 718K 50ms Full field When you need field visualization
Parametric 500K 4.5ms 4 scalars Direct optimization (fastest)
Ensemble (5x) 2.5M 25ms 4 scalars + uncertainty When confidence matters

Implementation Notes

PyTorch Geometric

We use PyTorch Geometric for GNN operations:

import torch_geometric
from torch_geometric.nn import MessagePassing, global_mean_pool

# Version requirements
# torch >= 2.0
# torch_geometric >= 2.3

GPU Memory

Model Batch Size GPU Memory
Field Predictor 16 4 GB
Parametric 32 2 GB
Training 16 8 GB

Checkpoints

# Save checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': model_config,
    'normalization_stats': stats,
    'epoch': epoch,
    'best_val_loss': best_loss
}, 'checkpoint.pt')

# Load checkpoint
checkpoint = torch.load('checkpoint.pt')
model = ParametricFieldPredictor(**checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])

Physics Interpretation

Why GNNs Work for FEA

  1. Locality: FEA solutions are local (nodes only affect neighbors)
  2. Superposition: Linear FEA is additive (sum of effects)
  3. Equilibrium: Force balance at each node (sum of messages = 0)

The GNN learns these principles:

  • Message passing ≈ Force distribution through elements
  • Aggregation ≈ Force equilibrium at nodes
  • Multiple layers ≈ Load path propagation

Physical Constraints

The architecture enforces physics:

# Displacement at fixed nodes = 0
displacement = model(data)
fixed_mask = data.boundary_conditions > 0
displacement[fixed_mask] = 0.0  # Hard constraint

# Stress-strain relationship (implicit)
# Learned by the network through training

Extension Points

Adding New Element Types

# In data_loader.py
ELEMENT_TYPES = {
    'CTETRA': 0,
    'CHEXA': 1,
    'CPENTA': 2,
    'CQUAD4': 3,
    'CTRIA3': 4,
    'CBAR': 5,
    'CBEAM': 6,
    # Add new types here
    'CTETRA10': 7,  # New 10-node tetrahedron
}

Custom Output Heads

class CustomFieldPredictor(FieldPredictorGNN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Add custom head for thermal analysis
        self.temperature_head = nn.Linear(hidden_channels, 1)

    def forward(self, data):
        h = super().forward(data)

        # Add temperature prediction
        temperature = self.temperature_head(h)
        return torch.cat([h, temperature], dim=-1)

References

  1. Battaglia et al. (2018) "Relational inductive biases, deep learning, and graph networks"
  2. Pfaff et al. (2021) "Learning Mesh-Based Simulation with Graph Networks"
  3. Sanchez-Gonzalez et al. (2020) "Learning to Simulate Complex Physics with Graph Networks"

See Also