Files
Atomizer/atomizer-field/neural_models/field_predictor.py

491 lines
16 KiB
Python
Raw Normal View History

"""
field_predictor.py
Graph Neural Network for predicting complete FEA field results
AtomizerField Field Predictor v2.0
Uses Graph Neural Networks to learn the physics of structural response.
Key Innovation:
Instead of: parameters FEA max_stress (scalar)
We learn: parameters Neural Network complete stress field (N values)
This enables 1000x faster optimization with physics understanding.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data
import numpy as np
class MeshGraphConv(MessagePassing):
"""
Custom Graph Convolution for FEA meshes
This layer propagates information along mesh edges (element connectivity)
to learn how forces flow through the structure.
Key insight: Stress and displacement fields follow mesh topology.
Adjacent elements influence each other through equilibrium.
"""
def __init__(self, in_channels, out_channels, edge_dim=None):
"""
Args:
in_channels (int): Input node feature dimension
out_channels (int): Output node feature dimension
edge_dim (int): Edge feature dimension (optional)
"""
super().__init__(aggr='mean') # Mean aggregation of neighbor messages
self.in_channels = in_channels
self.out_channels = out_channels
# Message function: how to combine node and edge features
if edge_dim is not None:
self.message_mlp = nn.Sequential(
nn.Linear(2 * in_channels + edge_dim, out_channels),
nn.LayerNorm(out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
else:
self.message_mlp = nn.Sequential(
nn.Linear(2 * in_channels, out_channels),
nn.LayerNorm(out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
# Update function: how to update node features
self.update_mlp = nn.Sequential(
nn.Linear(in_channels + out_channels, out_channels),
nn.LayerNorm(out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
self.edge_dim = edge_dim
def forward(self, x, edge_index, edge_attr=None):
"""
Propagate messages through the mesh graph
Args:
x: Node features [num_nodes, in_channels]
edge_index: Edge connectivity [2, num_edges]
edge_attr: Edge features [num_edges, edge_dim] (optional)
Returns:
Updated node features [num_nodes, out_channels]
"""
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_attr=None):
"""
Construct messages from neighbors
Args:
x_i: Target node features
x_j: Source node features
edge_attr: Edge features
"""
if edge_attr is not None:
# Combine source node, target node, and edge features
msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1)
else:
msg_input = torch.cat([x_i, x_j], dim=-1)
return self.message_mlp(msg_input)
def update(self, aggr_out, x):
"""
Update node features with aggregated messages
Args:
aggr_out: Aggregated messages from neighbors
x: Original node features
"""
# Combine original features with aggregated messages
update_input = torch.cat([x, aggr_out], dim=-1)
return self.update_mlp(update_input)
class FieldPredictorGNN(nn.Module):
"""
Graph Neural Network for predicting complete FEA fields
Architecture:
1. Node Encoder: Encode node positions, BCs, loads
2. Edge Encoder: Encode element connectivity, material properties
3. Message Passing: Propagate information through mesh (multiple layers)
4. Field Decoder: Predict displacement/stress at each node/element
This architecture respects physics:
- Uses mesh topology (forces flow through connected elements)
- Incorporates boundary conditions (fixed/loaded nodes)
- Learns material behavior (E, nu stress-strain relationship)
"""
def __init__(
self,
node_feature_dim=3, # Node coordinates (x, y, z)
edge_feature_dim=5, # Material properties (E, nu, rho, etc.)
hidden_dim=128,
num_layers=6,
output_dim=6, # 6 DOF displacement (3 translation + 3 rotation)
dropout=0.1
):
"""
Initialize field predictor
Args:
node_feature_dim (int): Dimension of node features (position + BCs + loads)
edge_feature_dim (int): Dimension of edge features (material properties)
hidden_dim (int): Hidden layer dimension
num_layers (int): Number of message passing layers
output_dim (int): Output dimension per node (6 for displacement)
dropout (float): Dropout rate
"""
super().__init__()
self.node_feature_dim = node_feature_dim
self.edge_feature_dim = edge_feature_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.output_dim = output_dim
# Node encoder: embed node coordinates + BCs + loads
self.node_encoder = nn.Sequential(
nn.Linear(node_feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
# Edge encoder: embed material properties
self.edge_encoder = nn.Sequential(
nn.Linear(edge_feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2)
)
# Message passing layers (the physics learning happens here)
self.conv_layers = nn.ModuleList([
MeshGraphConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
edge_dim=hidden_dim // 2
)
for _ in range(num_layers)
])
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim)
for _ in range(num_layers)
])
self.dropouts = nn.ModuleList([
nn.Dropout(dropout)
for _ in range(num_layers)
])
# Field decoder: predict displacement at each node
self.field_decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, output_dim)
)
# Physics-informed constraint layer (optional, ensures equilibrium)
self.physics_scale = nn.Parameter(torch.ones(1))
def forward(self, data):
"""
Forward pass: mesh displacement field
Args:
data (torch_geometric.data.Data): Batch of mesh graphs containing:
- x: Node features [num_nodes, node_feature_dim]
- edge_index: Connectivity [2, num_edges]
- edge_attr: Edge features [num_edges, edge_feature_dim]
- batch: Batch assignment [num_nodes]
Returns:
displacement_field: Predicted displacement [num_nodes, output_dim]
"""
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
# Encode nodes (positions + BCs + loads)
x = self.node_encoder(x) # [num_nodes, hidden_dim]
# Encode edges (material properties)
if edge_attr is not None:
edge_features = self.edge_encoder(edge_attr) # [num_edges, hidden_dim//2]
else:
edge_features = None
# Message passing: learn how forces propagate through mesh
for i, (conv, norm, dropout) in enumerate(zip(
self.conv_layers, self.layer_norms, self.dropouts
)):
# Graph convolution
x_new = conv(x, edge_index, edge_features)
# Residual connection (helps gradients flow)
x = x + dropout(x_new)
# Layer normalization
x = norm(x)
# Decode to displacement field
displacement = self.field_decoder(x) # [num_nodes, output_dim]
# Apply physics-informed scaling
displacement = displacement * self.physics_scale
return displacement
def predict_stress_from_displacement(self, displacement, data, material_props):
"""
Convert predicted displacement to stress using constitutive law
This implements: σ = C : ε = C : (u)
Where C is the material stiffness matrix
Args:
displacement: Predicted displacement [num_nodes, 6]
data: Mesh graph data
material_props: Material properties (E, nu)
Returns:
stress_field: Predicted stress [num_elements, n_components]
"""
# This would compute strain from displacement gradients
# then apply material constitutive law
# For now, we'll predict displacement and train a separate stress predictor
raise NotImplementedError("Stress prediction implemented in StressPredictor")
class StressPredictor(nn.Module):
"""
Predicts stress field from displacement field
This can be:
1. Physics-based: Compute strain from displacement, apply constitutive law
2. Learned: Train neural network to predict stress from displacement
We use learned approach for flexibility with nonlinear materials.
"""
def __init__(self, displacement_dim=6, hidden_dim=128, stress_components=6):
"""
Args:
displacement_dim (int): Displacement DOFs per node
hidden_dim (int): Hidden layer size
stress_components (int): Stress tensor components (6 for 3D)
"""
super().__init__()
# Stress predictor network
self.stress_net = nn.Sequential(
nn.Linear(displacement_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, stress_components)
)
def forward(self, displacement):
"""
Predict stress from displacement
Args:
displacement: [num_nodes, displacement_dim]
Returns:
stress: [num_nodes, stress_components]
"""
return self.stress_net(displacement)
class AtomizerFieldModel(nn.Module):
"""
Complete AtomizerField model: predicts both displacement and stress fields
This is the main model you'll use for training and inference.
"""
def __init__(
self,
node_feature_dim=10, # 3 (xyz) + 6 (BC DOFs) + 1 (load magnitude)
edge_feature_dim=5, # E, nu, rho, G, alpha
hidden_dim=128,
num_layers=6,
dropout=0.1
):
"""
Initialize complete field prediction model
Args:
node_feature_dim (int): Node features (coords + BCs + loads)
edge_feature_dim (int): Edge features (material properties)
hidden_dim (int): Hidden dimension
num_layers (int): Message passing layers
dropout (float): Dropout rate
"""
super().__init__()
# Displacement predictor (main GNN)
self.displacement_predictor = FieldPredictorGNN(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
output_dim=6, # 6 DOF displacement
dropout=dropout
)
# Stress predictor (from displacement)
self.stress_predictor = StressPredictor(
displacement_dim=6,
hidden_dim=hidden_dim,
stress_components=6 # σxx, σyy, σzz, τxy, τyz, τxz
)
def forward(self, data, return_stress=True):
"""
Predict displacement and stress fields
Args:
data: Mesh graph data
return_stress (bool): Whether to predict stress
Returns:
dict with:
- displacement: [num_nodes, 6]
- stress: [num_nodes, 6] (if return_stress=True)
- von_mises: [num_nodes] (if return_stress=True)
"""
# Predict displacement
displacement = self.displacement_predictor(data)
results = {'displacement': displacement}
if return_stress:
# Predict stress from displacement
stress = self.stress_predictor(displacement)
# Calculate von Mises stress
# σ_vm = sqrt(0.5 * ((σxx-σyy)² + (σyy-σzz)² + (σzz-σxx)² + 6(τxy² + τyz² + τxz²)))
sxx, syy, szz, txy, tyz, txz = stress[:, 0], stress[:, 1], stress[:, 2], \
stress[:, 3], stress[:, 4], stress[:, 5]
von_mises = torch.sqrt(
0.5 * (
(sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 +
6 * (txy**2 + tyz**2 + txz**2)
)
)
results['stress'] = stress
results['von_mises'] = von_mises
return results
def get_max_values(self, results):
"""
Extract maximum values (for compatibility with scalar optimization)
Args:
results: Output from forward()
Returns:
dict with max_displacement, max_stress
"""
max_displacement = torch.max(torch.norm(results['displacement'][:, :3], dim=1))
max_stress = torch.max(results['von_mises']) if 'von_mises' in results else None
return {
'max_displacement': max_displacement,
'max_stress': max_stress
}
def create_model(config=None):
"""
Factory function to create AtomizerField model
Args:
config (dict): Model configuration
Returns:
AtomizerFieldModel instance
"""
if config is None:
config = {
'node_feature_dim': 10,
'edge_feature_dim': 5,
'hidden_dim': 128,
'num_layers': 6,
'dropout': 0.1
}
model = AtomizerFieldModel(**config)
# Initialize weights
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
model.apply(init_weights)
return model
if __name__ == "__main__":
# Test model creation
print("Testing AtomizerField Model Creation...")
model = create_model()
print(f"Model created: {sum(p.numel() for p in model.parameters()):,} parameters")
# Create dummy data
num_nodes = 100
num_edges = 300
x = torch.randn(num_nodes, 10) # Node features
edge_index = torch.randint(0, num_nodes, (2, num_edges)) # Edge connectivity
edge_attr = torch.randn(num_edges, 5) # Edge features
batch = torch.zeros(num_nodes, dtype=torch.long) # Batch assignment
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
# Forward pass
with torch.no_grad():
results = model(data)
print(f"\nTest forward pass:")
print(f" Displacement shape: {results['displacement'].shape}")
print(f" Stress shape: {results['stress'].shape}")
print(f" Von Mises shape: {results['von_mises'].shape}")
max_vals = model.get_max_values(results)
print(f"\nMax values:")
print(f" Max displacement: {max_vals['max_displacement']:.6f}")
print(f" Max stress: {max_vals['max_stress']:.2f}")
print("\nModel test passed!")