452 lines
15 KiB
Python
452 lines
15 KiB
Python
|
|
"""
|
||
|
|
train.py
|
||
|
|
Training script for AtomizerField neural field predictor
|
||
|
|
|
||
|
|
AtomizerField Training Pipeline v2.0
|
||
|
|
Trains Graph Neural Networks to predict complete FEA field results.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python train.py --train_dir ./training_data --val_dir ./validation_data
|
||
|
|
|
||
|
|
Key Features:
|
||
|
|
- Multi-GPU support
|
||
|
|
- Checkpoint saving/loading
|
||
|
|
- TensorBoard logging
|
||
|
|
- Early stopping
|
||
|
|
- Learning rate scheduling
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
import time
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.optim as optim
|
||
|
|
from torch.utils.tensorboard import SummaryWriter
|
||
|
|
|
||
|
|
from neural_models.field_predictor import create_model, AtomizerFieldModel
|
||
|
|
from neural_models.physics_losses import create_loss_function
|
||
|
|
from neural_models.data_loader import create_dataloaders
|
||
|
|
|
||
|
|
|
||
|
|
class Trainer:
|
||
|
|
"""
|
||
|
|
Training manager for AtomizerField models
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, config):
|
||
|
|
"""
|
||
|
|
Initialize trainer
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config (dict): Training configuration
|
||
|
|
"""
|
||
|
|
self.config = config
|
||
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print("AtomizerField Training Pipeline v2.0")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"Device: {self.device}")
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
print("\nCreating model...")
|
||
|
|
self.model = create_model(config.get('model', {}))
|
||
|
|
self.model = self.model.to(self.device)
|
||
|
|
|
||
|
|
num_params = sum(p.numel() for p in self.model.parameters())
|
||
|
|
print(f"Model created: {num_params:,} parameters")
|
||
|
|
|
||
|
|
# Create loss function
|
||
|
|
loss_config = config.get('loss', {})
|
||
|
|
loss_type = loss_config.pop('type', 'mse')
|
||
|
|
self.criterion = create_loss_function(loss_type, loss_config)
|
||
|
|
print(f"Loss function: {loss_type}")
|
||
|
|
|
||
|
|
# Create optimizer
|
||
|
|
self.optimizer = optim.AdamW(
|
||
|
|
self.model.parameters(),
|
||
|
|
lr=config.get('learning_rate', 1e-3),
|
||
|
|
weight_decay=config.get('weight_decay', 1e-5)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Learning rate scheduler
|
||
|
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
|
self.optimizer,
|
||
|
|
mode='min',
|
||
|
|
factor=0.5,
|
||
|
|
patience=10,
|
||
|
|
verbose=True
|
||
|
|
)
|
||
|
|
|
||
|
|
# Training state
|
||
|
|
self.start_epoch = 0
|
||
|
|
self.best_val_loss = float('inf')
|
||
|
|
self.epochs_without_improvement = 0
|
||
|
|
|
||
|
|
# Create output directories
|
||
|
|
self.output_dir = Path(config.get('output_dir', './runs'))
|
||
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# TensorBoard logging
|
||
|
|
self.writer = SummaryWriter(
|
||
|
|
log_dir=self.output_dir / 'tensorboard'
|
||
|
|
)
|
||
|
|
|
||
|
|
# Save config
|
||
|
|
with open(self.output_dir / 'config.json', 'w') as f:
|
||
|
|
json.dump(config, f, indent=2)
|
||
|
|
|
||
|
|
def train_epoch(self, train_loader, epoch):
|
||
|
|
"""
|
||
|
|
Train for one epoch
|
||
|
|
|
||
|
|
Args:
|
||
|
|
train_loader: Training data loader
|
||
|
|
epoch (int): Current epoch number
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Training metrics
|
||
|
|
"""
|
||
|
|
self.model.train()
|
||
|
|
|
||
|
|
total_loss = 0.0
|
||
|
|
total_disp_loss = 0.0
|
||
|
|
total_stress_loss = 0.0
|
||
|
|
num_batches = 0
|
||
|
|
|
||
|
|
for batch_idx, batch in enumerate(train_loader):
|
||
|
|
# Move batch to device
|
||
|
|
batch = batch.to(self.device)
|
||
|
|
|
||
|
|
# Zero gradients
|
||
|
|
self.optimizer.zero_grad()
|
||
|
|
|
||
|
|
# Forward pass
|
||
|
|
predictions = self.model(batch, return_stress=True)
|
||
|
|
|
||
|
|
# Prepare targets
|
||
|
|
targets = {
|
||
|
|
'displacement': batch.y_displacement,
|
||
|
|
}
|
||
|
|
if hasattr(batch, 'y_stress'):
|
||
|
|
targets['stress'] = batch.y_stress
|
||
|
|
|
||
|
|
# Compute loss
|
||
|
|
losses = self.criterion(predictions, targets, batch)
|
||
|
|
|
||
|
|
# Backward pass
|
||
|
|
losses['total_loss'].backward()
|
||
|
|
|
||
|
|
# Gradient clipping (prevents exploding gradients)
|
||
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||
|
|
|
||
|
|
# Update weights
|
||
|
|
self.optimizer.step()
|
||
|
|
|
||
|
|
# Accumulate metrics
|
||
|
|
total_loss += losses['total_loss'].item()
|
||
|
|
if 'displacement_loss' in losses:
|
||
|
|
total_disp_loss += losses['displacement_loss'].item()
|
||
|
|
if 'stress_loss' in losses:
|
||
|
|
total_stress_loss += losses['stress_loss'].item()
|
||
|
|
num_batches += 1
|
||
|
|
|
||
|
|
# Print progress
|
||
|
|
if batch_idx % 10 == 0:
|
||
|
|
print(f" Batch {batch_idx}/{len(train_loader)}: "
|
||
|
|
f"Loss={losses['total_loss'].item():.6f}")
|
||
|
|
|
||
|
|
# Average metrics
|
||
|
|
metrics = {
|
||
|
|
'total_loss': total_loss / num_batches,
|
||
|
|
'displacement_loss': total_disp_loss / num_batches,
|
||
|
|
'stress_loss': total_stress_loss / num_batches
|
||
|
|
}
|
||
|
|
|
||
|
|
return metrics
|
||
|
|
|
||
|
|
def validate(self, val_loader):
|
||
|
|
"""
|
||
|
|
Validate model
|
||
|
|
|
||
|
|
Args:
|
||
|
|
val_loader: Validation data loader
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Validation metrics
|
||
|
|
"""
|
||
|
|
self.model.eval()
|
||
|
|
|
||
|
|
total_loss = 0.0
|
||
|
|
total_disp_loss = 0.0
|
||
|
|
total_stress_loss = 0.0
|
||
|
|
num_batches = 0
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for batch in val_loader:
|
||
|
|
# Move batch to device
|
||
|
|
batch = batch.to(self.device)
|
||
|
|
|
||
|
|
# Forward pass
|
||
|
|
predictions = self.model(batch, return_stress=True)
|
||
|
|
|
||
|
|
# Prepare targets
|
||
|
|
targets = {
|
||
|
|
'displacement': batch.y_displacement,
|
||
|
|
}
|
||
|
|
if hasattr(batch, 'y_stress'):
|
||
|
|
targets['stress'] = batch.y_stress
|
||
|
|
|
||
|
|
# Compute loss
|
||
|
|
losses = self.criterion(predictions, targets, batch)
|
||
|
|
|
||
|
|
# Accumulate metrics
|
||
|
|
total_loss += losses['total_loss'].item()
|
||
|
|
if 'displacement_loss' in losses:
|
||
|
|
total_disp_loss += losses['displacement_loss'].item()
|
||
|
|
if 'stress_loss' in losses:
|
||
|
|
total_stress_loss += losses['stress_loss'].item()
|
||
|
|
num_batches += 1
|
||
|
|
|
||
|
|
# Average metrics
|
||
|
|
metrics = {
|
||
|
|
'total_loss': total_loss / num_batches,
|
||
|
|
'displacement_loss': total_disp_loss / num_batches,
|
||
|
|
'stress_loss': total_stress_loss / num_batches
|
||
|
|
}
|
||
|
|
|
||
|
|
return metrics
|
||
|
|
|
||
|
|
def train(self, train_loader, val_loader, num_epochs):
|
||
|
|
"""
|
||
|
|
Main training loop
|
||
|
|
|
||
|
|
Args:
|
||
|
|
train_loader: Training data loader
|
||
|
|
val_loader: Validation data loader
|
||
|
|
num_epochs (int): Number of epochs to train
|
||
|
|
"""
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Starting training for {num_epochs} epochs")
|
||
|
|
print(f"{'='*60}\n")
|
||
|
|
|
||
|
|
for epoch in range(self.start_epoch, num_epochs):
|
||
|
|
epoch_start_time = time.time()
|
||
|
|
|
||
|
|
print(f"Epoch {epoch + 1}/{num_epochs}")
|
||
|
|
print("-" * 60)
|
||
|
|
|
||
|
|
# Train
|
||
|
|
train_metrics = self.train_epoch(train_loader, epoch)
|
||
|
|
|
||
|
|
# Validate
|
||
|
|
val_metrics = self.validate(val_loader)
|
||
|
|
|
||
|
|
epoch_time = time.time() - epoch_start_time
|
||
|
|
|
||
|
|
# Print metrics
|
||
|
|
print(f"\nEpoch {epoch + 1} Results:")
|
||
|
|
print(f" Training Loss: {train_metrics['total_loss']:.6f}")
|
||
|
|
print(f" Displacement: {train_metrics['displacement_loss']:.6f}")
|
||
|
|
print(f" Stress: {train_metrics['stress_loss']:.6f}")
|
||
|
|
print(f" Validation Loss: {val_metrics['total_loss']:.6f}")
|
||
|
|
print(f" Displacement: {val_metrics['displacement_loss']:.6f}")
|
||
|
|
print(f" Stress: {val_metrics['stress_loss']:.6f}")
|
||
|
|
print(f" Time: {epoch_time:.1f}s")
|
||
|
|
|
||
|
|
# Log to TensorBoard
|
||
|
|
self.writer.add_scalar('Loss/train', train_metrics['total_loss'], epoch)
|
||
|
|
self.writer.add_scalar('Loss/val', val_metrics['total_loss'], epoch)
|
||
|
|
self.writer.add_scalar('DisplacementLoss/train', train_metrics['displacement_loss'], epoch)
|
||
|
|
self.writer.add_scalar('DisplacementLoss/val', val_metrics['displacement_loss'], epoch)
|
||
|
|
self.writer.add_scalar('StressLoss/train', train_metrics['stress_loss'], epoch)
|
||
|
|
self.writer.add_scalar('StressLoss/val', val_metrics['stress_loss'], epoch)
|
||
|
|
self.writer.add_scalar('LearningRate', self.optimizer.param_groups[0]['lr'], epoch)
|
||
|
|
|
||
|
|
# Learning rate scheduling
|
||
|
|
self.scheduler.step(val_metrics['total_loss'])
|
||
|
|
|
||
|
|
# Save checkpoint
|
||
|
|
is_best = val_metrics['total_loss'] < self.best_val_loss
|
||
|
|
if is_best:
|
||
|
|
self.best_val_loss = val_metrics['total_loss']
|
||
|
|
self.epochs_without_improvement = 0
|
||
|
|
print(f" New best validation loss: {self.best_val_loss:.6f}")
|
||
|
|
else:
|
||
|
|
self.epochs_without_improvement += 1
|
||
|
|
|
||
|
|
self.save_checkpoint(epoch, val_metrics, is_best)
|
||
|
|
|
||
|
|
# Early stopping
|
||
|
|
patience = self.config.get('early_stopping_patience', 50)
|
||
|
|
if self.epochs_without_improvement >= patience:
|
||
|
|
print(f"\nEarly stopping after {patience} epochs without improvement")
|
||
|
|
break
|
||
|
|
|
||
|
|
print()
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print("Training complete!")
|
||
|
|
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||
|
|
print(f"{'='*60}\n")
|
||
|
|
|
||
|
|
self.writer.close()
|
||
|
|
|
||
|
|
def save_checkpoint(self, epoch, metrics, is_best=False):
|
||
|
|
"""
|
||
|
|
Save model checkpoint
|
||
|
|
|
||
|
|
Args:
|
||
|
|
epoch (int): Current epoch
|
||
|
|
metrics (dict): Validation metrics
|
||
|
|
is_best (bool): Whether this is the best model so far
|
||
|
|
"""
|
||
|
|
checkpoint = {
|
||
|
|
'epoch': epoch,
|
||
|
|
'model_state_dict': self.model.state_dict(),
|
||
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
||
|
|
'best_val_loss': self.best_val_loss,
|
||
|
|
'config': self.config,
|
||
|
|
'metrics': metrics
|
||
|
|
}
|
||
|
|
|
||
|
|
# Save latest checkpoint
|
||
|
|
checkpoint_path = self.output_dir / 'checkpoint_latest.pt'
|
||
|
|
torch.save(checkpoint, checkpoint_path)
|
||
|
|
|
||
|
|
# Save best checkpoint
|
||
|
|
if is_best:
|
||
|
|
best_path = self.output_dir / 'checkpoint_best.pt'
|
||
|
|
torch.save(checkpoint, best_path)
|
||
|
|
print(f" Saved best model to {best_path}")
|
||
|
|
|
||
|
|
# Save periodic checkpoint
|
||
|
|
if (epoch + 1) % 10 == 0:
|
||
|
|
periodic_path = self.output_dir / f'checkpoint_epoch_{epoch + 1}.pt'
|
||
|
|
torch.save(checkpoint, periodic_path)
|
||
|
|
|
||
|
|
def load_checkpoint(self, checkpoint_path):
|
||
|
|
"""
|
||
|
|
Load model checkpoint
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpoint_path (str): Path to checkpoint file
|
||
|
|
"""
|
||
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
|
|
||
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
|
|
self.start_epoch = checkpoint['epoch'] + 1
|
||
|
|
self.best_val_loss = checkpoint['best_val_loss']
|
||
|
|
|
||
|
|
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
|
||
|
|
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""
|
||
|
|
Main training entry point
|
||
|
|
"""
|
||
|
|
parser = argparse.ArgumentParser(description='Train AtomizerField neural field predictor')
|
||
|
|
|
||
|
|
# Data arguments
|
||
|
|
parser.add_argument('--train_dir', type=str, required=True,
|
||
|
|
help='Directory containing training cases')
|
||
|
|
parser.add_argument('--val_dir', type=str, required=True,
|
||
|
|
help='Directory containing validation cases')
|
||
|
|
|
||
|
|
# Training arguments
|
||
|
|
parser.add_argument('--epochs', type=int, default=100,
|
||
|
|
help='Number of training epochs')
|
||
|
|
parser.add_argument('--batch_size', type=int, default=4,
|
||
|
|
help='Batch size')
|
||
|
|
parser.add_argument('--lr', type=float, default=1e-3,
|
||
|
|
help='Learning rate')
|
||
|
|
parser.add_argument('--weight_decay', type=float, default=1e-5,
|
||
|
|
help='Weight decay')
|
||
|
|
|
||
|
|
# Model arguments
|
||
|
|
parser.add_argument('--hidden_dim', type=int, default=128,
|
||
|
|
help='Hidden dimension')
|
||
|
|
parser.add_argument('--num_layers', type=int, default=6,
|
||
|
|
help='Number of GNN layers')
|
||
|
|
parser.add_argument('--dropout', type=float, default=0.1,
|
||
|
|
help='Dropout rate')
|
||
|
|
|
||
|
|
# Loss arguments
|
||
|
|
parser.add_argument('--loss_type', type=str, default='mse',
|
||
|
|
choices=['mse', 'relative', 'physics', 'max'],
|
||
|
|
help='Loss function type')
|
||
|
|
|
||
|
|
# Other arguments
|
||
|
|
parser.add_argument('--output_dir', type=str, default='./runs',
|
||
|
|
help='Output directory for checkpoints and logs')
|
||
|
|
parser.add_argument('--resume', type=str, default=None,
|
||
|
|
help='Path to checkpoint to resume from')
|
||
|
|
parser.add_argument('--num_workers', type=int, default=0,
|
||
|
|
help='Number of data loading workers')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Build configuration
|
||
|
|
config = {
|
||
|
|
'model': {
|
||
|
|
'node_feature_dim': 12, # 3 coords + 6 BCs + 3 loads
|
||
|
|
'edge_feature_dim': 5, # E, nu, rho, G, alpha
|
||
|
|
'hidden_dim': args.hidden_dim,
|
||
|
|
'num_layers': args.num_layers,
|
||
|
|
'dropout': args.dropout
|
||
|
|
},
|
||
|
|
'loss': {
|
||
|
|
'type': args.loss_type
|
||
|
|
},
|
||
|
|
'learning_rate': args.lr,
|
||
|
|
'weight_decay': args.weight_decay,
|
||
|
|
'batch_size': args.batch_size,
|
||
|
|
'num_epochs': args.epochs,
|
||
|
|
'output_dir': args.output_dir,
|
||
|
|
'early_stopping_patience': 50
|
||
|
|
}
|
||
|
|
|
||
|
|
# Find all case directories
|
||
|
|
train_cases = list(Path(args.train_dir).glob('*/'))
|
||
|
|
val_cases = list(Path(args.val_dir).glob('*/'))
|
||
|
|
|
||
|
|
print(f"Found {len(train_cases)} training cases")
|
||
|
|
print(f"Found {len(val_cases)} validation cases")
|
||
|
|
|
||
|
|
if not train_cases or not val_cases:
|
||
|
|
print("ERROR: No training or validation cases found!")
|
||
|
|
print("Please ensure your directories contain parsed FEA data.")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Create data loaders
|
||
|
|
train_loader, val_loader = create_dataloaders(
|
||
|
|
train_cases,
|
||
|
|
val_cases,
|
||
|
|
batch_size=args.batch_size,
|
||
|
|
num_workers=args.num_workers,
|
||
|
|
normalize=True,
|
||
|
|
include_stress=True
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create trainer
|
||
|
|
trainer = Trainer(config)
|
||
|
|
|
||
|
|
# Resume from checkpoint if specified
|
||
|
|
if args.resume:
|
||
|
|
trainer.load_checkpoint(args.resume)
|
||
|
|
|
||
|
|
# Train
|
||
|
|
trainer.train(train_loader, val_loader, args.epochs)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|