""" predict.py Inference script for AtomizerField trained models AtomizerField Inference v2.0 Uses trained GNN to predict FEA fields 1000x faster than traditional simulation. Usage: python predict.py --model checkpoint_best.pt --input case_001 This enables: - Rapid design exploration (milliseconds vs hours per analysis) - Real-time optimization - Interactive design feedback """ import argparse import json from pathlib import Path import time import torch import numpy as np import h5py from neural_models.field_predictor import AtomizerFieldModel from neural_models.data_loader import FEAMeshDataset class FieldPredictor: """ Inference engine for trained field prediction models """ def __init__(self, checkpoint_path, device=None): """ Initialize predictor Args: checkpoint_path (str): Path to trained model checkpoint device (str): Device to run on ('cuda' or 'cpu') """ if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) print(f"\nAtomizerField Inference Engine v2.0") print(f"Device: {self.device}") # Load checkpoint print(f"Loading model from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=self.device) # Create model model_config = checkpoint['config']['model'] self.model = AtomizerFieldModel(**model_config) self.model.load_state_dict(checkpoint['model_state_dict']) self.model = self.model.to(self.device) self.model.eval() self.config = checkpoint['config'] print(f"Model loaded (epoch {checkpoint['epoch']}, val_loss={checkpoint['best_val_loss']:.6f})") def predict(self, case_directory): """ Predict displacement and stress fields for a case Args: case_directory (str): Path to parsed FEA case Returns: dict: Predictions with displacement, stress, von_mises fields """ print(f"\nPredicting fields for {Path(case_directory).name}...") # Load data dataset = FEAMeshDataset( [case_directory], normalize=True, include_stress=False # Don't need ground truth for prediction ) if len(dataset) == 0: raise ValueError(f"Could not load case from {case_directory}") data = dataset[0].to(self.device) # Predict start_time = time.time() with torch.no_grad(): predictions = self.model(data, return_stress=True) inference_time = time.time() - start_time print(f"Prediction complete in {inference_time*1000:.1f} ms") # Convert to numpy results = { 'displacement': predictions['displacement'].cpu().numpy(), 'stress': predictions['stress'].cpu().numpy(), 'von_mises': predictions['von_mises'].cpu().numpy(), 'inference_time_ms': inference_time * 1000 } # Compute max values max_disp = np.max(np.linalg.norm(results['displacement'][:, :3], axis=1)) max_stress = np.max(results['von_mises']) results['max_displacement'] = float(max_disp) results['max_stress'] = float(max_stress) print(f"\nResults:") print(f" Max displacement: {max_disp:.6f} mm") print(f" Max von Mises stress: {max_stress:.2f} MPa") return results def save_predictions(self, predictions, case_directory, output_name='predicted'): """ Save predictions in same format as ground truth Args: predictions (dict): Prediction results case_directory (str): Case directory output_name (str): Output file name prefix """ case_dir = Path(case_directory) output_file = case_dir / f"{output_name}_fields.h5" print(f"\nSaving predictions to {output_file}...") with h5py.File(output_file, 'w') as f: # Save displacement f.create_dataset('displacement', data=predictions['displacement'], compression='gzip') # Save stress f.create_dataset('stress', data=predictions['stress'], compression='gzip') # Save von Mises f.create_dataset('von_mises', data=predictions['von_mises'], compression='gzip') # Save metadata f.attrs['max_displacement'] = predictions['max_displacement'] f.attrs['max_stress'] = predictions['max_stress'] f.attrs['inference_time_ms'] = predictions['inference_time_ms'] print(f"Predictions saved!") # Also save JSON summary summary_file = case_dir / f"{output_name}_summary.json" summary = { 'max_displacement': predictions['max_displacement'], 'max_stress': predictions['max_stress'], 'inference_time_ms': predictions['inference_time_ms'], 'num_nodes': len(predictions['displacement']) } with open(summary_file, 'w') as f: json.dump(summary, f, indent=2) print(f"Summary saved to {summary_file}") def compare_with_ground_truth(self, predictions, case_directory): """ Compare predictions with FEA ground truth Args: predictions (dict): Model predictions case_directory (str): Case directory with ground truth Returns: dict: Comparison metrics """ case_dir = Path(case_directory) h5_file = case_dir / "neural_field_data.h5" if not h5_file.exists(): print("No ground truth available for comparison") return None print("\nComparing with FEA ground truth...") # Load ground truth with h5py.File(h5_file, 'r') as f: gt_displacement = f['results/displacement'][:] # Try to load stress gt_stress = None if 'results/stress' in f: stress_group = f['results/stress'] for stress_type in stress_group.keys(): gt_stress = stress_group[stress_type]['data'][:] break # Compute errors pred_disp = predictions['displacement'] disp_error = np.linalg.norm(pred_disp - gt_displacement, axis=1) disp_magnitude = np.linalg.norm(gt_displacement, axis=1) rel_disp_error = disp_error / (disp_magnitude + 1e-8) metrics = { 'displacement': { 'mae': float(np.mean(disp_error)), 'rmse': float(np.sqrt(np.mean(disp_error**2))), 'relative_error': float(np.mean(rel_disp_error)), 'max_error': float(np.max(disp_error)) } } # Compare max values pred_max_disp = predictions['max_displacement'] gt_max_disp = float(np.max(disp_magnitude)) metrics['max_displacement_error'] = abs(pred_max_disp - gt_max_disp) metrics['max_displacement_relative_error'] = metrics['max_displacement_error'] / (gt_max_disp + 1e-8) if gt_stress is not None: pred_stress = predictions['stress'] stress_error = np.linalg.norm(pred_stress - gt_stress, axis=1) metrics['stress'] = { 'mae': float(np.mean(stress_error)), 'rmse': float(np.sqrt(np.mean(stress_error**2))), } # Print comparison print("\nComparison Results:") print(f" Displacement MAE: {metrics['displacement']['mae']:.6f} mm") print(f" Displacement RMSE: {metrics['displacement']['rmse']:.6f} mm") print(f" Displacement Relative Error: {metrics['displacement']['relative_error']*100:.2f}%") print(f" Max Displacement Error: {metrics['max_displacement_error']:.6f} mm ({metrics['max_displacement_relative_error']*100:.2f}%)") if 'stress' in metrics: print(f" Stress MAE: {metrics['stress']['mae']:.2f} MPa") print(f" Stress RMSE: {metrics['stress']['rmse']:.2f} MPa") return metrics def batch_predict(predictor, case_directories, output_dir=None): """ Run predictions on multiple cases Args: predictor (FieldPredictor): Initialized predictor case_directories (list): List of case directories output_dir (str): Optional output directory for results Returns: list: List of prediction results """ print(f"\n{'='*60}") print(f"Batch Prediction: {len(case_directories)} cases") print(f"{'='*60}") results = [] for i, case_dir in enumerate(case_directories, 1): print(f"\n[{i}/{len(case_directories)}] Processing {Path(case_dir).name}...") try: # Predict predictions = predictor.predict(case_dir) # Save predictions predictor.save_predictions(predictions, case_dir) # Compare with ground truth comparison = predictor.compare_with_ground_truth(predictions, case_dir) result = { 'case': str(case_dir), 'status': 'success', 'predictions': { 'max_displacement': predictions['max_displacement'], 'max_stress': predictions['max_stress'], 'inference_time_ms': predictions['inference_time_ms'] }, 'comparison': comparison } results.append(result) except Exception as e: print(f"ERROR: {e}") results.append({ 'case': str(case_dir), 'status': 'failed', 'error': str(e) }) # Save batch results if output_dir: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) results_file = output_path / 'batch_predictions.json' with open(results_file, 'w') as f: json.dump(results, f, indent=2) print(f"\nBatch results saved to {results_file}") # Print summary print(f"\n{'='*60}") print("Batch Prediction Summary") print(f"{'='*60}") successful = sum(1 for r in results if r['status'] == 'success') print(f"Successful: {successful}/{len(results)}") if successful > 0: avg_time = np.mean([r['predictions']['inference_time_ms'] for r in results if r['status'] == 'success']) print(f"Average inference time: {avg_time:.1f} ms") return results def main(): """ Main inference entry point """ parser = argparse.ArgumentParser(description='Predict FEA fields using trained model') parser.add_argument('--model', type=str, required=True, help='Path to model checkpoint') parser.add_argument('--input', type=str, required=True, help='Input case directory or directory containing multiple cases') parser.add_argument('--output_dir', type=str, default=None, help='Output directory for batch results') parser.add_argument('--batch', action='store_true', help='Process all subdirectories as separate cases') parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='Device to run on') parser.add_argument('--compare', action='store_true', help='Compare predictions with ground truth') args = parser.parse_args() # Create predictor predictor = FieldPredictor(args.model, device=args.device) input_path = Path(args.input) if args.batch: # Batch prediction case_dirs = [d for d in input_path.iterdir() if d.is_dir()] batch_predict(predictor, case_dirs, args.output_dir) else: # Single prediction predictions = predictor.predict(args.input) # Save predictions predictor.save_predictions(predictions, args.input) # Compare with ground truth if requested if args.compare: predictor.compare_with_ground_truth(predictions, args.input) print("\nInference complete!") if __name__ == "__main__": main()