""" field_predictor.py Graph Neural Network for predicting complete FEA field results AtomizerField Field Predictor v2.0 Uses Graph Neural Networks to learn the physics of structural response. Key Innovation: Instead of: parameters → FEA → max_stress (scalar) We learn: parameters → Neural Network → complete stress field (N values) This enables 1000x faster optimization with physics understanding. """ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing, global_mean_pool from torch_geometric.data import Data import numpy as np class MeshGraphConv(MessagePassing): """ Custom Graph Convolution for FEA meshes This layer propagates information along mesh edges (element connectivity) to learn how forces flow through the structure. Key insight: Stress and displacement fields follow mesh topology. Adjacent elements influence each other through equilibrium. """ def __init__(self, in_channels, out_channels, edge_dim=None): """ Args: in_channels (int): Input node feature dimension out_channels (int): Output node feature dimension edge_dim (int): Edge feature dimension (optional) """ super().__init__(aggr='mean') # Mean aggregation of neighbor messages self.in_channels = in_channels self.out_channels = out_channels # Message function: how to combine node and edge features if edge_dim is not None: self.message_mlp = nn.Sequential( nn.Linear(2 * in_channels + edge_dim, out_channels), nn.LayerNorm(out_channels), nn.ReLU(), nn.Linear(out_channels, out_channels) ) else: self.message_mlp = nn.Sequential( nn.Linear(2 * in_channels, out_channels), nn.LayerNorm(out_channels), nn.ReLU(), nn.Linear(out_channels, out_channels) ) # Update function: how to update node features self.update_mlp = nn.Sequential( nn.Linear(in_channels + out_channels, out_channels), nn.LayerNorm(out_channels), nn.ReLU(), nn.Linear(out_channels, out_channels) ) self.edge_dim = edge_dim def forward(self, x, edge_index, edge_attr=None): """ Propagate messages through the mesh graph Args: x: Node features [num_nodes, in_channels] edge_index: Edge connectivity [2, num_edges] edge_attr: Edge features [num_edges, edge_dim] (optional) Returns: Updated node features [num_nodes, out_channels] """ return self.propagate(edge_index, x=x, edge_attr=edge_attr) def message(self, x_i, x_j, edge_attr=None): """ Construct messages from neighbors Args: x_i: Target node features x_j: Source node features edge_attr: Edge features """ if edge_attr is not None: # Combine source node, target node, and edge features msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1) else: msg_input = torch.cat([x_i, x_j], dim=-1) return self.message_mlp(msg_input) def update(self, aggr_out, x): """ Update node features with aggregated messages Args: aggr_out: Aggregated messages from neighbors x: Original node features """ # Combine original features with aggregated messages update_input = torch.cat([x, aggr_out], dim=-1) return self.update_mlp(update_input) class FieldPredictorGNN(nn.Module): """ Graph Neural Network for predicting complete FEA fields Architecture: 1. Node Encoder: Encode node positions, BCs, loads 2. Edge Encoder: Encode element connectivity, material properties 3. Message Passing: Propagate information through mesh (multiple layers) 4. Field Decoder: Predict displacement/stress at each node/element This architecture respects physics: - Uses mesh topology (forces flow through connected elements) - Incorporates boundary conditions (fixed/loaded nodes) - Learns material behavior (E, nu → stress-strain relationship) """ def __init__( self, node_feature_dim=3, # Node coordinates (x, y, z) edge_feature_dim=5, # Material properties (E, nu, rho, etc.) hidden_dim=128, num_layers=6, output_dim=6, # 6 DOF displacement (3 translation + 3 rotation) dropout=0.1 ): """ Initialize field predictor Args: node_feature_dim (int): Dimension of node features (position + BCs + loads) edge_feature_dim (int): Dimension of edge features (material properties) hidden_dim (int): Hidden layer dimension num_layers (int): Number of message passing layers output_dim (int): Output dimension per node (6 for displacement) dropout (float): Dropout rate """ super().__init__() self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.output_dim = output_dim # Node encoder: embed node coordinates + BCs + loads self.node_encoder = nn.Sequential( nn.Linear(node_feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim) ) # Edge encoder: embed material properties self.edge_encoder = nn.Sequential( nn.Linear(edge_feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2) ) # Message passing layers (the physics learning happens here) self.conv_layers = nn.ModuleList([ MeshGraphConv( in_channels=hidden_dim, out_channels=hidden_dim, edge_dim=hidden_dim // 2 ) for _ in range(num_layers) ]) self.layer_norms = nn.ModuleList([ nn.LayerNorm(hidden_dim) for _ in range(num_layers) ]) self.dropouts = nn.ModuleList([ nn.Dropout(dropout) for _ in range(num_layers) ]) # Field decoder: predict displacement at each node self.field_decoder = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, output_dim) ) # Physics-informed constraint layer (optional, ensures equilibrium) self.physics_scale = nn.Parameter(torch.ones(1)) def forward(self, data): """ Forward pass: mesh → displacement field Args: data (torch_geometric.data.Data): Batch of mesh graphs containing: - x: Node features [num_nodes, node_feature_dim] - edge_index: Connectivity [2, num_edges] - edge_attr: Edge features [num_edges, edge_feature_dim] - batch: Batch assignment [num_nodes] Returns: displacement_field: Predicted displacement [num_nodes, output_dim] """ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr # Encode nodes (positions + BCs + loads) x = self.node_encoder(x) # [num_nodes, hidden_dim] # Encode edges (material properties) if edge_attr is not None: edge_features = self.edge_encoder(edge_attr) # [num_edges, hidden_dim//2] else: edge_features = None # Message passing: learn how forces propagate through mesh for i, (conv, norm, dropout) in enumerate(zip( self.conv_layers, self.layer_norms, self.dropouts )): # Graph convolution x_new = conv(x, edge_index, edge_features) # Residual connection (helps gradients flow) x = x + dropout(x_new) # Layer normalization x = norm(x) # Decode to displacement field displacement = self.field_decoder(x) # [num_nodes, output_dim] # Apply physics-informed scaling displacement = displacement * self.physics_scale return displacement def predict_stress_from_displacement(self, displacement, data, material_props): """ Convert predicted displacement to stress using constitutive law This implements: σ = C : ε = C : (∇u) Where C is the material stiffness matrix Args: displacement: Predicted displacement [num_nodes, 6] data: Mesh graph data material_props: Material properties (E, nu) Returns: stress_field: Predicted stress [num_elements, n_components] """ # This would compute strain from displacement gradients # then apply material constitutive law # For now, we'll predict displacement and train a separate stress predictor raise NotImplementedError("Stress prediction implemented in StressPredictor") class StressPredictor(nn.Module): """ Predicts stress field from displacement field This can be: 1. Physics-based: Compute strain from displacement, apply constitutive law 2. Learned: Train neural network to predict stress from displacement We use learned approach for flexibility with nonlinear materials. """ def __init__(self, displacement_dim=6, hidden_dim=128, stress_components=6): """ Args: displacement_dim (int): Displacement DOFs per node hidden_dim (int): Hidden layer size stress_components (int): Stress tensor components (6 for 3D) """ super().__init__() # Stress predictor network self.stress_net = nn.Sequential( nn.Linear(displacement_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, stress_components) ) def forward(self, displacement): """ Predict stress from displacement Args: displacement: [num_nodes, displacement_dim] Returns: stress: [num_nodes, stress_components] """ return self.stress_net(displacement) class AtomizerFieldModel(nn.Module): """ Complete AtomizerField model: predicts both displacement and stress fields This is the main model you'll use for training and inference. """ def __init__( self, node_feature_dim=10, # 3 (xyz) + 6 (BC DOFs) + 1 (load magnitude) edge_feature_dim=5, # E, nu, rho, G, alpha hidden_dim=128, num_layers=6, dropout=0.1 ): """ Initialize complete field prediction model Args: node_feature_dim (int): Node features (coords + BCs + loads) edge_feature_dim (int): Edge features (material properties) hidden_dim (int): Hidden dimension num_layers (int): Message passing layers dropout (float): Dropout rate """ super().__init__() # Displacement predictor (main GNN) self.displacement_predictor = FieldPredictorGNN( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, hidden_dim=hidden_dim, num_layers=num_layers, output_dim=6, # 6 DOF displacement dropout=dropout ) # Stress predictor (from displacement) self.stress_predictor = StressPredictor( displacement_dim=6, hidden_dim=hidden_dim, stress_components=6 # σxx, σyy, σzz, τxy, τyz, τxz ) def forward(self, data, return_stress=True): """ Predict displacement and stress fields Args: data: Mesh graph data return_stress (bool): Whether to predict stress Returns: dict with: - displacement: [num_nodes, 6] - stress: [num_nodes, 6] (if return_stress=True) - von_mises: [num_nodes] (if return_stress=True) """ # Predict displacement displacement = self.displacement_predictor(data) results = {'displacement': displacement} if return_stress: # Predict stress from displacement stress = self.stress_predictor(displacement) # Calculate von Mises stress # σ_vm = sqrt(0.5 * ((σxx-σyy)² + (σyy-σzz)² + (σzz-σxx)² + 6(τxy² + τyz² + τxz²))) sxx, syy, szz, txy, tyz, txz = stress[:, 0], stress[:, 1], stress[:, 2], \ stress[:, 3], stress[:, 4], stress[:, 5] von_mises = torch.sqrt( 0.5 * ( (sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 + 6 * (txy**2 + tyz**2 + txz**2) ) ) results['stress'] = stress results['von_mises'] = von_mises return results def get_max_values(self, results): """ Extract maximum values (for compatibility with scalar optimization) Args: results: Output from forward() Returns: dict with max_displacement, max_stress """ max_displacement = torch.max(torch.norm(results['displacement'][:, :3], dim=1)) max_stress = torch.max(results['von_mises']) if 'von_mises' in results else None return { 'max_displacement': max_displacement, 'max_stress': max_stress } def create_model(config=None): """ Factory function to create AtomizerField model Args: config (dict): Model configuration Returns: AtomizerFieldModel instance """ if config is None: config = { 'node_feature_dim': 10, 'edge_feature_dim': 5, 'hidden_dim': 128, 'num_layers': 6, 'dropout': 0.1 } model = AtomizerFieldModel(**config) # Initialize weights def init_weights(m): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) model.apply(init_weights) return model if __name__ == "__main__": # Test model creation print("Testing AtomizerField Model Creation...") model = create_model() print(f"Model created: {sum(p.numel() for p in model.parameters()):,} parameters") # Create dummy data num_nodes = 100 num_edges = 300 x = torch.randn(num_nodes, 10) # Node features edge_index = torch.randint(0, num_nodes, (2, num_edges)) # Edge connectivity edge_attr = torch.randn(num_edges, 5) # Edge features batch = torch.zeros(num_nodes, dtype=torch.long) # Batch assignment data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch) # Forward pass with torch.no_grad(): results = model(data) print(f"\nTest forward pass:") print(f" Displacement shape: {results['displacement'].shape}") print(f" Stress shape: {results['stress'].shape}") print(f" Von Mises shape: {results['von_mises'].shape}") max_vals = model.get_max_values(results) print(f"\nMax values:") print(f" Max displacement: {max_vals['max_displacement']:.6f}") print(f" Max stress: {max_vals['max_stress']:.2f}") print("\nModel test passed!")