""" train.py Training script for AtomizerField neural field predictor AtomizerField Training Pipeline v2.0 Trains Graph Neural Networks to predict complete FEA field results. Usage: python train.py --train_dir ./training_data --val_dir ./validation_data Key Features: - Multi-GPU support - Checkpoint saving/loading - TensorBoard logging - Early stopping - Learning rate scheduling """ import argparse import json from pathlib import Path import time from datetime import datetime import torch import torch.nn as nn import torch.optim as optim from torch.utils.tensorboard import SummaryWriter from neural_models.field_predictor import create_model, AtomizerFieldModel from neural_models.physics_losses import create_loss_function from neural_models.data_loader import create_dataloaders class Trainer: """ Training manager for AtomizerField models """ def __init__(self, config): """ Initialize trainer Args: config (dict): Training configuration """ self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\n{'='*60}") print("AtomizerField Training Pipeline v2.0") print(f"{'='*60}") print(f"Device: {self.device}") # Create model print("\nCreating model...") self.model = create_model(config.get('model', {})) self.model = self.model.to(self.device) num_params = sum(p.numel() for p in self.model.parameters()) print(f"Model created: {num_params:,} parameters") # Create loss function loss_config = config.get('loss', {}) loss_type = loss_config.pop('type', 'mse') self.criterion = create_loss_function(loss_type, loss_config) print(f"Loss function: {loss_type}") # Create optimizer self.optimizer = optim.AdamW( self.model.parameters(), lr=config.get('learning_rate', 1e-3), weight_decay=config.get('weight_decay', 1e-5) ) # Learning rate scheduler self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=10, verbose=True ) # Training state self.start_epoch = 0 self.best_val_loss = float('inf') self.epochs_without_improvement = 0 # Create output directories self.output_dir = Path(config.get('output_dir', './runs')) self.output_dir.mkdir(parents=True, exist_ok=True) # TensorBoard logging self.writer = SummaryWriter( log_dir=self.output_dir / 'tensorboard' ) # Save config with open(self.output_dir / 'config.json', 'w') as f: json.dump(config, f, indent=2) def train_epoch(self, train_loader, epoch): """ Train for one epoch Args: train_loader: Training data loader epoch (int): Current epoch number Returns: dict: Training metrics """ self.model.train() total_loss = 0.0 total_disp_loss = 0.0 total_stress_loss = 0.0 num_batches = 0 for batch_idx, batch in enumerate(train_loader): # Move batch to device batch = batch.to(self.device) # Zero gradients self.optimizer.zero_grad() # Forward pass predictions = self.model(batch, return_stress=True) # Prepare targets targets = { 'displacement': batch.y_displacement, } if hasattr(batch, 'y_stress'): targets['stress'] = batch.y_stress # Compute loss losses = self.criterion(predictions, targets, batch) # Backward pass losses['total_loss'].backward() # Gradient clipping (prevents exploding gradients) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # Update weights self.optimizer.step() # Accumulate metrics total_loss += losses['total_loss'].item() if 'displacement_loss' in losses: total_disp_loss += losses['displacement_loss'].item() if 'stress_loss' in losses: total_stress_loss += losses['stress_loss'].item() num_batches += 1 # Print progress if batch_idx % 10 == 0: print(f" Batch {batch_idx}/{len(train_loader)}: " f"Loss={losses['total_loss'].item():.6f}") # Average metrics metrics = { 'total_loss': total_loss / num_batches, 'displacement_loss': total_disp_loss / num_batches, 'stress_loss': total_stress_loss / num_batches } return metrics def validate(self, val_loader): """ Validate model Args: val_loader: Validation data loader Returns: dict: Validation metrics """ self.model.eval() total_loss = 0.0 total_disp_loss = 0.0 total_stress_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in val_loader: # Move batch to device batch = batch.to(self.device) # Forward pass predictions = self.model(batch, return_stress=True) # Prepare targets targets = { 'displacement': batch.y_displacement, } if hasattr(batch, 'y_stress'): targets['stress'] = batch.y_stress # Compute loss losses = self.criterion(predictions, targets, batch) # Accumulate metrics total_loss += losses['total_loss'].item() if 'displacement_loss' in losses: total_disp_loss += losses['displacement_loss'].item() if 'stress_loss' in losses: total_stress_loss += losses['stress_loss'].item() num_batches += 1 # Average metrics metrics = { 'total_loss': total_loss / num_batches, 'displacement_loss': total_disp_loss / num_batches, 'stress_loss': total_stress_loss / num_batches } return metrics def train(self, train_loader, val_loader, num_epochs): """ Main training loop Args: train_loader: Training data loader val_loader: Validation data loader num_epochs (int): Number of epochs to train """ print(f"\n{'='*60}") print(f"Starting training for {num_epochs} epochs") print(f"{'='*60}\n") for epoch in range(self.start_epoch, num_epochs): epoch_start_time = time.time() print(f"Epoch {epoch + 1}/{num_epochs}") print("-" * 60) # Train train_metrics = self.train_epoch(train_loader, epoch) # Validate val_metrics = self.validate(val_loader) epoch_time = time.time() - epoch_start_time # Print metrics print(f"\nEpoch {epoch + 1} Results:") print(f" Training Loss: {train_metrics['total_loss']:.6f}") print(f" Displacement: {train_metrics['displacement_loss']:.6f}") print(f" Stress: {train_metrics['stress_loss']:.6f}") print(f" Validation Loss: {val_metrics['total_loss']:.6f}") print(f" Displacement: {val_metrics['displacement_loss']:.6f}") print(f" Stress: {val_metrics['stress_loss']:.6f}") print(f" Time: {epoch_time:.1f}s") # Log to TensorBoard self.writer.add_scalar('Loss/train', train_metrics['total_loss'], epoch) self.writer.add_scalar('Loss/val', val_metrics['total_loss'], epoch) self.writer.add_scalar('DisplacementLoss/train', train_metrics['displacement_loss'], epoch) self.writer.add_scalar('DisplacementLoss/val', val_metrics['displacement_loss'], epoch) self.writer.add_scalar('StressLoss/train', train_metrics['stress_loss'], epoch) self.writer.add_scalar('StressLoss/val', val_metrics['stress_loss'], epoch) self.writer.add_scalar('LearningRate', self.optimizer.param_groups[0]['lr'], epoch) # Learning rate scheduling self.scheduler.step(val_metrics['total_loss']) # Save checkpoint is_best = val_metrics['total_loss'] < self.best_val_loss if is_best: self.best_val_loss = val_metrics['total_loss'] self.epochs_without_improvement = 0 print(f" New best validation loss: {self.best_val_loss:.6f}") else: self.epochs_without_improvement += 1 self.save_checkpoint(epoch, val_metrics, is_best) # Early stopping patience = self.config.get('early_stopping_patience', 50) if self.epochs_without_improvement >= patience: print(f"\nEarly stopping after {patience} epochs without improvement") break print() print(f"\n{'='*60}") print("Training complete!") print(f"Best validation loss: {self.best_val_loss:.6f}") print(f"{'='*60}\n") self.writer.close() def save_checkpoint(self, epoch, metrics, is_best=False): """ Save model checkpoint Args: epoch (int): Current epoch metrics (dict): Validation metrics is_best (bool): Whether this is the best model so far """ checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_val_loss': self.best_val_loss, 'config': self.config, 'metrics': metrics } # Save latest checkpoint checkpoint_path = self.output_dir / 'checkpoint_latest.pt' torch.save(checkpoint, checkpoint_path) # Save best checkpoint if is_best: best_path = self.output_dir / 'checkpoint_best.pt' torch.save(checkpoint, best_path) print(f" Saved best model to {best_path}") # Save periodic checkpoint if (epoch + 1) % 10 == 0: periodic_path = self.output_dir / f'checkpoint_epoch_{epoch + 1}.pt' torch.save(checkpoint, periodic_path) def load_checkpoint(self, checkpoint_path): """ Load model checkpoint Args: checkpoint_path (str): Path to checkpoint file """ checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.start_epoch = checkpoint['epoch'] + 1 self.best_val_loss = checkpoint['best_val_loss'] print(f"Loaded checkpoint from epoch {checkpoint['epoch']}") print(f"Best validation loss: {self.best_val_loss:.6f}") def main(): """ Main training entry point """ parser = argparse.ArgumentParser(description='Train AtomizerField neural field predictor') # Data arguments parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training cases') parser.add_argument('--val_dir', type=str, required=True, help='Directory containing validation cases') # Training arguments parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=4, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') # Model arguments parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension') parser.add_argument('--num_layers', type=int, default=6, help='Number of GNN layers') parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate') # Loss arguments parser.add_argument('--loss_type', type=str, default='mse', choices=['mse', 'relative', 'physics', 'max'], help='Loss function type') # Other arguments parser.add_argument('--output_dir', type=str, default='./runs', help='Output directory for checkpoints and logs') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') parser.add_argument('--num_workers', type=int, default=0, help='Number of data loading workers') args = parser.parse_args() # Build configuration config = { 'model': { 'node_feature_dim': 12, # 3 coords + 6 BCs + 3 loads 'edge_feature_dim': 5, # E, nu, rho, G, alpha 'hidden_dim': args.hidden_dim, 'num_layers': args.num_layers, 'dropout': args.dropout }, 'loss': { 'type': args.loss_type }, 'learning_rate': args.lr, 'weight_decay': args.weight_decay, 'batch_size': args.batch_size, 'num_epochs': args.epochs, 'output_dir': args.output_dir, 'early_stopping_patience': 50 } # Find all case directories train_cases = list(Path(args.train_dir).glob('*/')) val_cases = list(Path(args.val_dir).glob('*/')) print(f"Found {len(train_cases)} training cases") print(f"Found {len(val_cases)} validation cases") if not train_cases or not val_cases: print("ERROR: No training or validation cases found!") print("Please ensure your directories contain parsed FEA data.") return # Create data loaders train_loader, val_loader = create_dataloaders( train_cases, val_cases, batch_size=args.batch_size, num_workers=args.num_workers, normalize=True, include_stress=True ) # Create trainer trainer = Trainer(config) # Resume from checkpoint if specified if args.resume: trainer.load_checkpoint(args.resume) # Train trainer.train(train_loader, val_loader, args.epochs) if __name__ == "__main__": main()