374 lines
12 KiB
Python
374 lines
12 KiB
Python
|
|
"""
|
||
|
|
predict.py
|
||
|
|
Inference script for AtomizerField trained models
|
||
|
|
|
||
|
|
AtomizerField Inference v2.0
|
||
|
|
Uses trained GNN to predict FEA fields 1000x faster than traditional simulation.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python predict.py --model checkpoint_best.pt --input case_001
|
||
|
|
|
||
|
|
This enables:
|
||
|
|
- Rapid design exploration (milliseconds vs hours per analysis)
|
||
|
|
- Real-time optimization
|
||
|
|
- Interactive design feedback
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
import time
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
import h5py
|
||
|
|
|
||
|
|
from neural_models.field_predictor import AtomizerFieldModel
|
||
|
|
from neural_models.data_loader import FEAMeshDataset
|
||
|
|
|
||
|
|
|
||
|
|
class FieldPredictor:
|
||
|
|
"""
|
||
|
|
Inference engine for trained field prediction models
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, checkpoint_path, device=None):
|
||
|
|
"""
|
||
|
|
Initialize predictor
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpoint_path (str): Path to trained model checkpoint
|
||
|
|
device (str): Device to run on ('cuda' or 'cpu')
|
||
|
|
"""
|
||
|
|
if device is None:
|
||
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
|
else:
|
||
|
|
self.device = torch.device(device)
|
||
|
|
|
||
|
|
print(f"\nAtomizerField Inference Engine v2.0")
|
||
|
|
print(f"Device: {self.device}")
|
||
|
|
|
||
|
|
# Load checkpoint
|
||
|
|
print(f"Loading model from {checkpoint_path}...")
|
||
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
model_config = checkpoint['config']['model']
|
||
|
|
self.model = AtomizerFieldModel(**model_config)
|
||
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
self.model = self.model.to(self.device)
|
||
|
|
self.model.eval()
|
||
|
|
|
||
|
|
self.config = checkpoint['config']
|
||
|
|
|
||
|
|
print(f"Model loaded (epoch {checkpoint['epoch']}, val_loss={checkpoint['best_val_loss']:.6f})")
|
||
|
|
|
||
|
|
def predict(self, case_directory):
|
||
|
|
"""
|
||
|
|
Predict displacement and stress fields for a case
|
||
|
|
|
||
|
|
Args:
|
||
|
|
case_directory (str): Path to parsed FEA case
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Predictions with displacement, stress, von_mises fields
|
||
|
|
"""
|
||
|
|
print(f"\nPredicting fields for {Path(case_directory).name}...")
|
||
|
|
|
||
|
|
# Load data
|
||
|
|
dataset = FEAMeshDataset(
|
||
|
|
[case_directory],
|
||
|
|
normalize=True,
|
||
|
|
include_stress=False # Don't need ground truth for prediction
|
||
|
|
)
|
||
|
|
|
||
|
|
if len(dataset) == 0:
|
||
|
|
raise ValueError(f"Could not load case from {case_directory}")
|
||
|
|
|
||
|
|
data = dataset[0].to(self.device)
|
||
|
|
|
||
|
|
# Predict
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
predictions = self.model(data, return_stress=True)
|
||
|
|
|
||
|
|
inference_time = time.time() - start_time
|
||
|
|
|
||
|
|
print(f"Prediction complete in {inference_time*1000:.1f} ms")
|
||
|
|
|
||
|
|
# Convert to numpy
|
||
|
|
results = {
|
||
|
|
'displacement': predictions['displacement'].cpu().numpy(),
|
||
|
|
'stress': predictions['stress'].cpu().numpy(),
|
||
|
|
'von_mises': predictions['von_mises'].cpu().numpy(),
|
||
|
|
'inference_time_ms': inference_time * 1000
|
||
|
|
}
|
||
|
|
|
||
|
|
# Compute max values
|
||
|
|
max_disp = np.max(np.linalg.norm(results['displacement'][:, :3], axis=1))
|
||
|
|
max_stress = np.max(results['von_mises'])
|
||
|
|
|
||
|
|
results['max_displacement'] = float(max_disp)
|
||
|
|
results['max_stress'] = float(max_stress)
|
||
|
|
|
||
|
|
print(f"\nResults:")
|
||
|
|
print(f" Max displacement: {max_disp:.6f} mm")
|
||
|
|
print(f" Max von Mises stress: {max_stress:.2f} MPa")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def save_predictions(self, predictions, case_directory, output_name='predicted'):
|
||
|
|
"""
|
||
|
|
Save predictions in same format as ground truth
|
||
|
|
|
||
|
|
Args:
|
||
|
|
predictions (dict): Prediction results
|
||
|
|
case_directory (str): Case directory
|
||
|
|
output_name (str): Output file name prefix
|
||
|
|
"""
|
||
|
|
case_dir = Path(case_directory)
|
||
|
|
output_file = case_dir / f"{output_name}_fields.h5"
|
||
|
|
|
||
|
|
print(f"\nSaving predictions to {output_file}...")
|
||
|
|
|
||
|
|
with h5py.File(output_file, 'w') as f:
|
||
|
|
# Save displacement
|
||
|
|
f.create_dataset('displacement',
|
||
|
|
data=predictions['displacement'],
|
||
|
|
compression='gzip')
|
||
|
|
|
||
|
|
# Save stress
|
||
|
|
f.create_dataset('stress',
|
||
|
|
data=predictions['stress'],
|
||
|
|
compression='gzip')
|
||
|
|
|
||
|
|
# Save von Mises
|
||
|
|
f.create_dataset('von_mises',
|
||
|
|
data=predictions['von_mises'],
|
||
|
|
compression='gzip')
|
||
|
|
|
||
|
|
# Save metadata
|
||
|
|
f.attrs['max_displacement'] = predictions['max_displacement']
|
||
|
|
f.attrs['max_stress'] = predictions['max_stress']
|
||
|
|
f.attrs['inference_time_ms'] = predictions['inference_time_ms']
|
||
|
|
|
||
|
|
print(f"Predictions saved!")
|
||
|
|
|
||
|
|
# Also save JSON summary
|
||
|
|
summary_file = case_dir / f"{output_name}_summary.json"
|
||
|
|
summary = {
|
||
|
|
'max_displacement': predictions['max_displacement'],
|
||
|
|
'max_stress': predictions['max_stress'],
|
||
|
|
'inference_time_ms': predictions['inference_time_ms'],
|
||
|
|
'num_nodes': len(predictions['displacement'])
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(summary_file, 'w') as f:
|
||
|
|
json.dump(summary, f, indent=2)
|
||
|
|
|
||
|
|
print(f"Summary saved to {summary_file}")
|
||
|
|
|
||
|
|
def compare_with_ground_truth(self, predictions, case_directory):
|
||
|
|
"""
|
||
|
|
Compare predictions with FEA ground truth
|
||
|
|
|
||
|
|
Args:
|
||
|
|
predictions (dict): Model predictions
|
||
|
|
case_directory (str): Case directory with ground truth
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Comparison metrics
|
||
|
|
"""
|
||
|
|
case_dir = Path(case_directory)
|
||
|
|
h5_file = case_dir / "neural_field_data.h5"
|
||
|
|
|
||
|
|
if not h5_file.exists():
|
||
|
|
print("No ground truth available for comparison")
|
||
|
|
return None
|
||
|
|
|
||
|
|
print("\nComparing with FEA ground truth...")
|
||
|
|
|
||
|
|
# Load ground truth
|
||
|
|
with h5py.File(h5_file, 'r') as f:
|
||
|
|
gt_displacement = f['results/displacement'][:]
|
||
|
|
|
||
|
|
# Try to load stress
|
||
|
|
gt_stress = None
|
||
|
|
if 'results/stress' in f:
|
||
|
|
stress_group = f['results/stress']
|
||
|
|
for stress_type in stress_group.keys():
|
||
|
|
gt_stress = stress_group[stress_type]['data'][:]
|
||
|
|
break
|
||
|
|
|
||
|
|
# Compute errors
|
||
|
|
pred_disp = predictions['displacement']
|
||
|
|
disp_error = np.linalg.norm(pred_disp - gt_displacement, axis=1)
|
||
|
|
disp_magnitude = np.linalg.norm(gt_displacement, axis=1)
|
||
|
|
rel_disp_error = disp_error / (disp_magnitude + 1e-8)
|
||
|
|
|
||
|
|
metrics = {
|
||
|
|
'displacement': {
|
||
|
|
'mae': float(np.mean(disp_error)),
|
||
|
|
'rmse': float(np.sqrt(np.mean(disp_error**2))),
|
||
|
|
'relative_error': float(np.mean(rel_disp_error)),
|
||
|
|
'max_error': float(np.max(disp_error))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Compare max values
|
||
|
|
pred_max_disp = predictions['max_displacement']
|
||
|
|
gt_max_disp = float(np.max(disp_magnitude))
|
||
|
|
metrics['max_displacement_error'] = abs(pred_max_disp - gt_max_disp)
|
||
|
|
metrics['max_displacement_relative_error'] = metrics['max_displacement_error'] / (gt_max_disp + 1e-8)
|
||
|
|
|
||
|
|
if gt_stress is not None:
|
||
|
|
pred_stress = predictions['stress']
|
||
|
|
stress_error = np.linalg.norm(pred_stress - gt_stress, axis=1)
|
||
|
|
|
||
|
|
metrics['stress'] = {
|
||
|
|
'mae': float(np.mean(stress_error)),
|
||
|
|
'rmse': float(np.sqrt(np.mean(stress_error**2))),
|
||
|
|
}
|
||
|
|
|
||
|
|
# Print comparison
|
||
|
|
print("\nComparison Results:")
|
||
|
|
print(f" Displacement MAE: {metrics['displacement']['mae']:.6f} mm")
|
||
|
|
print(f" Displacement RMSE: {metrics['displacement']['rmse']:.6f} mm")
|
||
|
|
print(f" Displacement Relative Error: {metrics['displacement']['relative_error']*100:.2f}%")
|
||
|
|
print(f" Max Displacement Error: {metrics['max_displacement_error']:.6f} mm ({metrics['max_displacement_relative_error']*100:.2f}%)")
|
||
|
|
|
||
|
|
if 'stress' in metrics:
|
||
|
|
print(f" Stress MAE: {metrics['stress']['mae']:.2f} MPa")
|
||
|
|
print(f" Stress RMSE: {metrics['stress']['rmse']:.2f} MPa")
|
||
|
|
|
||
|
|
return metrics
|
||
|
|
|
||
|
|
|
||
|
|
def batch_predict(predictor, case_directories, output_dir=None):
|
||
|
|
"""
|
||
|
|
Run predictions on multiple cases
|
||
|
|
|
||
|
|
Args:
|
||
|
|
predictor (FieldPredictor): Initialized predictor
|
||
|
|
case_directories (list): List of case directories
|
||
|
|
output_dir (str): Optional output directory for results
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
list: List of prediction results
|
||
|
|
"""
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Batch Prediction: {len(case_directories)} cases")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
|
||
|
|
results = []
|
||
|
|
|
||
|
|
for i, case_dir in enumerate(case_directories, 1):
|
||
|
|
print(f"\n[{i}/{len(case_directories)}] Processing {Path(case_dir).name}...")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Predict
|
||
|
|
predictions = predictor.predict(case_dir)
|
||
|
|
|
||
|
|
# Save predictions
|
||
|
|
predictor.save_predictions(predictions, case_dir)
|
||
|
|
|
||
|
|
# Compare with ground truth
|
||
|
|
comparison = predictor.compare_with_ground_truth(predictions, case_dir)
|
||
|
|
|
||
|
|
result = {
|
||
|
|
'case': str(case_dir),
|
||
|
|
'status': 'success',
|
||
|
|
'predictions': {
|
||
|
|
'max_displacement': predictions['max_displacement'],
|
||
|
|
'max_stress': predictions['max_stress'],
|
||
|
|
'inference_time_ms': predictions['inference_time_ms']
|
||
|
|
},
|
||
|
|
'comparison': comparison
|
||
|
|
}
|
||
|
|
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"ERROR: {e}")
|
||
|
|
results.append({
|
||
|
|
'case': str(case_dir),
|
||
|
|
'status': 'failed',
|
||
|
|
'error': str(e)
|
||
|
|
})
|
||
|
|
|
||
|
|
# Save batch results
|
||
|
|
if output_dir:
|
||
|
|
output_path = Path(output_dir)
|
||
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
results_file = output_path / 'batch_predictions.json'
|
||
|
|
with open(results_file, 'w') as f:
|
||
|
|
json.dump(results, f, indent=2)
|
||
|
|
|
||
|
|
print(f"\nBatch results saved to {results_file}")
|
||
|
|
|
||
|
|
# Print summary
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print("Batch Prediction Summary")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
successful = sum(1 for r in results if r['status'] == 'success')
|
||
|
|
print(f"Successful: {successful}/{len(results)}")
|
||
|
|
|
||
|
|
if successful > 0:
|
||
|
|
avg_time = np.mean([r['predictions']['inference_time_ms']
|
||
|
|
for r in results if r['status'] == 'success'])
|
||
|
|
print(f"Average inference time: {avg_time:.1f} ms")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""
|
||
|
|
Main inference entry point
|
||
|
|
"""
|
||
|
|
parser = argparse.ArgumentParser(description='Predict FEA fields using trained model')
|
||
|
|
|
||
|
|
parser.add_argument('--model', type=str, required=True,
|
||
|
|
help='Path to model checkpoint')
|
||
|
|
parser.add_argument('--input', type=str, required=True,
|
||
|
|
help='Input case directory or directory containing multiple cases')
|
||
|
|
parser.add_argument('--output_dir', type=str, default=None,
|
||
|
|
help='Output directory for batch results')
|
||
|
|
parser.add_argument('--batch', action='store_true',
|
||
|
|
help='Process all subdirectories as separate cases')
|
||
|
|
parser.add_argument('--device', type=str, default=None,
|
||
|
|
choices=['cuda', 'cpu'],
|
||
|
|
help='Device to run on')
|
||
|
|
parser.add_argument('--compare', action='store_true',
|
||
|
|
help='Compare predictions with ground truth')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Create predictor
|
||
|
|
predictor = FieldPredictor(args.model, device=args.device)
|
||
|
|
|
||
|
|
input_path = Path(args.input)
|
||
|
|
|
||
|
|
if args.batch:
|
||
|
|
# Batch prediction
|
||
|
|
case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
|
||
|
|
batch_predict(predictor, case_dirs, args.output_dir)
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Single prediction
|
||
|
|
predictions = predictor.predict(args.input)
|
||
|
|
|
||
|
|
# Save predictions
|
||
|
|
predictor.save_predictions(predictions, args.input)
|
||
|
|
|
||
|
|
# Compare with ground truth if requested
|
||
|
|
if args.compare:
|
||
|
|
predictor.compare_with_ground_truth(predictions, args.input)
|
||
|
|
|
||
|
|
print("\nInference complete!")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|