Files
Atomizer/atomizer-field/predict.py
Antoine d5ffba099e feat: Merge Atomizer-Field neural network module into main repository
Permanently integrates the Atomizer-Field GNN surrogate system:
- neural_models/: Graph Neural Network for FEA field prediction
- batch_parser.py: Parse training data from FEA exports
- train.py: Neural network training pipeline
- predict.py: Inference engine for fast predictions

This enables 600x-2200x speedup over traditional FEA by replacing
expensive simulations with millisecond neural network predictions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 15:31:33 -05:00

374 lines
12 KiB
Python

"""
predict.py
Inference script for AtomizerField trained models
AtomizerField Inference v2.0
Uses trained GNN to predict FEA fields 1000x faster than traditional simulation.
Usage:
python predict.py --model checkpoint_best.pt --input case_001
This enables:
- Rapid design exploration (milliseconds vs hours per analysis)
- Real-time optimization
- Interactive design feedback
"""
import argparse
import json
from pathlib import Path
import time
import torch
import numpy as np
import h5py
from neural_models.field_predictor import AtomizerFieldModel
from neural_models.data_loader import FEAMeshDataset
class FieldPredictor:
"""
Inference engine for trained field prediction models
"""
def __init__(self, checkpoint_path, device=None):
"""
Initialize predictor
Args:
checkpoint_path (str): Path to trained model checkpoint
device (str): Device to run on ('cuda' or 'cpu')
"""
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"\nAtomizerField Inference Engine v2.0")
print(f"Device: {self.device}")
# Load checkpoint
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Create model
model_config = checkpoint['config']['model']
self.model = AtomizerFieldModel(**model_config)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
self.config = checkpoint['config']
print(f"Model loaded (epoch {checkpoint['epoch']}, val_loss={checkpoint['best_val_loss']:.6f})")
def predict(self, case_directory):
"""
Predict displacement and stress fields for a case
Args:
case_directory (str): Path to parsed FEA case
Returns:
dict: Predictions with displacement, stress, von_mises fields
"""
print(f"\nPredicting fields for {Path(case_directory).name}...")
# Load data
dataset = FEAMeshDataset(
[case_directory],
normalize=True,
include_stress=False # Don't need ground truth for prediction
)
if len(dataset) == 0:
raise ValueError(f"Could not load case from {case_directory}")
data = dataset[0].to(self.device)
# Predict
start_time = time.time()
with torch.no_grad():
predictions = self.model(data, return_stress=True)
inference_time = time.time() - start_time
print(f"Prediction complete in {inference_time*1000:.1f} ms")
# Convert to numpy
results = {
'displacement': predictions['displacement'].cpu().numpy(),
'stress': predictions['stress'].cpu().numpy(),
'von_mises': predictions['von_mises'].cpu().numpy(),
'inference_time_ms': inference_time * 1000
}
# Compute max values
max_disp = np.max(np.linalg.norm(results['displacement'][:, :3], axis=1))
max_stress = np.max(results['von_mises'])
results['max_displacement'] = float(max_disp)
results['max_stress'] = float(max_stress)
print(f"\nResults:")
print(f" Max displacement: {max_disp:.6f} mm")
print(f" Max von Mises stress: {max_stress:.2f} MPa")
return results
def save_predictions(self, predictions, case_directory, output_name='predicted'):
"""
Save predictions in same format as ground truth
Args:
predictions (dict): Prediction results
case_directory (str): Case directory
output_name (str): Output file name prefix
"""
case_dir = Path(case_directory)
output_file = case_dir / f"{output_name}_fields.h5"
print(f"\nSaving predictions to {output_file}...")
with h5py.File(output_file, 'w') as f:
# Save displacement
f.create_dataset('displacement',
data=predictions['displacement'],
compression='gzip')
# Save stress
f.create_dataset('stress',
data=predictions['stress'],
compression='gzip')
# Save von Mises
f.create_dataset('von_mises',
data=predictions['von_mises'],
compression='gzip')
# Save metadata
f.attrs['max_displacement'] = predictions['max_displacement']
f.attrs['max_stress'] = predictions['max_stress']
f.attrs['inference_time_ms'] = predictions['inference_time_ms']
print(f"Predictions saved!")
# Also save JSON summary
summary_file = case_dir / f"{output_name}_summary.json"
summary = {
'max_displacement': predictions['max_displacement'],
'max_stress': predictions['max_stress'],
'inference_time_ms': predictions['inference_time_ms'],
'num_nodes': len(predictions['displacement'])
}
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Summary saved to {summary_file}")
def compare_with_ground_truth(self, predictions, case_directory):
"""
Compare predictions with FEA ground truth
Args:
predictions (dict): Model predictions
case_directory (str): Case directory with ground truth
Returns:
dict: Comparison metrics
"""
case_dir = Path(case_directory)
h5_file = case_dir / "neural_field_data.h5"
if not h5_file.exists():
print("No ground truth available for comparison")
return None
print("\nComparing with FEA ground truth...")
# Load ground truth
with h5py.File(h5_file, 'r') as f:
gt_displacement = f['results/displacement'][:]
# Try to load stress
gt_stress = None
if 'results/stress' in f:
stress_group = f['results/stress']
for stress_type in stress_group.keys():
gt_stress = stress_group[stress_type]['data'][:]
break
# Compute errors
pred_disp = predictions['displacement']
disp_error = np.linalg.norm(pred_disp - gt_displacement, axis=1)
disp_magnitude = np.linalg.norm(gt_displacement, axis=1)
rel_disp_error = disp_error / (disp_magnitude + 1e-8)
metrics = {
'displacement': {
'mae': float(np.mean(disp_error)),
'rmse': float(np.sqrt(np.mean(disp_error**2))),
'relative_error': float(np.mean(rel_disp_error)),
'max_error': float(np.max(disp_error))
}
}
# Compare max values
pred_max_disp = predictions['max_displacement']
gt_max_disp = float(np.max(disp_magnitude))
metrics['max_displacement_error'] = abs(pred_max_disp - gt_max_disp)
metrics['max_displacement_relative_error'] = metrics['max_displacement_error'] / (gt_max_disp + 1e-8)
if gt_stress is not None:
pred_stress = predictions['stress']
stress_error = np.linalg.norm(pred_stress - gt_stress, axis=1)
metrics['stress'] = {
'mae': float(np.mean(stress_error)),
'rmse': float(np.sqrt(np.mean(stress_error**2))),
}
# Print comparison
print("\nComparison Results:")
print(f" Displacement MAE: {metrics['displacement']['mae']:.6f} mm")
print(f" Displacement RMSE: {metrics['displacement']['rmse']:.6f} mm")
print(f" Displacement Relative Error: {metrics['displacement']['relative_error']*100:.2f}%")
print(f" Max Displacement Error: {metrics['max_displacement_error']:.6f} mm ({metrics['max_displacement_relative_error']*100:.2f}%)")
if 'stress' in metrics:
print(f" Stress MAE: {metrics['stress']['mae']:.2f} MPa")
print(f" Stress RMSE: {metrics['stress']['rmse']:.2f} MPa")
return metrics
def batch_predict(predictor, case_directories, output_dir=None):
"""
Run predictions on multiple cases
Args:
predictor (FieldPredictor): Initialized predictor
case_directories (list): List of case directories
output_dir (str): Optional output directory for results
Returns:
list: List of prediction results
"""
print(f"\n{'='*60}")
print(f"Batch Prediction: {len(case_directories)} cases")
print(f"{'='*60}")
results = []
for i, case_dir in enumerate(case_directories, 1):
print(f"\n[{i}/{len(case_directories)}] Processing {Path(case_dir).name}...")
try:
# Predict
predictions = predictor.predict(case_dir)
# Save predictions
predictor.save_predictions(predictions, case_dir)
# Compare with ground truth
comparison = predictor.compare_with_ground_truth(predictions, case_dir)
result = {
'case': str(case_dir),
'status': 'success',
'predictions': {
'max_displacement': predictions['max_displacement'],
'max_stress': predictions['max_stress'],
'inference_time_ms': predictions['inference_time_ms']
},
'comparison': comparison
}
results.append(result)
except Exception as e:
print(f"ERROR: {e}")
results.append({
'case': str(case_dir),
'status': 'failed',
'error': str(e)
})
# Save batch results
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
results_file = output_path / 'batch_predictions.json'
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nBatch results saved to {results_file}")
# Print summary
print(f"\n{'='*60}")
print("Batch Prediction Summary")
print(f"{'='*60}")
successful = sum(1 for r in results if r['status'] == 'success')
print(f"Successful: {successful}/{len(results)}")
if successful > 0:
avg_time = np.mean([r['predictions']['inference_time_ms']
for r in results if r['status'] == 'success'])
print(f"Average inference time: {avg_time:.1f} ms")
return results
def main():
"""
Main inference entry point
"""
parser = argparse.ArgumentParser(description='Predict FEA fields using trained model')
parser.add_argument('--model', type=str, required=True,
help='Path to model checkpoint')
parser.add_argument('--input', type=str, required=True,
help='Input case directory or directory containing multiple cases')
parser.add_argument('--output_dir', type=str, default=None,
help='Output directory for batch results')
parser.add_argument('--batch', action='store_true',
help='Process all subdirectories as separate cases')
parser.add_argument('--device', type=str, default=None,
choices=['cuda', 'cpu'],
help='Device to run on')
parser.add_argument('--compare', action='store_true',
help='Compare predictions with ground truth')
args = parser.parse_args()
# Create predictor
predictor = FieldPredictor(args.model, device=args.device)
input_path = Path(args.input)
if args.batch:
# Batch prediction
case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
batch_predict(predictor, case_dirs, args.output_dir)
else:
# Single prediction
predictions = predictor.predict(args.input)
# Save predictions
predictor.save_predictions(predictions, args.input)
# Compare with ground truth if requested
if args.compare:
predictor.compare_with_ground_truth(predictions, args.input)
print("\nInference complete!")
if __name__ == "__main__":
main()