Files
Atomizer/atomizer-field/neural_models/data_loader.py
Antoine d5ffba099e feat: Merge Atomizer-Field neural network module into main repository
Permanently integrates the Atomizer-Field GNN surrogate system:
- neural_models/: Graph Neural Network for FEA field prediction
- batch_parser.py: Parse training data from FEA exports
- train.py: Neural network training pipeline
- predict.py: Inference engine for fast predictions

This enables 600x-2200x speedup over traditional FEA by replacing
expensive simulations with millisecond neural network predictions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 15:31:33 -05:00

417 lines
15 KiB
Python

"""
data_loader.py
Data loading pipeline for neural field training
AtomizerField Data Loader v2.0
Converts parsed FEA data (HDF5 + JSON) into PyTorch Geometric graphs for training.
Key Transformation:
Parsed FEA Data → Graph Representation → Neural Network Input
Graph structure:
- Nodes: FEA mesh nodes (with coordinates, BCs, loads)
- Edges: Element connectivity (with material properties)
- Labels: Displacement and stress fields (ground truth from FEA)
"""
import json
import h5py
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import warnings
class FEAMeshDataset(Dataset):
"""
PyTorch Dataset for FEA mesh data
Loads parsed neural field data and converts to PyTorch Geometric graphs.
Each graph represents one FEA analysis case.
"""
def __init__(
self,
case_directories,
normalize=True,
include_stress=True,
cache_in_memory=False
):
"""
Initialize dataset
Args:
case_directories (list): List of paths to parsed cases
normalize (bool): Normalize node coordinates and results
include_stress (bool): Include stress in targets
cache_in_memory (bool): Load all data into RAM (faster but memory-intensive)
"""
self.case_dirs = [Path(d) for d in case_directories]
self.normalize = normalize
self.include_stress = include_stress
self.cache_in_memory = cache_in_memory
# Validate all cases exist
self.valid_cases = []
for case_dir in self.case_dirs:
if self._validate_case(case_dir):
self.valid_cases.append(case_dir)
else:
warnings.warn(f"Skipping invalid case: {case_dir}")
print(f"Loaded {len(self.valid_cases)}/{len(self.case_dirs)} valid cases")
# Cache data if requested
self.cache = {}
if cache_in_memory:
print("Caching data in memory...")
for idx in range(len(self.valid_cases)):
self.cache[idx] = self._load_case(idx)
print("Cache complete!")
# Compute normalization statistics
if normalize:
self._compute_normalization_stats()
def _validate_case(self, case_dir):
"""Check if case has required files"""
json_file = case_dir / "neural_field_data.json"
h5_file = case_dir / "neural_field_data.h5"
return json_file.exists() and h5_file.exists()
def __len__(self):
return len(self.valid_cases)
def __getitem__(self, idx):
"""
Get graph data for one case
Returns:
torch_geometric.data.Data object with:
- x: Node features [num_nodes, feature_dim]
- edge_index: Element connectivity [2, num_edges]
- edge_attr: Edge features (material props) [num_edges, edge_dim]
- y_displacement: Target displacement [num_nodes, 6]
- y_stress: Target stress [num_nodes, 6] (if include_stress)
- bc_mask: Boundary condition mask [num_nodes, 6]
- pos: Node positions [num_nodes, 3]
"""
if self.cache_in_memory and idx in self.cache:
return self.cache[idx]
return self._load_case(idx)
def _load_case(self, idx):
"""Load and process a single case"""
case_dir = self.valid_cases[idx]
# Load JSON metadata
with open(case_dir / "neural_field_data.json", 'r') as f:
metadata = json.load(f)
# Load HDF5 field data
with h5py.File(case_dir / "neural_field_data.h5", 'r') as f:
# Node coordinates
node_coords = torch.from_numpy(f['mesh/node_coordinates'][:]).float()
# Displacement field (target)
displacement = torch.from_numpy(f['results/displacement'][:]).float()
# Stress field (target, if available)
stress = None
if self.include_stress and 'results/stress' in f:
# Try to load first available stress type
stress_group = f['results/stress']
for stress_type in stress_group.keys():
stress_data = stress_group[stress_type]['data'][:]
stress = torch.from_numpy(stress_data).float()
break
# Build graph structure
graph_data = self._build_graph(metadata, node_coords, displacement, stress)
# Normalize if requested
if self.normalize:
graph_data = self._normalize_graph(graph_data)
return graph_data
def _build_graph(self, metadata, node_coords, displacement, stress):
"""
Convert FEA mesh to graph
Args:
metadata (dict): Parsed metadata
node_coords (Tensor): Node positions [num_nodes, 3]
displacement (Tensor): Displacement field [num_nodes, 6]
stress (Tensor): Stress field [num_nodes, 6] or None
Returns:
torch_geometric.data.Data
"""
num_nodes = node_coords.shape[0]
# === NODE FEATURES ===
# Start with coordinates
node_features = [node_coords] # [num_nodes, 3]
# Add boundary conditions (which DOFs are constrained)
bc_mask = torch.zeros(num_nodes, 6) # [num_nodes, 6]
if 'boundary_conditions' in metadata and 'spc' in metadata['boundary_conditions']:
for spc in metadata['boundary_conditions']['spc']:
node_id = spc['node']
# Find node index (assuming node IDs are sequential starting from 1)
# This is a simplification - production code should use ID mapping
if node_id <= num_nodes:
dofs = spc['dofs']
# Parse DOF string (e.g., "123" means constrained in x,y,z)
for dof_char in str(dofs):
if dof_char.isdigit():
dof_idx = int(dof_char) - 1 # 0-indexed
if 0 <= dof_idx < 6:
bc_mask[node_id - 1, dof_idx] = 1.0
node_features.append(bc_mask) # [num_nodes, 6]
# Add load information (force magnitude at each node)
load_features = torch.zeros(num_nodes, 3) # [num_nodes, 3] for x,y,z forces
if 'loads' in metadata and 'point_forces' in metadata['loads']:
for force in metadata['loads']['point_forces']:
node_id = force['node']
if node_id <= num_nodes:
magnitude = force['magnitude']
direction = force['direction']
force_vector = [magnitude * d for d in direction]
load_features[node_id - 1] = torch.tensor(force_vector)
node_features.append(load_features) # [num_nodes, 3]
# Concatenate all node features
x = torch.cat(node_features, dim=-1) # [num_nodes, 3+6+3=12]
# === EDGE FEATURES ===
# Build edge index from element connectivity
edge_index = []
edge_attrs = []
# Get material properties
material_dict = {}
if 'materials' in metadata:
for mat in metadata['materials']:
mat_id = mat['id']
if mat['type'] == 'MAT1':
material_dict[mat_id] = [
mat.get('E', 0.0) / 1e6, # Normalize E (MPa → GPa)
mat.get('nu', 0.0),
mat.get('rho', 0.0) * 1e6, # Normalize rho
mat.get('G', 0.0) / 1e6 if mat.get('G') else 0.0,
mat.get('alpha', 0.0) * 1e6 if mat.get('alpha') else 0.0
]
# Process elements to create edges
if 'mesh' in metadata and 'elements' in metadata['mesh']:
for elem_type in ['solid', 'shell', 'beam']:
if elem_type in metadata['mesh']['elements']:
for elem in metadata['mesh']['elements'][elem_type]:
elem_nodes = elem['nodes']
mat_id = elem.get('material_id', 1)
# Get material properties for this element
mat_props = material_dict.get(mat_id, [0.0] * 5)
# Create edges between all node pairs in element
# (fully connected within element)
for i in range(len(elem_nodes)):
for j in range(i + 1, len(elem_nodes)):
node_i = elem_nodes[i] - 1 # 0-indexed
node_j = elem_nodes[j] - 1
if node_i < num_nodes and node_j < num_nodes:
# Add bidirectional edges
edge_index.append([node_i, node_j])
edge_index.append([node_j, node_i])
# Both edges get same material properties
edge_attrs.append(mat_props)
edge_attrs.append(mat_props)
# Convert to tensors
if edge_index:
edge_index = torch.tensor(edge_index, dtype=torch.long).t() # [2, num_edges]
edge_attr = torch.tensor(edge_attrs, dtype=torch.float) # [num_edges, 5]
else:
# No edges (shouldn't happen, but handle gracefully)
edge_index = torch.zeros((2, 0), dtype=torch.long)
edge_attr = torch.zeros((0, 5), dtype=torch.float)
# === CREATE DATA OBJECT ===
data = Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
y_displacement=displacement,
bc_mask=bc_mask,
pos=node_coords # Store original positions
)
# Add stress if available
if stress is not None:
data.y_stress = stress
return data
def _normalize_graph(self, data):
"""
Normalize graph features
- Coordinates: Center and scale to unit box
- Displacement: Scale by mean displacement
- Stress: Scale by mean stress
"""
# Normalize coordinates (already done in node features)
if hasattr(self, 'coord_mean') and hasattr(self, 'coord_std'):
# Extract coords from features (first 3 dimensions)
coords = data.x[:, :3]
coords_norm = (coords - self.coord_mean) / (self.coord_std + 1e-8)
data.x[:, :3] = coords_norm
# Normalize displacement
if hasattr(self, 'disp_mean') and hasattr(self, 'disp_std'):
data.y_displacement = (data.y_displacement - self.disp_mean) / (self.disp_std + 1e-8)
# Normalize stress
if hasattr(data, 'y_stress') and hasattr(self, 'stress_mean') and hasattr(self, 'stress_std'):
data.y_stress = (data.y_stress - self.stress_mean) / (self.stress_std + 1e-8)
return data
def _compute_normalization_stats(self):
"""
Compute mean and std for normalization across entire dataset
"""
print("Computing normalization statistics...")
all_coords = []
all_disp = []
all_stress = []
for idx in range(len(self.valid_cases)):
case_dir = self.valid_cases[idx]
with h5py.File(case_dir / "neural_field_data.h5", 'r') as f:
coords = f['mesh/node_coordinates'][:]
disp = f['results/displacement'][:]
all_coords.append(coords)
all_disp.append(disp)
# Load stress if available
if self.include_stress and 'results/stress' in f:
stress_group = f['results/stress']
for stress_type in stress_group.keys():
stress_data = stress_group[stress_type]['data'][:]
all_stress.append(stress_data)
break
# Concatenate all data
all_coords = np.concatenate(all_coords, axis=0)
all_disp = np.concatenate(all_disp, axis=0)
# Compute statistics
self.coord_mean = torch.from_numpy(all_coords.mean(axis=0)).float()
self.coord_std = torch.from_numpy(all_coords.std(axis=0)).float()
self.disp_mean = torch.from_numpy(all_disp.mean(axis=0)).float()
self.disp_std = torch.from_numpy(all_disp.std(axis=0)).float()
if all_stress:
all_stress = np.concatenate(all_stress, axis=0)
self.stress_mean = torch.from_numpy(all_stress.mean(axis=0)).float()
self.stress_std = torch.from_numpy(all_stress.std(axis=0)).float()
print("Normalization statistics computed!")
def create_dataloaders(
train_cases,
val_cases,
batch_size=4,
num_workers=0,
normalize=True,
include_stress=True
):
"""
Create training and validation dataloaders
Args:
train_cases (list): List of training case directories
val_cases (list): List of validation case directories
batch_size (int): Batch size
num_workers (int): Number of data loading workers
normalize (bool): Normalize features
include_stress (bool): Include stress targets
Returns:
train_loader, val_loader
"""
print("\nCreating datasets...")
# Create datasets
train_dataset = FEAMeshDataset(
train_cases,
normalize=normalize,
include_stress=include_stress,
cache_in_memory=False # Set to True for small datasets
)
val_dataset = FEAMeshDataset(
val_cases,
normalize=normalize,
include_stress=include_stress,
cache_in_memory=False
)
# Share normalization stats with validation set
if normalize and hasattr(train_dataset, 'coord_mean'):
val_dataset.coord_mean = train_dataset.coord_mean
val_dataset.coord_std = train_dataset.coord_std
val_dataset.disp_mean = train_dataset.disp_mean
val_dataset.disp_std = train_dataset.disp_std
if hasattr(train_dataset, 'stress_mean'):
val_dataset.stress_mean = train_dataset.stress_mean
val_dataset.stress_std = train_dataset.stress_std
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
print(f"\nDataloaders created:")
print(f" Training: {len(train_dataset)} cases")
print(f" Validation: {len(val_dataset)} cases")
return train_loader, val_loader
if __name__ == "__main__":
# Test data loader
print("Testing FEA Mesh Data Loader...\n")
# This is a placeholder test - you would use actual parsed case directories
print("Note: This test requires actual parsed FEA data.")
print("Run the parser first on your NX Nastran files.")
print("\nData loader implementation complete!")