Files
Atomizer/atomizer-field/train_parametric.py

790 lines
27 KiB
Python
Raw Normal View History

"""
train_parametric.py
Training script for AtomizerField parametric predictor
AtomizerField Parametric Training Pipeline v2.0
Trains design-conditioned GNN to predict optimization objectives directly.
Usage:
python train_parametric.py --train_dir ./training_data --val_dir ./validation_data
Key Differences from train.py (field predictor):
- Predicts scalar objectives (mass, frequency, displacement, stress) instead of fields
- Uses design parameters as conditioning input
- Multi-objective loss function for all 4 outputs
- Faster training due to simpler output structure
Output:
checkpoint_best.pt containing:
- model_state_dict: Trained weights
- config: Model configuration
- normalization: Normalization statistics for inference
- design_var_names: Names of design variables
- best_val_loss: Best validation loss achieved
"""
import argparse
import json
import sys
from pathlib import Path
import time
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import h5py
# Try to import tensorboard, but make it optional
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
from neural_models.parametric_predictor import create_parametric_model, ParametricFieldPredictor
class ParametricDataset(Dataset):
"""
PyTorch Dataset for parametric training.
Loads training data exported by Atomizer's TrainingDataExporter
and prepares it for the parametric predictor.
Expected directory structure:
training_data/
trial_0000/
metadata.json (design params + results)
input/
neural_field_data.h5 (mesh data)
trial_0001/
...
"""
def __init__(
self,
data_dir: Path,
normalize: bool = True,
cache_in_memory: bool = False
):
"""
Initialize dataset.
Args:
data_dir: Directory containing trial_* subdirectories
normalize: Whether to normalize inputs/outputs
cache_in_memory: Cache all data in RAM (faster but memory-intensive)
"""
self.data_dir = Path(data_dir)
self.normalize = normalize
self.cache_in_memory = cache_in_memory
# Find all valid trial directories
self.trial_dirs = sorted([
d for d in self.data_dir.glob("trial_*")
if d.is_dir() and self._is_valid_trial(d)
])
print(f"Found {len(self.trial_dirs)} valid trials in {data_dir}")
if len(self.trial_dirs) == 0:
raise ValueError(f"No valid trial directories found in {data_dir}")
# Extract design variable names from first trial
self.design_var_names = self._get_design_var_names()
print(f"Design variables: {self.design_var_names}")
# Compute normalization statistics
if normalize:
self._compute_normalization_stats()
# Cache data if requested
self.cache = {}
if cache_in_memory:
print("Caching data in memory...")
for idx in range(len(self.trial_dirs)):
self.cache[idx] = self._load_trial(idx)
print("Cache complete!")
def _is_valid_trial(self, trial_dir: Path) -> bool:
"""Check if trial directory has required files."""
metadata_file = trial_dir / "metadata.json"
# Check for metadata
if not metadata_file.exists():
return False
# Check metadata has required fields
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
has_design = 'design_parameters' in metadata
has_results = 'results' in metadata
return has_design and has_results
except:
return False
def _get_design_var_names(self) -> List[str]:
"""Extract design variable names from first trial."""
metadata_file = self.trial_dirs[0] / "metadata.json"
with open(metadata_file, 'r') as f:
metadata = json.load(f)
return list(metadata['design_parameters'].keys())
def _compute_normalization_stats(self):
"""Compute normalization statistics across all trials."""
print("Computing normalization statistics...")
all_design_params = []
all_mass = []
all_disp = []
all_stiffness = []
for trial_dir in self.trial_dirs:
with open(trial_dir / "metadata.json", 'r') as f:
metadata = json.load(f)
# Design parameters
design_params = [metadata['design_parameters'][name]
for name in self.design_var_names]
all_design_params.append(design_params)
# Results
results = metadata.get('results', {})
objectives = results.get('objectives', results)
if 'mass' in objectives:
all_mass.append(objectives['mass'])
if 'max_displacement' in results:
all_disp.append(results['max_displacement'])
elif 'max_displacement' in objectives:
all_disp.append(objectives['max_displacement'])
if 'stiffness' in objectives:
all_stiffness.append(objectives['stiffness'])
# Convert to numpy arrays
all_design_params = np.array(all_design_params)
# Compute statistics
self.design_mean = torch.from_numpy(all_design_params.mean(axis=0)).float()
self.design_std = torch.from_numpy(all_design_params.std(axis=0)).float()
self.design_std = torch.clamp(self.design_std, min=1e-6) # Prevent division by zero
# Output statistics
self.mass_mean = np.mean(all_mass) if all_mass else 0.1
self.mass_std = np.std(all_mass) if all_mass else 0.05
self.mass_std = max(self.mass_std, 1e-6)
self.disp_mean = np.mean(all_disp) if all_disp else 0.01
self.disp_std = np.std(all_disp) if all_disp else 0.005
self.disp_std = max(self.disp_std, 1e-6)
self.stiffness_mean = np.mean(all_stiffness) if all_stiffness else 20000.0
self.stiffness_std = np.std(all_stiffness) if all_stiffness else 5000.0
self.stiffness_std = max(self.stiffness_std, 1e-6)
# Frequency and stress defaults (if not available in data)
self.freq_mean = 18.0
self.freq_std = 5.0
self.stress_mean = 200.0
self.stress_std = 50.0
print(f" Design mean: {self.design_mean.numpy()}")
print(f" Design std: {self.design_std.numpy()}")
print(f" Mass: {self.mass_mean:.4f} +/- {self.mass_std:.4f}")
print(f" Displacement: {self.disp_mean:.6f} +/- {self.disp_std:.6f}")
print(f" Stiffness: {self.stiffness_mean:.2f} +/- {self.stiffness_std:.2f}")
def __len__(self) -> int:
return len(self.trial_dirs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
if self.cache_in_memory and idx in self.cache:
return self.cache[idx]
return self._load_trial(idx)
def _load_trial(self, idx: int) -> Dict[str, torch.Tensor]:
"""Load and process a single trial."""
trial_dir = self.trial_dirs[idx]
# Load metadata
with open(trial_dir / "metadata.json", 'r') as f:
metadata = json.load(f)
# Extract design parameters
design_params = [metadata['design_parameters'][name]
for name in self.design_var_names]
design_tensor = torch.tensor(design_params, dtype=torch.float32)
# Normalize design parameters
if self.normalize:
design_tensor = (design_tensor - self.design_mean) / self.design_std
# Extract results
results = metadata.get('results', {})
objectives = results.get('objectives', results)
# Get targets (with fallbacks)
mass = objectives.get('mass', 0.1)
stiffness = objectives.get('stiffness', 20000.0)
max_displacement = results.get('max_displacement',
objectives.get('max_displacement', 0.01))
# Frequency and stress might not be available
frequency = objectives.get('frequency', self.freq_mean)
max_stress = objectives.get('max_stress', self.stress_mean)
# Create target tensor
targets = torch.tensor([mass, frequency, max_displacement, max_stress],
dtype=torch.float32)
# Normalize targets
if self.normalize:
targets[0] = (targets[0] - self.mass_mean) / self.mass_std
targets[1] = (targets[1] - self.freq_mean) / self.freq_std
targets[2] = (targets[2] - self.disp_mean) / self.disp_std
targets[3] = (targets[3] - self.stress_mean) / self.stress_std
# Try to load mesh data if available
mesh_data = self._load_mesh_data(trial_dir)
return {
'design_params': design_tensor,
'targets': targets,
'mesh_data': mesh_data,
'trial_dir': str(trial_dir)
}
def _load_mesh_data(self, trial_dir: Path) -> Optional[Dict[str, torch.Tensor]]:
"""Load mesh data from H5 file if available."""
h5_paths = [
trial_dir / "input" / "neural_field_data.h5",
trial_dir / "neural_field_data.h5",
]
for h5_path in h5_paths:
if h5_path.exists():
try:
with h5py.File(h5_path, 'r') as f:
node_coords = torch.from_numpy(f['mesh/node_coordinates'][:]).float()
# Build simple edge index from connectivity if available
# For now, return just coordinates
return {
'node_coords': node_coords,
'num_nodes': node_coords.shape[0]
}
except Exception as e:
print(f"Warning: Could not load mesh from {h5_path}: {e}")
return None
def get_normalization_stats(self) -> Dict[str, Any]:
"""Return normalization statistics for saving with model."""
return {
'design_mean': self.design_mean.numpy().tolist(),
'design_std': self.design_std.numpy().tolist(),
'mass_mean': float(self.mass_mean),
'mass_std': float(self.mass_std),
'freq_mean': float(self.freq_mean),
'freq_std': float(self.freq_std),
'max_disp_mean': float(self.disp_mean),
'max_disp_std': float(self.disp_std),
'max_stress_mean': float(self.stress_mean),
'max_stress_std': float(self.stress_std),
}
def create_reference_graph(num_nodes: int = 500, device: torch.device = None):
"""
Create a reference graph structure for the GNN.
In production, this would come from the actual mesh.
For now, create a simple grid-like structure.
"""
if device is None:
device = torch.device('cpu')
# Create simple node features (placeholder)
x = torch.randn(num_nodes, 12, device=device)
# Create grid-like connectivity - ensure valid indices
edges = []
grid_size = int(np.ceil(np.sqrt(num_nodes)))
for i in range(num_nodes):
row = i // grid_size
col = i % grid_size
# Right neighbor (same row)
right = i + 1
if col < grid_size - 1 and right < num_nodes:
edges.append([i, right])
edges.append([right, i])
# Bottom neighbor (next row)
bottom = i + grid_size
if bottom < num_nodes:
edges.append([i, bottom])
edges.append([bottom, i])
# Ensure we have at least some edges
if len(edges) == 0:
# Fallback: fully connected for very small graphs
for i in range(num_nodes):
for j in range(i + 1, min(i + 5, num_nodes)):
edges.append([i, j])
edges.append([j, i])
edge_index = torch.tensor(edges, dtype=torch.long, device=device).t().contiguous()
edge_attr = torch.randn(edge_index.shape[1], 5, device=device)
from torch_geometric.data import Data
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
class ParametricTrainer:
"""
Training manager for parametric predictor models.
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize trainer.
Args:
config: Training configuration
"""
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n{'='*60}")
print("AtomizerField Parametric Training Pipeline v2.0")
print(f"{'='*60}")
print(f"Device: {self.device}")
# Create model
print("\nCreating parametric model...")
model_config = config.get('model', {})
self.model = create_parametric_model(model_config)
self.model = self.model.to(self.device)
num_params = self.model.get_num_parameters()
print(f"Model created: {num_params:,} parameters")
# Create optimizer
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.get('learning_rate', 1e-3),
weight_decay=config.get('weight_decay', 1e-5)
)
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=10,
verbose=True
)
# Multi-objective loss weights
self.loss_weights = config.get('loss_weights', {
'mass': 1.0,
'frequency': 1.0,
'displacement': 1.0,
'stress': 1.0
})
# Training state
self.start_epoch = 0
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
# Create output directories
self.output_dir = Path(config.get('output_dir', './runs/parametric'))
self.output_dir.mkdir(parents=True, exist_ok=True)
# TensorBoard logging (optional)
self.writer = None
if TENSORBOARD_AVAILABLE:
self.writer = SummaryWriter(log_dir=self.output_dir / 'tensorboard')
# Create reference graph for inference
self.reference_graph = create_reference_graph(
num_nodes=config.get('reference_nodes', 500),
device=self.device
)
# Save config
with open(self.output_dir / 'config.json', 'w') as f:
json.dump(config, f, indent=2)
def compute_loss(
self,
predictions: Dict[str, torch.Tensor],
targets: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute multi-objective loss.
Args:
predictions: Model outputs (mass, frequency, max_displacement, max_stress)
targets: Target values [batch, 4]
Returns:
total_loss: Combined loss tensor
loss_dict: Individual losses for logging
"""
# MSE losses for each objective
mass_loss = nn.functional.mse_loss(predictions['mass'], targets[:, 0])
freq_loss = nn.functional.mse_loss(predictions['frequency'], targets[:, 1])
disp_loss = nn.functional.mse_loss(predictions['max_displacement'], targets[:, 2])
stress_loss = nn.functional.mse_loss(predictions['max_stress'], targets[:, 3])
# Weighted combination
total_loss = (
self.loss_weights['mass'] * mass_loss +
self.loss_weights['frequency'] * freq_loss +
self.loss_weights['displacement'] * disp_loss +
self.loss_weights['stress'] * stress_loss
)
loss_dict = {
'total': total_loss.item(),
'mass': mass_loss.item(),
'frequency': freq_loss.item(),
'displacement': disp_loss.item(),
'stress': stress_loss.item()
}
return total_loss, loss_dict
def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train for one epoch."""
self.model.train()
total_losses = {'total': 0, 'mass': 0, 'frequency': 0, 'displacement': 0, 'stress': 0}
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
# Get data
design_params = batch['design_params'].to(self.device)
targets = batch['targets'].to(self.device)
# Zero gradients
self.optimizer.zero_grad()
# Forward pass (using reference graph)
predictions = self.model(self.reference_graph, design_params)
# Compute loss
loss, loss_dict = self.compute_loss(predictions, targets)
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update weights
self.optimizer.step()
# Accumulate losses
for key in total_losses:
total_losses[key] += loss_dict[key]
num_batches += 1
# Print progress
if batch_idx % 10 == 0:
print(f" Batch {batch_idx}/{len(train_loader)}: Loss={loss_dict['total']:.6f}")
# Average losses
return {k: v / num_batches for k, v in total_losses.items()}
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
"""Validate model."""
self.model.eval()
total_losses = {'total': 0, 'mass': 0, 'frequency': 0, 'displacement': 0, 'stress': 0}
num_batches = 0
with torch.no_grad():
for batch in val_loader:
design_params = batch['design_params'].to(self.device)
targets = batch['targets'].to(self.device)
predictions = self.model(self.reference_graph, design_params)
_, loss_dict = self.compute_loss(predictions, targets)
for key in total_losses:
total_losses[key] += loss_dict[key]
num_batches += 1
return {k: v / num_batches for k, v in total_losses.items()}
def train(
self,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int,
train_dataset: ParametricDataset
):
"""
Main training loop.
Args:
train_loader: Training data loader
val_loader: Validation data loader
num_epochs: Number of epochs
train_dataset: Training dataset (for normalization stats)
"""
print(f"\n{'='*60}")
print(f"Starting training for {num_epochs} epochs")
print(f"{'='*60}\n")
for epoch in range(self.start_epoch, num_epochs):
epoch_start = time.time()
print(f"Epoch {epoch + 1}/{num_epochs}")
print("-" * 60)
# Train
train_metrics = self.train_epoch(train_loader, epoch)
# Validate
val_metrics = self.validate(val_loader)
epoch_time = time.time() - epoch_start
# Print metrics
print(f"\nEpoch {epoch + 1} Results:")
print(f" Training Loss: {train_metrics['total']:.6f}")
print(f" Mass: {train_metrics['mass']:.6f}, Freq: {train_metrics['frequency']:.6f}")
print(f" Disp: {train_metrics['displacement']:.6f}, Stress: {train_metrics['stress']:.6f}")
print(f" Validation Loss: {val_metrics['total']:.6f}")
print(f" Mass: {val_metrics['mass']:.6f}, Freq: {val_metrics['frequency']:.6f}")
print(f" Disp: {val_metrics['displacement']:.6f}, Stress: {val_metrics['stress']:.6f}")
print(f" Time: {epoch_time:.1f}s")
# Log to TensorBoard
if self.writer:
self.writer.add_scalar('Loss/train', train_metrics['total'], epoch)
self.writer.add_scalar('Loss/val', val_metrics['total'], epoch)
for key in ['mass', 'frequency', 'displacement', 'stress']:
self.writer.add_scalar(f'{key}/train', train_metrics[key], epoch)
self.writer.add_scalar(f'{key}/val', val_metrics[key], epoch)
# Learning rate scheduling
self.scheduler.step(val_metrics['total'])
# Save checkpoint
is_best = val_metrics['total'] < self.best_val_loss
if is_best:
self.best_val_loss = val_metrics['total']
self.epochs_without_improvement = 0
print(f" New best validation loss: {self.best_val_loss:.6f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics, train_dataset, is_best)
# Early stopping
patience = self.config.get('early_stopping_patience', 50)
if self.epochs_without_improvement >= patience:
print(f"\nEarly stopping after {patience} epochs without improvement")
break
print()
print(f"\n{'='*60}")
print("Training complete!")
print(f"Best validation loss: {self.best_val_loss:.6f}")
print(f"{'='*60}\n")
if self.writer:
self.writer.close()
def save_checkpoint(
self,
epoch: int,
metrics: Dict[str, float],
dataset: ParametricDataset,
is_best: bool = False
):
"""Save model checkpoint with all required metadata."""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_val_loss': self.best_val_loss,
'config': self.model.config,
'normalization': dataset.get_normalization_stats(),
'design_var_names': dataset.design_var_names,
'metrics': metrics
}
# Save latest
torch.save(checkpoint, self.output_dir / 'checkpoint_latest.pt')
# Save best
if is_best:
best_path = self.output_dir / 'checkpoint_best.pt'
torch.save(checkpoint, best_path)
print(f" Saved best model to {best_path}")
# Periodic checkpoint
if (epoch + 1) % 10 == 0:
torch.save(checkpoint, self.output_dir / f'checkpoint_epoch_{epoch + 1}.pt')
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Custom collate function for DataLoader."""
design_params = torch.stack([item['design_params'] for item in batch])
targets = torch.stack([item['targets'] for item in batch])
return {
'design_params': design_params,
'targets': targets
}
def main():
"""Main training entry point."""
parser = argparse.ArgumentParser(
description='Train AtomizerField parametric predictor'
)
# Data arguments
parser.add_argument('--train_dir', type=str, required=True,
help='Directory containing training trial_* subdirs')
parser.add_argument('--val_dir', type=str, default=None,
help='Directory containing validation data (uses split if not provided)')
parser.add_argument('--val_split', type=float, default=0.2,
help='Validation split ratio if val_dir not provided')
# Training arguments
parser.add_argument('--epochs', type=int, default=200,
help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=16,
help='Batch size')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-5,
help='Weight decay')
# Model arguments
parser.add_argument('--hidden_channels', type=int, default=128,
help='Hidden dimension')
parser.add_argument('--num_layers', type=int, default=4,
help='Number of GNN layers')
parser.add_argument('--dropout', type=float, default=0.1,
help='Dropout rate')
# Output arguments
parser.add_argument('--output_dir', type=str, default='./runs/parametric',
help='Output directory')
parser.add_argument('--resume', type=str, default=None,
help='Path to checkpoint to resume from')
args = parser.parse_args()
# Create datasets
print("\nLoading training data...")
train_dataset = ParametricDataset(args.train_dir, normalize=True)
if args.val_dir:
val_dataset = ParametricDataset(args.val_dir, normalize=True)
# Share normalization stats
val_dataset.design_mean = train_dataset.design_mean
val_dataset.design_std = train_dataset.design_std
else:
# Split training data
n_total = len(train_dataset)
n_val = int(n_total * args.val_split)
n_train = n_total - n_val
train_dataset, val_dataset = torch.utils.data.random_split(
train_dataset, [n_train, n_val]
)
print(f"Split: {n_train} train, {n_val} validation")
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=0
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=0
)
# Get design dimension from dataset
if hasattr(train_dataset, 'design_var_names'):
design_dim = len(train_dataset.design_var_names)
else:
# For split dataset, access underlying dataset
design_dim = len(train_dataset.dataset.design_var_names)
# Build configuration
config = {
'model': {
'input_channels': 12,
'edge_dim': 5,
'hidden_channels': args.hidden_channels,
'num_layers': args.num_layers,
'design_dim': design_dim,
'dropout': args.dropout
},
'learning_rate': args.learning_rate,
'weight_decay': args.weight_decay,
'batch_size': args.batch_size,
'num_epochs': args.epochs,
'output_dir': args.output_dir,
'early_stopping_patience': 50,
'loss_weights': {
'mass': 1.0,
'frequency': 1.0,
'displacement': 1.0,
'stress': 1.0
}
}
# Create trainer
trainer = ParametricTrainer(config)
# Resume if specified
if args.resume:
checkpoint = torch.load(args.resume, map_location=trainer.device)
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.start_epoch = checkpoint['epoch'] + 1
trainer.best_val_loss = checkpoint['best_val_loss']
print(f"Resumed from epoch {checkpoint['epoch']}")
# Get base dataset for normalization stats
base_dataset = train_dataset
if hasattr(train_dataset, 'dataset'):
base_dataset = train_dataset.dataset
# Train
trainer.train(train_loader, val_loader, args.epochs, base_dataset)
if __name__ == "__main__":
main()