""" train_parametric.py Training script for AtomizerField parametric predictor AtomizerField Parametric Training Pipeline v2.0 Trains design-conditioned GNN to predict optimization objectives directly. Usage: python train_parametric.py --train_dir ./training_data --val_dir ./validation_data Key Differences from train.py (field predictor): - Predicts scalar objectives (mass, frequency, displacement, stress) instead of fields - Uses design parameters as conditioning input - Multi-objective loss function for all 4 outputs - Faster training due to simpler output structure Output: checkpoint_best.pt containing: - model_state_dict: Trained weights - config: Model configuration - normalization: Normalization statistics for inference - design_var_names: Names of design variables - best_val_loss: Best validation loss achieved """ import argparse import json import sys from pathlib import Path import time from datetime import datetime from typing import Dict, List, Any, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import h5py # Try to import tensorboard, but make it optional try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: TENSORBOARD_AVAILABLE = False from neural_models.parametric_predictor import create_parametric_model, ParametricFieldPredictor class ParametricDataset(Dataset): """ PyTorch Dataset for parametric training. Loads training data exported by Atomizer's TrainingDataExporter and prepares it for the parametric predictor. Expected directory structure: training_data/ ├── trial_0000/ │ ├── metadata.json (design params + results) │ └── input/ │ └── neural_field_data.h5 (mesh data) ├── trial_0001/ │ └── ... """ def __init__( self, data_dir: Path, normalize: bool = True, cache_in_memory: bool = False ): """ Initialize dataset. Args: data_dir: Directory containing trial_* subdirectories normalize: Whether to normalize inputs/outputs cache_in_memory: Cache all data in RAM (faster but memory-intensive) """ self.data_dir = Path(data_dir) self.normalize = normalize self.cache_in_memory = cache_in_memory # Find all valid trial directories self.trial_dirs = sorted([ d for d in self.data_dir.glob("trial_*") if d.is_dir() and self._is_valid_trial(d) ]) print(f"Found {len(self.trial_dirs)} valid trials in {data_dir}") if len(self.trial_dirs) == 0: raise ValueError(f"No valid trial directories found in {data_dir}") # Extract design variable names from first trial self.design_var_names = self._get_design_var_names() print(f"Design variables: {self.design_var_names}") # Compute normalization statistics if normalize: self._compute_normalization_stats() # Cache data if requested self.cache = {} if cache_in_memory: print("Caching data in memory...") for idx in range(len(self.trial_dirs)): self.cache[idx] = self._load_trial(idx) print("Cache complete!") def _is_valid_trial(self, trial_dir: Path) -> bool: """Check if trial directory has required files.""" metadata_file = trial_dir / "metadata.json" # Check for metadata if not metadata_file.exists(): return False # Check metadata has required fields try: with open(metadata_file, 'r') as f: metadata = json.load(f) has_design = 'design_parameters' in metadata has_results = 'results' in metadata return has_design and has_results except: return False def _get_design_var_names(self) -> List[str]: """Extract design variable names from first trial.""" metadata_file = self.trial_dirs[0] / "metadata.json" with open(metadata_file, 'r') as f: metadata = json.load(f) return list(metadata['design_parameters'].keys()) def _compute_normalization_stats(self): """Compute normalization statistics across all trials.""" print("Computing normalization statistics...") all_design_params = [] all_mass = [] all_disp = [] all_stiffness = [] for trial_dir in self.trial_dirs: with open(trial_dir / "metadata.json", 'r') as f: metadata = json.load(f) # Design parameters design_params = [metadata['design_parameters'][name] for name in self.design_var_names] all_design_params.append(design_params) # Results results = metadata.get('results', {}) objectives = results.get('objectives', results) if 'mass' in objectives: all_mass.append(objectives['mass']) if 'max_displacement' in results: all_disp.append(results['max_displacement']) elif 'max_displacement' in objectives: all_disp.append(objectives['max_displacement']) if 'stiffness' in objectives: all_stiffness.append(objectives['stiffness']) # Convert to numpy arrays all_design_params = np.array(all_design_params) # Compute statistics self.design_mean = torch.from_numpy(all_design_params.mean(axis=0)).float() self.design_std = torch.from_numpy(all_design_params.std(axis=0)).float() self.design_std = torch.clamp(self.design_std, min=1e-6) # Prevent division by zero # Output statistics self.mass_mean = np.mean(all_mass) if all_mass else 0.1 self.mass_std = np.std(all_mass) if all_mass else 0.05 self.mass_std = max(self.mass_std, 1e-6) self.disp_mean = np.mean(all_disp) if all_disp else 0.01 self.disp_std = np.std(all_disp) if all_disp else 0.005 self.disp_std = max(self.disp_std, 1e-6) self.stiffness_mean = np.mean(all_stiffness) if all_stiffness else 20000.0 self.stiffness_std = np.std(all_stiffness) if all_stiffness else 5000.0 self.stiffness_std = max(self.stiffness_std, 1e-6) # Frequency and stress defaults (if not available in data) self.freq_mean = 18.0 self.freq_std = 5.0 self.stress_mean = 200.0 self.stress_std = 50.0 print(f" Design mean: {self.design_mean.numpy()}") print(f" Design std: {self.design_std.numpy()}") print(f" Mass: {self.mass_mean:.4f} +/- {self.mass_std:.4f}") print(f" Displacement: {self.disp_mean:.6f} +/- {self.disp_std:.6f}") print(f" Stiffness: {self.stiffness_mean:.2f} +/- {self.stiffness_std:.2f}") def __len__(self) -> int: return len(self.trial_dirs) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: if self.cache_in_memory and idx in self.cache: return self.cache[idx] return self._load_trial(idx) def _load_trial(self, idx: int) -> Dict[str, torch.Tensor]: """Load and process a single trial.""" trial_dir = self.trial_dirs[idx] # Load metadata with open(trial_dir / "metadata.json", 'r') as f: metadata = json.load(f) # Extract design parameters design_params = [metadata['design_parameters'][name] for name in self.design_var_names] design_tensor = torch.tensor(design_params, dtype=torch.float32) # Normalize design parameters if self.normalize: design_tensor = (design_tensor - self.design_mean) / self.design_std # Extract results results = metadata.get('results', {}) objectives = results.get('objectives', results) # Get targets (with fallbacks) mass = objectives.get('mass', 0.1) stiffness = objectives.get('stiffness', 20000.0) max_displacement = results.get('max_displacement', objectives.get('max_displacement', 0.01)) # Frequency and stress might not be available frequency = objectives.get('frequency', self.freq_mean) max_stress = objectives.get('max_stress', self.stress_mean) # Create target tensor targets = torch.tensor([mass, frequency, max_displacement, max_stress], dtype=torch.float32) # Normalize targets if self.normalize: targets[0] = (targets[0] - self.mass_mean) / self.mass_std targets[1] = (targets[1] - self.freq_mean) / self.freq_std targets[2] = (targets[2] - self.disp_mean) / self.disp_std targets[3] = (targets[3] - self.stress_mean) / self.stress_std # Try to load mesh data if available mesh_data = self._load_mesh_data(trial_dir) return { 'design_params': design_tensor, 'targets': targets, 'mesh_data': mesh_data, 'trial_dir': str(trial_dir) } def _load_mesh_data(self, trial_dir: Path) -> Optional[Dict[str, torch.Tensor]]: """Load mesh data from H5 file if available.""" h5_paths = [ trial_dir / "input" / "neural_field_data.h5", trial_dir / "neural_field_data.h5", ] for h5_path in h5_paths: if h5_path.exists(): try: with h5py.File(h5_path, 'r') as f: node_coords = torch.from_numpy(f['mesh/node_coordinates'][:]).float() # Build simple edge index from connectivity if available # For now, return just coordinates return { 'node_coords': node_coords, 'num_nodes': node_coords.shape[0] } except Exception as e: print(f"Warning: Could not load mesh from {h5_path}: {e}") return None def get_normalization_stats(self) -> Dict[str, Any]: """Return normalization statistics for saving with model.""" return { 'design_mean': self.design_mean.numpy().tolist(), 'design_std': self.design_std.numpy().tolist(), 'mass_mean': float(self.mass_mean), 'mass_std': float(self.mass_std), 'freq_mean': float(self.freq_mean), 'freq_std': float(self.freq_std), 'max_disp_mean': float(self.disp_mean), 'max_disp_std': float(self.disp_std), 'max_stress_mean': float(self.stress_mean), 'max_stress_std': float(self.stress_std), } def create_reference_graph(num_nodes: int = 500, device: torch.device = None): """ Create a reference graph structure for the GNN. In production, this would come from the actual mesh. For now, create a simple grid-like structure. """ if device is None: device = torch.device('cpu') # Create simple node features (placeholder) x = torch.randn(num_nodes, 12, device=device) # Create grid-like connectivity - ensure valid indices edges = [] grid_size = int(np.ceil(np.sqrt(num_nodes))) for i in range(num_nodes): row = i // grid_size col = i % grid_size # Right neighbor (same row) right = i + 1 if col < grid_size - 1 and right < num_nodes: edges.append([i, right]) edges.append([right, i]) # Bottom neighbor (next row) bottom = i + grid_size if bottom < num_nodes: edges.append([i, bottom]) edges.append([bottom, i]) # Ensure we have at least some edges if len(edges) == 0: # Fallback: fully connected for very small graphs for i in range(num_nodes): for j in range(i + 1, min(i + 5, num_nodes)): edges.append([i, j]) edges.append([j, i]) edge_index = torch.tensor(edges, dtype=torch.long, device=device).t().contiguous() edge_attr = torch.randn(edge_index.shape[1], 5, device=device) from torch_geometric.data import Data return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) class ParametricTrainer: """ Training manager for parametric predictor models. """ def __init__(self, config: Dict[str, Any]): """ Initialize trainer. Args: config: Training configuration """ self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\n{'='*60}") print("AtomizerField Parametric Training Pipeline v2.0") print(f"{'='*60}") print(f"Device: {self.device}") # Create model print("\nCreating parametric model...") model_config = config.get('model', {}) self.model = create_parametric_model(model_config) self.model = self.model.to(self.device) num_params = self.model.get_num_parameters() print(f"Model created: {num_params:,} parameters") # Create optimizer self.optimizer = optim.AdamW( self.model.parameters(), lr=config.get('learning_rate', 1e-3), weight_decay=config.get('weight_decay', 1e-5) ) # Learning rate scheduler self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=10, verbose=True ) # Multi-objective loss weights self.loss_weights = config.get('loss_weights', { 'mass': 1.0, 'frequency': 1.0, 'displacement': 1.0, 'stress': 1.0 }) # Training state self.start_epoch = 0 self.best_val_loss = float('inf') self.epochs_without_improvement = 0 # Create output directories self.output_dir = Path(config.get('output_dir', './runs/parametric')) self.output_dir.mkdir(parents=True, exist_ok=True) # TensorBoard logging (optional) self.writer = None if TENSORBOARD_AVAILABLE: self.writer = SummaryWriter(log_dir=self.output_dir / 'tensorboard') # Create reference graph for inference self.reference_graph = create_reference_graph( num_nodes=config.get('reference_nodes', 500), device=self.device ) # Save config with open(self.output_dir / 'config.json', 'w') as f: json.dump(config, f, indent=2) def compute_loss( self, predictions: Dict[str, torch.Tensor], targets: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute multi-objective loss. Args: predictions: Model outputs (mass, frequency, max_displacement, max_stress) targets: Target values [batch, 4] Returns: total_loss: Combined loss tensor loss_dict: Individual losses for logging """ # MSE losses for each objective mass_loss = nn.functional.mse_loss(predictions['mass'], targets[:, 0]) freq_loss = nn.functional.mse_loss(predictions['frequency'], targets[:, 1]) disp_loss = nn.functional.mse_loss(predictions['max_displacement'], targets[:, 2]) stress_loss = nn.functional.mse_loss(predictions['max_stress'], targets[:, 3]) # Weighted combination total_loss = ( self.loss_weights['mass'] * mass_loss + self.loss_weights['frequency'] * freq_loss + self.loss_weights['displacement'] * disp_loss + self.loss_weights['stress'] * stress_loss ) loss_dict = { 'total': total_loss.item(), 'mass': mass_loss.item(), 'frequency': freq_loss.item(), 'displacement': disp_loss.item(), 'stress': stress_loss.item() } return total_loss, loss_dict def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]: """Train for one epoch.""" self.model.train() total_losses = {'total': 0, 'mass': 0, 'frequency': 0, 'displacement': 0, 'stress': 0} num_batches = 0 for batch_idx, batch in enumerate(train_loader): # Get data design_params = batch['design_params'].to(self.device) targets = batch['targets'].to(self.device) # Zero gradients self.optimizer.zero_grad() # Forward pass (using reference graph) predictions = self.model(self.reference_graph, design_params) # Compute loss loss, loss_dict = self.compute_loss(predictions, targets) # Backward pass loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # Update weights self.optimizer.step() # Accumulate losses for key in total_losses: total_losses[key] += loss_dict[key] num_batches += 1 # Print progress if batch_idx % 10 == 0: print(f" Batch {batch_idx}/{len(train_loader)}: Loss={loss_dict['total']:.6f}") # Average losses return {k: v / num_batches for k, v in total_losses.items()} def validate(self, val_loader: DataLoader) -> Dict[str, float]: """Validate model.""" self.model.eval() total_losses = {'total': 0, 'mass': 0, 'frequency': 0, 'displacement': 0, 'stress': 0} num_batches = 0 with torch.no_grad(): for batch in val_loader: design_params = batch['design_params'].to(self.device) targets = batch['targets'].to(self.device) predictions = self.model(self.reference_graph, design_params) _, loss_dict = self.compute_loss(predictions, targets) for key in total_losses: total_losses[key] += loss_dict[key] num_batches += 1 return {k: v / num_batches for k, v in total_losses.items()} def train( self, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int, train_dataset: ParametricDataset ): """ Main training loop. Args: train_loader: Training data loader val_loader: Validation data loader num_epochs: Number of epochs train_dataset: Training dataset (for normalization stats) """ print(f"\n{'='*60}") print(f"Starting training for {num_epochs} epochs") print(f"{'='*60}\n") for epoch in range(self.start_epoch, num_epochs): epoch_start = time.time() print(f"Epoch {epoch + 1}/{num_epochs}") print("-" * 60) # Train train_metrics = self.train_epoch(train_loader, epoch) # Validate val_metrics = self.validate(val_loader) epoch_time = time.time() - epoch_start # Print metrics print(f"\nEpoch {epoch + 1} Results:") print(f" Training Loss: {train_metrics['total']:.6f}") print(f" Mass: {train_metrics['mass']:.6f}, Freq: {train_metrics['frequency']:.6f}") print(f" Disp: {train_metrics['displacement']:.6f}, Stress: {train_metrics['stress']:.6f}") print(f" Validation Loss: {val_metrics['total']:.6f}") print(f" Mass: {val_metrics['mass']:.6f}, Freq: {val_metrics['frequency']:.6f}") print(f" Disp: {val_metrics['displacement']:.6f}, Stress: {val_metrics['stress']:.6f}") print(f" Time: {epoch_time:.1f}s") # Log to TensorBoard if self.writer: self.writer.add_scalar('Loss/train', train_metrics['total'], epoch) self.writer.add_scalar('Loss/val', val_metrics['total'], epoch) for key in ['mass', 'frequency', 'displacement', 'stress']: self.writer.add_scalar(f'{key}/train', train_metrics[key], epoch) self.writer.add_scalar(f'{key}/val', val_metrics[key], epoch) # Learning rate scheduling self.scheduler.step(val_metrics['total']) # Save checkpoint is_best = val_metrics['total'] < self.best_val_loss if is_best: self.best_val_loss = val_metrics['total'] self.epochs_without_improvement = 0 print(f" New best validation loss: {self.best_val_loss:.6f}") else: self.epochs_without_improvement += 1 self.save_checkpoint(epoch, val_metrics, train_dataset, is_best) # Early stopping patience = self.config.get('early_stopping_patience', 50) if self.epochs_without_improvement >= patience: print(f"\nEarly stopping after {patience} epochs without improvement") break print() print(f"\n{'='*60}") print("Training complete!") print(f"Best validation loss: {self.best_val_loss:.6f}") print(f"{'='*60}\n") if self.writer: self.writer.close() def save_checkpoint( self, epoch: int, metrics: Dict[str, float], dataset: ParametricDataset, is_best: bool = False ): """Save model checkpoint with all required metadata.""" checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_val_loss': self.best_val_loss, 'config': self.model.config, 'normalization': dataset.get_normalization_stats(), 'design_var_names': dataset.design_var_names, 'metrics': metrics } # Save latest torch.save(checkpoint, self.output_dir / 'checkpoint_latest.pt') # Save best if is_best: best_path = self.output_dir / 'checkpoint_best.pt' torch.save(checkpoint, best_path) print(f" Saved best model to {best_path}") # Periodic checkpoint if (epoch + 1) % 10 == 0: torch.save(checkpoint, self.output_dir / f'checkpoint_epoch_{epoch + 1}.pt') def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: """Custom collate function for DataLoader.""" design_params = torch.stack([item['design_params'] for item in batch]) targets = torch.stack([item['targets'] for item in batch]) return { 'design_params': design_params, 'targets': targets } def main(): """Main training entry point.""" parser = argparse.ArgumentParser( description='Train AtomizerField parametric predictor' ) # Data arguments parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training trial_* subdirs') parser.add_argument('--val_dir', type=str, default=None, help='Directory containing validation data (uses split if not provided)') parser.add_argument('--val_split', type=float, default=0.2, help='Validation split ratio if val_dir not provided') # Training arguments parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=16, help='Batch size') parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate') parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') # Model arguments parser.add_argument('--hidden_channels', type=int, default=128, help='Hidden dimension') parser.add_argument('--num_layers', type=int, default=4, help='Number of GNN layers') parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate') # Output arguments parser.add_argument('--output_dir', type=str, default='./runs/parametric', help='Output directory') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') args = parser.parse_args() # Create datasets print("\nLoading training data...") train_dataset = ParametricDataset(args.train_dir, normalize=True) if args.val_dir: val_dataset = ParametricDataset(args.val_dir, normalize=True) # Share normalization stats val_dataset.design_mean = train_dataset.design_mean val_dataset.design_std = train_dataset.design_std else: # Split training data n_total = len(train_dataset) n_val = int(n_total * args.val_split) n_train = n_total - n_val train_dataset, val_dataset = torch.utils.data.random_split( train_dataset, [n_train, n_val] ) print(f"Split: {n_train} train, {n_val} validation") # Create data loaders train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0 ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0 ) # Get design dimension from dataset if hasattr(train_dataset, 'design_var_names'): design_dim = len(train_dataset.design_var_names) else: # For split dataset, access underlying dataset design_dim = len(train_dataset.dataset.design_var_names) # Build configuration config = { 'model': { 'input_channels': 12, 'edge_dim': 5, 'hidden_channels': args.hidden_channels, 'num_layers': args.num_layers, 'design_dim': design_dim, 'dropout': args.dropout }, 'learning_rate': args.learning_rate, 'weight_decay': args.weight_decay, 'batch_size': args.batch_size, 'num_epochs': args.epochs, 'output_dir': args.output_dir, 'early_stopping_patience': 50, 'loss_weights': { 'mass': 1.0, 'frequency': 1.0, 'displacement': 1.0, 'stress': 1.0 } } # Create trainer trainer = ParametricTrainer(config) # Resume if specified if args.resume: checkpoint = torch.load(args.resume, map_location=trainer.device) trainer.model.load_state_dict(checkpoint['model_state_dict']) trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) trainer.start_epoch = checkpoint['epoch'] + 1 trainer.best_val_loss = checkpoint['best_val_loss'] print(f"Resumed from epoch {checkpoint['epoch']}") # Get base dataset for normalization stats base_dataset = train_dataset if hasattr(train_dataset, 'dataset'): base_dataset = train_dataset.dataset # Train trainer.train(train_loader, val_loader, args.epochs, base_dataset) if __name__ == "__main__": main()