Files
Atomizer/atomizer-field/predict.py

374 lines
12 KiB
Python
Raw Normal View History

"""
predict.py
Inference script for AtomizerField trained models
AtomizerField Inference v2.0
Uses trained GNN to predict FEA fields 1000x faster than traditional simulation.
Usage:
python predict.py --model checkpoint_best.pt --input case_001
This enables:
- Rapid design exploration (milliseconds vs hours per analysis)
- Real-time optimization
- Interactive design feedback
"""
import argparse
import json
from pathlib import Path
import time
import torch
import numpy as np
import h5py
from neural_models.field_predictor import AtomizerFieldModel
from neural_models.data_loader import FEAMeshDataset
class FieldPredictor:
"""
Inference engine for trained field prediction models
"""
def __init__(self, checkpoint_path, device=None):
"""
Initialize predictor
Args:
checkpoint_path (str): Path to trained model checkpoint
device (str): Device to run on ('cuda' or 'cpu')
"""
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"\nAtomizerField Inference Engine v2.0")
print(f"Device: {self.device}")
# Load checkpoint
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Create model
model_config = checkpoint['config']['model']
self.model = AtomizerFieldModel(**model_config)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
self.config = checkpoint['config']
print(f"Model loaded (epoch {checkpoint['epoch']}, val_loss={checkpoint['best_val_loss']:.6f})")
def predict(self, case_directory):
"""
Predict displacement and stress fields for a case
Args:
case_directory (str): Path to parsed FEA case
Returns:
dict: Predictions with displacement, stress, von_mises fields
"""
print(f"\nPredicting fields for {Path(case_directory).name}...")
# Load data
dataset = FEAMeshDataset(
[case_directory],
normalize=True,
include_stress=False # Don't need ground truth for prediction
)
if len(dataset) == 0:
raise ValueError(f"Could not load case from {case_directory}")
data = dataset[0].to(self.device)
# Predict
start_time = time.time()
with torch.no_grad():
predictions = self.model(data, return_stress=True)
inference_time = time.time() - start_time
print(f"Prediction complete in {inference_time*1000:.1f} ms")
# Convert to numpy
results = {
'displacement': predictions['displacement'].cpu().numpy(),
'stress': predictions['stress'].cpu().numpy(),
'von_mises': predictions['von_mises'].cpu().numpy(),
'inference_time_ms': inference_time * 1000
}
# Compute max values
max_disp = np.max(np.linalg.norm(results['displacement'][:, :3], axis=1))
max_stress = np.max(results['von_mises'])
results['max_displacement'] = float(max_disp)
results['max_stress'] = float(max_stress)
print(f"\nResults:")
print(f" Max displacement: {max_disp:.6f} mm")
print(f" Max von Mises stress: {max_stress:.2f} MPa")
return results
def save_predictions(self, predictions, case_directory, output_name='predicted'):
"""
Save predictions in same format as ground truth
Args:
predictions (dict): Prediction results
case_directory (str): Case directory
output_name (str): Output file name prefix
"""
case_dir = Path(case_directory)
output_file = case_dir / f"{output_name}_fields.h5"
print(f"\nSaving predictions to {output_file}...")
with h5py.File(output_file, 'w') as f:
# Save displacement
f.create_dataset('displacement',
data=predictions['displacement'],
compression='gzip')
# Save stress
f.create_dataset('stress',
data=predictions['stress'],
compression='gzip')
# Save von Mises
f.create_dataset('von_mises',
data=predictions['von_mises'],
compression='gzip')
# Save metadata
f.attrs['max_displacement'] = predictions['max_displacement']
f.attrs['max_stress'] = predictions['max_stress']
f.attrs['inference_time_ms'] = predictions['inference_time_ms']
print(f"Predictions saved!")
# Also save JSON summary
summary_file = case_dir / f"{output_name}_summary.json"
summary = {
'max_displacement': predictions['max_displacement'],
'max_stress': predictions['max_stress'],
'inference_time_ms': predictions['inference_time_ms'],
'num_nodes': len(predictions['displacement'])
}
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Summary saved to {summary_file}")
def compare_with_ground_truth(self, predictions, case_directory):
"""
Compare predictions with FEA ground truth
Args:
predictions (dict): Model predictions
case_directory (str): Case directory with ground truth
Returns:
dict: Comparison metrics
"""
case_dir = Path(case_directory)
h5_file = case_dir / "neural_field_data.h5"
if not h5_file.exists():
print("No ground truth available for comparison")
return None
print("\nComparing with FEA ground truth...")
# Load ground truth
with h5py.File(h5_file, 'r') as f:
gt_displacement = f['results/displacement'][:]
# Try to load stress
gt_stress = None
if 'results/stress' in f:
stress_group = f['results/stress']
for stress_type in stress_group.keys():
gt_stress = stress_group[stress_type]['data'][:]
break
# Compute errors
pred_disp = predictions['displacement']
disp_error = np.linalg.norm(pred_disp - gt_displacement, axis=1)
disp_magnitude = np.linalg.norm(gt_displacement, axis=1)
rel_disp_error = disp_error / (disp_magnitude + 1e-8)
metrics = {
'displacement': {
'mae': float(np.mean(disp_error)),
'rmse': float(np.sqrt(np.mean(disp_error**2))),
'relative_error': float(np.mean(rel_disp_error)),
'max_error': float(np.max(disp_error))
}
}
# Compare max values
pred_max_disp = predictions['max_displacement']
gt_max_disp = float(np.max(disp_magnitude))
metrics['max_displacement_error'] = abs(pred_max_disp - gt_max_disp)
metrics['max_displacement_relative_error'] = metrics['max_displacement_error'] / (gt_max_disp + 1e-8)
if gt_stress is not None:
pred_stress = predictions['stress']
stress_error = np.linalg.norm(pred_stress - gt_stress, axis=1)
metrics['stress'] = {
'mae': float(np.mean(stress_error)),
'rmse': float(np.sqrt(np.mean(stress_error**2))),
}
# Print comparison
print("\nComparison Results:")
print(f" Displacement MAE: {metrics['displacement']['mae']:.6f} mm")
print(f" Displacement RMSE: {metrics['displacement']['rmse']:.6f} mm")
print(f" Displacement Relative Error: {metrics['displacement']['relative_error']*100:.2f}%")
print(f" Max Displacement Error: {metrics['max_displacement_error']:.6f} mm ({metrics['max_displacement_relative_error']*100:.2f}%)")
if 'stress' in metrics:
print(f" Stress MAE: {metrics['stress']['mae']:.2f} MPa")
print(f" Stress RMSE: {metrics['stress']['rmse']:.2f} MPa")
return metrics
def batch_predict(predictor, case_directories, output_dir=None):
"""
Run predictions on multiple cases
Args:
predictor (FieldPredictor): Initialized predictor
case_directories (list): List of case directories
output_dir (str): Optional output directory for results
Returns:
list: List of prediction results
"""
print(f"\n{'='*60}")
print(f"Batch Prediction: {len(case_directories)} cases")
print(f"{'='*60}")
results = []
for i, case_dir in enumerate(case_directories, 1):
print(f"\n[{i}/{len(case_directories)}] Processing {Path(case_dir).name}...")
try:
# Predict
predictions = predictor.predict(case_dir)
# Save predictions
predictor.save_predictions(predictions, case_dir)
# Compare with ground truth
comparison = predictor.compare_with_ground_truth(predictions, case_dir)
result = {
'case': str(case_dir),
'status': 'success',
'predictions': {
'max_displacement': predictions['max_displacement'],
'max_stress': predictions['max_stress'],
'inference_time_ms': predictions['inference_time_ms']
},
'comparison': comparison
}
results.append(result)
except Exception as e:
print(f"ERROR: {e}")
results.append({
'case': str(case_dir),
'status': 'failed',
'error': str(e)
})
# Save batch results
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
results_file = output_path / 'batch_predictions.json'
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nBatch results saved to {results_file}")
# Print summary
print(f"\n{'='*60}")
print("Batch Prediction Summary")
print(f"{'='*60}")
successful = sum(1 for r in results if r['status'] == 'success')
print(f"Successful: {successful}/{len(results)}")
if successful > 0:
avg_time = np.mean([r['predictions']['inference_time_ms']
for r in results if r['status'] == 'success'])
print(f"Average inference time: {avg_time:.1f} ms")
return results
def main():
"""
Main inference entry point
"""
parser = argparse.ArgumentParser(description='Predict FEA fields using trained model')
parser.add_argument('--model', type=str, required=True,
help='Path to model checkpoint')
parser.add_argument('--input', type=str, required=True,
help='Input case directory or directory containing multiple cases')
parser.add_argument('--output_dir', type=str, default=None,
help='Output directory for batch results')
parser.add_argument('--batch', action='store_true',
help='Process all subdirectories as separate cases')
parser.add_argument('--device', type=str, default=None,
choices=['cuda', 'cpu'],
help='Device to run on')
parser.add_argument('--compare', action='store_true',
help='Compare predictions with ground truth')
args = parser.parse_args()
# Create predictor
predictor = FieldPredictor(args.model, device=args.device)
input_path = Path(args.input)
if args.batch:
# Batch prediction
case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
batch_predict(predictor, case_dirs, args.output_dir)
else:
# Single prediction
predictions = predictor.predict(args.input)
# Save predictions
predictor.save_predictions(predictions, args.input)
# Compare with ground truth if requested
if args.compare:
predictor.compare_with_ground_truth(predictions, args.input)
print("\nInference complete!")
if __name__ == "__main__":
main()