This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
601 lines
20 KiB
Python
601 lines
20 KiB
Python
"""
|
|
Training Pipeline for ZernikeGNN
|
|
=================================
|
|
|
|
This module provides the complete training pipeline for the Zernike GNN surrogate.
|
|
|
|
Training Flow:
|
|
1. Load displacement field data from gnn_data/ folders
|
|
2. Interpolate to fixed polar grid
|
|
3. Normalize inputs (design vars) and outputs (displacements)
|
|
4. Train with multi-task loss (field + objectives)
|
|
5. Validate on held-out data
|
|
6. Save best model checkpoint
|
|
|
|
Usage:
|
|
# Command line
|
|
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
|
|
|
|
# Python API
|
|
from optimization_engine.gnn.train_zernike_gnn import ZernikeGNNTrainer
|
|
|
|
trainer = ZernikeGNNTrainer(['V11', 'V12'])
|
|
trainer.train(epochs=200)
|
|
trainer.save_checkpoint('model.pt')
|
|
"""
|
|
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime
|
|
import sys
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
|
|
from optimization_engine.gnn.zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model
|
|
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer, ZernikeRMSLoss
|
|
|
|
|
|
class MirrorDataset(Dataset):
|
|
"""PyTorch Dataset for mirror displacement fields."""
|
|
|
|
def __init__(
|
|
self,
|
|
data_list: List[Dict[str, Any]],
|
|
design_mean: Optional[torch.Tensor] = None,
|
|
design_std: Optional[torch.Tensor] = None,
|
|
disp_scale: float = 1e6 # mm → μm for numerical stability
|
|
):
|
|
"""
|
|
Args:
|
|
data_list: Output from create_mirror_dataset()
|
|
design_mean: Mean for design normalization (computed if None)
|
|
design_std: Std for design normalization (computed if None)
|
|
disp_scale: Scale factor for displacements
|
|
"""
|
|
self.data_list = data_list
|
|
self.disp_scale = disp_scale
|
|
|
|
# Stack all design variables for normalization
|
|
all_designs = np.stack([d['design_vars'] for d in data_list])
|
|
|
|
if design_mean is None:
|
|
self.design_mean = torch.tensor(np.mean(all_designs, axis=0), dtype=torch.float32)
|
|
else:
|
|
self.design_mean = design_mean
|
|
|
|
if design_std is None:
|
|
self.design_std = torch.tensor(np.std(all_designs, axis=0) + 1e-6, dtype=torch.float32)
|
|
else:
|
|
self.design_std = design_std
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data_list)
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
item = self.data_list[idx]
|
|
|
|
# Normalize design variables
|
|
design = torch.tensor(item['design_vars'], dtype=torch.float32)
|
|
design_norm = (design - self.design_mean) / self.design_std
|
|
|
|
# Scale displacements for numerical stability
|
|
z_disp = torch.tensor(item['z_displacement'], dtype=torch.float32)
|
|
z_disp_scaled = z_disp * self.disp_scale
|
|
|
|
return {
|
|
'design': design_norm,
|
|
'design_raw': design,
|
|
'z_displacement': z_disp_scaled,
|
|
'trial_number': item['trial_number'],
|
|
}
|
|
|
|
|
|
class ZernikeGNNTrainer:
|
|
"""
|
|
Complete training pipeline for ZernikeGNN.
|
|
|
|
Handles:
|
|
- Data loading and preprocessing
|
|
- Model initialization
|
|
- Training loop with validation
|
|
- Checkpointing
|
|
- Metrics tracking
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
study_versions: List[str],
|
|
base_dir: Optional[Path] = None,
|
|
model_type: str = 'full',
|
|
hidden_dim: int = 128,
|
|
n_layers: int = 6,
|
|
device: str = 'auto'
|
|
):
|
|
"""
|
|
Args:
|
|
study_versions: List of study versions (e.g., ['V11', 'V12'])
|
|
base_dir: Base Atomizer directory
|
|
model_type: 'full' or 'lite'
|
|
hidden_dim: Model hidden dimension
|
|
n_layers: Number of message passing layers
|
|
device: 'cpu', 'cuda', or 'auto'
|
|
"""
|
|
if base_dir is None:
|
|
base_dir = Path(__file__).parent.parent.parent
|
|
|
|
self.base_dir = Path(base_dir)
|
|
self.study_versions = study_versions
|
|
|
|
# Determine device
|
|
if device == 'auto':
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
else:
|
|
self.device = torch.device(device)
|
|
|
|
print(f"[TRAINER] Device: {self.device}", flush=True)
|
|
|
|
# Create polar graph (fixed structure)
|
|
self.polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
|
print(f"[TRAINER] Polar graph: {self.polar_graph.n_nodes} nodes, {self.polar_graph.edge_index.shape[1]} edges", flush=True)
|
|
|
|
# Prepare graph tensors
|
|
self.node_features = torch.tensor(
|
|
self.polar_graph.get_node_features(normalized=True),
|
|
dtype=torch.float32
|
|
).to(self.device)
|
|
|
|
self.edge_index = torch.tensor(
|
|
self.polar_graph.edge_index,
|
|
dtype=torch.long
|
|
).to(self.device)
|
|
|
|
self.edge_attr = torch.tensor(
|
|
self.polar_graph.get_edge_features(normalized=True),
|
|
dtype=torch.float32
|
|
).to(self.device)
|
|
|
|
# Load data
|
|
self._load_data()
|
|
|
|
# Create model
|
|
self.model_config = {
|
|
'model_type': model_type,
|
|
'n_design_vars': len(self.train_dataset.data_list[0]['design_vars']),
|
|
'n_subcases': 4,
|
|
'hidden_dim': hidden_dim,
|
|
'n_layers': n_layers,
|
|
}
|
|
|
|
self.model = create_model(**self.model_config).to(self.device)
|
|
print(f"[TRAINER] Model: {self.model.__class__.__name__} with {sum(p.numel() for p in self.model.parameters()):,} parameters", flush=True)
|
|
|
|
# Objective layer for evaluation
|
|
self.objective_layer = ZernikeObjectiveLayer(self.polar_graph, n_modes=50)
|
|
|
|
# Training state
|
|
self.best_val_loss = float('inf')
|
|
self.history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
|
|
|
|
def _load_data(self):
|
|
"""Load and prepare training data from studies."""
|
|
all_data = []
|
|
|
|
for version in self.study_versions:
|
|
study_dir = self.base_dir / "studies" / f"m1_mirror_adaptive_{version}"
|
|
|
|
if not study_dir.exists():
|
|
print(f"[WARN] Study not found: {study_dir}", flush=True)
|
|
continue
|
|
|
|
print(f"[TRAINER] Loading data from {study_dir.name}...", flush=True)
|
|
dataset = create_mirror_dataset(study_dir, polar_graph=self.polar_graph, verbose=True)
|
|
print(f"[TRAINER] Loaded {len(dataset)} samples", flush=True)
|
|
all_data.extend(dataset)
|
|
|
|
if not all_data:
|
|
raise ValueError("No data loaded!")
|
|
|
|
print(f"[TRAINER] Total samples: {len(all_data)}", flush=True)
|
|
|
|
# Train/val split (80/20)
|
|
np.random.seed(42)
|
|
indices = np.random.permutation(len(all_data))
|
|
n_train = int(0.8 * len(all_data))
|
|
|
|
train_data = [all_data[i] for i in indices[:n_train]]
|
|
val_data = [all_data[i] for i in indices[n_train:]]
|
|
|
|
print(f"[TRAINER] Train: {len(train_data)}, Val: {len(val_data)}", flush=True)
|
|
|
|
# Create datasets
|
|
self.train_dataset = MirrorDataset(train_data)
|
|
self.val_dataset = MirrorDataset(
|
|
val_data,
|
|
design_mean=self.train_dataset.design_mean,
|
|
design_std=self.train_dataset.design_std
|
|
)
|
|
|
|
# Store normalization params for inference
|
|
self.design_mean = self.train_dataset.design_mean
|
|
self.design_std = self.train_dataset.design_std
|
|
self.disp_scale = self.train_dataset.disp_scale
|
|
|
|
def train(
|
|
self,
|
|
epochs: int = 200,
|
|
lr: float = 1e-3,
|
|
weight_decay: float = 1e-5,
|
|
batch_size: int = 4,
|
|
field_weight: float = 1.0,
|
|
patience: int = 50,
|
|
verbose: bool = True
|
|
):
|
|
"""
|
|
Train the GNN model.
|
|
|
|
Args:
|
|
epochs: Number of training epochs
|
|
lr: Learning rate
|
|
weight_decay: Weight decay for regularization
|
|
batch_size: Training batch size
|
|
field_weight: Weight for field loss
|
|
patience: Early stopping patience
|
|
verbose: Print training progress
|
|
"""
|
|
# Create data loaders
|
|
train_loader = DataLoader(
|
|
self.train_dataset, batch_size=batch_size, shuffle=True
|
|
)
|
|
val_loader = DataLoader(
|
|
self.val_dataset, batch_size=batch_size, shuffle=False
|
|
)
|
|
|
|
# Optimizer
|
|
optimizer = torch.optim.AdamW(
|
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
|
|
|
# Training loop
|
|
no_improve = 0
|
|
|
|
for epoch in range(epochs):
|
|
# Training
|
|
self.model.train()
|
|
train_loss = 0.0
|
|
|
|
for batch in train_loader:
|
|
optimizer.zero_grad()
|
|
|
|
# Move to device
|
|
design = batch['design'].to(self.device)
|
|
z_disp_true = batch['z_displacement'].to(self.device)
|
|
|
|
# Forward pass for each sample in batch
|
|
batch_loss = 0.0
|
|
for i in range(design.size(0)):
|
|
z_disp_pred = self.model(
|
|
self.node_features,
|
|
self.edge_index,
|
|
self.edge_attr,
|
|
design[i]
|
|
)
|
|
|
|
# MSE loss on displacement field
|
|
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
|
|
batch_loss = batch_loss + loss
|
|
|
|
batch_loss = batch_loss / design.size(0)
|
|
batch_loss.backward()
|
|
optimizer.step()
|
|
|
|
train_loss += batch_loss.item()
|
|
|
|
train_loss /= len(train_loader)
|
|
scheduler.step()
|
|
|
|
# Validation
|
|
val_loss, val_metrics = self._validate(val_loader)
|
|
|
|
# Track history
|
|
self.history['train_loss'].append(train_loss)
|
|
self.history['val_loss'].append(val_loss)
|
|
self.history['val_r2'].append(val_metrics.get('r2_mean', 0))
|
|
|
|
# Early stopping
|
|
if val_loss < self.best_val_loss:
|
|
self.best_val_loss = val_loss
|
|
self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
|
|
no_improve = 0
|
|
else:
|
|
no_improve += 1
|
|
|
|
if verbose and epoch % 10 == 0:
|
|
print(f"[Epoch {epoch:3d}] Train: {train_loss:.6f}, Val: {val_loss:.6f}, "
|
|
f"R²: {val_metrics.get('r2_mean', 0):.4f}, LR: {scheduler.get_last_lr()[0]:.2e}", flush=True)
|
|
|
|
if no_improve >= patience:
|
|
print(f"[TRAINER] Early stopping at epoch {epoch}", flush=True)
|
|
break
|
|
|
|
# Restore best model
|
|
self.model.load_state_dict(self.best_model_state)
|
|
print(f"[TRAINER] Training complete. Best val loss: {self.best_val_loss:.6f}", flush=True)
|
|
|
|
def _validate(self, val_loader: DataLoader) -> Tuple[float, Dict[str, float]]:
|
|
"""Run validation and compute metrics."""
|
|
self.model.eval()
|
|
val_loss = 0.0
|
|
|
|
all_pred = []
|
|
all_true = []
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
design = batch['design'].to(self.device)
|
|
z_disp_true = batch['z_displacement'].to(self.device)
|
|
|
|
for i in range(design.size(0)):
|
|
z_disp_pred = self.model(
|
|
self.node_features,
|
|
self.edge_index,
|
|
self.edge_attr,
|
|
design[i]
|
|
)
|
|
|
|
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
|
|
val_loss += loss.item()
|
|
|
|
all_pred.append(z_disp_pred.cpu())
|
|
all_true.append(z_disp_true[i].cpu())
|
|
|
|
val_loss /= len(self.val_dataset)
|
|
|
|
# Compute R² for each subcase
|
|
all_pred = torch.stack(all_pred) # [n_val, n_nodes, 4]
|
|
all_true = torch.stack(all_true)
|
|
|
|
r2_per_subcase = []
|
|
for sc in range(4):
|
|
pred_flat = all_pred[:, :, sc].flatten()
|
|
true_flat = all_true[:, :, sc].flatten()
|
|
|
|
ss_res = ((true_flat - pred_flat) ** 2).sum()
|
|
ss_tot = ((true_flat - true_flat.mean()) ** 2).sum()
|
|
r2 = 1 - ss_res / (ss_tot + 1e-8)
|
|
r2_per_subcase.append(r2.item())
|
|
|
|
metrics = {
|
|
'r2_mean': np.mean(r2_per_subcase),
|
|
'r2_per_subcase': r2_per_subcase,
|
|
}
|
|
|
|
return val_loss, metrics
|
|
|
|
def evaluate_objectives(self) -> Dict[str, Any]:
|
|
"""
|
|
Evaluate objective prediction accuracy on validation set.
|
|
|
|
Returns:
|
|
Dictionary with per-objective metrics
|
|
"""
|
|
self.model.eval()
|
|
|
|
obj_pred_all = {k: [] for k in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']}
|
|
obj_true_all = {k: [] for k in obj_pred_all}
|
|
|
|
# Move objective layer to CPU for now (small dataset)
|
|
with torch.no_grad():
|
|
for i in range(len(self.val_dataset)):
|
|
item = self.val_dataset[i]
|
|
|
|
design = item['design'].to(self.device)
|
|
z_disp_true = item['z_displacement'] # Already scaled
|
|
|
|
# Predict
|
|
z_disp_pred = self.model(
|
|
self.node_features,
|
|
self.edge_index,
|
|
self.edge_attr,
|
|
design
|
|
).cpu()
|
|
|
|
# Unscale for objective computation
|
|
z_disp_pred_mm = z_disp_pred / self.disp_scale
|
|
z_disp_true_mm = z_disp_true / self.disp_scale
|
|
|
|
# Compute objectives
|
|
obj_pred = self.objective_layer(z_disp_pred_mm)
|
|
obj_true = self.objective_layer(z_disp_true_mm)
|
|
|
|
for k in obj_pred_all:
|
|
obj_pred_all[k].append(obj_pred[k].item())
|
|
obj_true_all[k].append(obj_true[k].item())
|
|
|
|
# Compute metrics per objective
|
|
results = {}
|
|
for k in obj_pred_all:
|
|
pred = np.array(obj_pred_all[k])
|
|
true = np.array(obj_true_all[k])
|
|
|
|
mae = np.mean(np.abs(pred - true))
|
|
mape = np.mean(np.abs(pred - true) / (np.abs(true) + 1e-6)) * 100
|
|
|
|
ss_res = np.sum((true - pred) ** 2)
|
|
ss_tot = np.sum((true - np.mean(true)) ** 2)
|
|
r2 = 1 - ss_res / (ss_tot + 1e-8)
|
|
|
|
results[k] = {
|
|
'mae': mae,
|
|
'mape': mape,
|
|
'r2': r2,
|
|
'pred_range': [pred.min(), pred.max()],
|
|
'true_range': [true.min(), true.max()],
|
|
}
|
|
|
|
return results
|
|
|
|
def save_checkpoint(self, path: Path) -> None:
|
|
"""Save model checkpoint."""
|
|
path = Path(path)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
'model_state_dict': self.model.state_dict(),
|
|
'config': self.model_config,
|
|
'design_mean': self.design_mean,
|
|
'design_std': self.design_std,
|
|
'disp_scale': self.disp_scale,
|
|
'history': self.history,
|
|
'best_val_loss': self.best_val_loss,
|
|
'study_versions': self.study_versions,
|
|
'timestamp': datetime.now().isoformat(),
|
|
}
|
|
|
|
torch.save(checkpoint, path)
|
|
print(f"[TRAINER] Saved checkpoint to {path}", flush=True)
|
|
|
|
@classmethod
|
|
def load_checkpoint(cls, path: Path, device: str = 'auto') -> 'ZernikeGNNTrainer':
|
|
"""Load trainer from checkpoint."""
|
|
checkpoint = torch.load(path, map_location='cpu')
|
|
|
|
# Create trainer with same config
|
|
trainer = cls(
|
|
study_versions=checkpoint['study_versions'],
|
|
model_type=checkpoint['config']['model_type'],
|
|
hidden_dim=checkpoint['config']['hidden_dim'],
|
|
n_layers=checkpoint['config']['n_layers'],
|
|
device=device,
|
|
)
|
|
|
|
# Load model weights
|
|
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
# Restore normalization
|
|
trainer.design_mean = checkpoint['design_mean']
|
|
trainer.design_std = checkpoint['design_std']
|
|
trainer.disp_scale = checkpoint['disp_scale']
|
|
|
|
# Restore history
|
|
trainer.history = checkpoint['history']
|
|
trainer.best_val_loss = checkpoint['best_val_loss']
|
|
|
|
return trainer
|
|
|
|
def predict(self, design_vars: Dict[str, float]) -> Dict[str, Any]:
|
|
"""
|
|
Make prediction for new design.
|
|
|
|
Args:
|
|
design_vars: Dictionary of design parameter values
|
|
|
|
Returns:
|
|
Dictionary with displacement field and objectives
|
|
"""
|
|
self.model.eval()
|
|
|
|
# Convert to tensor
|
|
design_names = self.train_dataset.data_list[0]['design_names']
|
|
design = torch.tensor(
|
|
[design_vars[name] for name in design_names],
|
|
dtype=torch.float32
|
|
)
|
|
|
|
# Normalize
|
|
design_norm = (design - self.design_mean) / self.design_std
|
|
|
|
with torch.no_grad():
|
|
z_disp_scaled = self.model(
|
|
self.node_features,
|
|
self.edge_index,
|
|
self.edge_attr,
|
|
design_norm.to(self.device)
|
|
).cpu()
|
|
|
|
# Unscale
|
|
z_disp_mm = z_disp_scaled / self.disp_scale
|
|
|
|
# Compute objectives
|
|
objectives = self.objective_layer(z_disp_mm)
|
|
|
|
return {
|
|
'z_displacement': z_disp_mm.numpy(),
|
|
'objectives': {k: v.item() for k, v in objectives.items()},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# CLI
|
|
# =============================================================================
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Train ZernikeGNN surrogate')
|
|
parser.add_argument('studies', nargs='+', help='Study versions (e.g., V11 V12)')
|
|
parser.add_argument('--epochs', type=int, default=200, help='Training epochs')
|
|
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
|
|
parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
|
|
parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension')
|
|
parser.add_argument('--n-layers', type=int, default=6, help='Message passing layers')
|
|
parser.add_argument('--model-type', choices=['full', 'lite'], default='full')
|
|
parser.add_argument('--output', '-o', type=Path, help='Output checkpoint path')
|
|
parser.add_argument('--device', default='auto', help='Device (cpu, cuda, auto)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create trainer
|
|
print("="*60, flush=True)
|
|
print("ZERNIKE GNN TRAINING", flush=True)
|
|
print("="*60, flush=True)
|
|
|
|
trainer = ZernikeGNNTrainer(
|
|
study_versions=args.studies,
|
|
model_type=args.model_type,
|
|
hidden_dim=args.hidden_dim,
|
|
n_layers=args.n_layers,
|
|
device=args.device,
|
|
)
|
|
|
|
# Train
|
|
trainer.train(
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
batch_size=args.batch_size,
|
|
)
|
|
|
|
# Evaluate objectives
|
|
print("\n--- Objective Prediction Evaluation ---", flush=True)
|
|
obj_results = trainer.evaluate_objectives()
|
|
for k, v in obj_results.items():
|
|
print(f"\n{k}:", flush=True)
|
|
print(f" MAE: {v['mae']:.2f} nm", flush=True)
|
|
print(f" MAPE: {v['mape']:.1f}%", flush=True)
|
|
print(f" R²: {v['r2']:.4f}", flush=True)
|
|
|
|
# Save checkpoint
|
|
if args.output:
|
|
output_path = args.output
|
|
else:
|
|
output_path = Path("zernike_gnn_checkpoint.pt")
|
|
|
|
trainer.save_checkpoint(output_path)
|
|
|
|
print("\n" + "="*60, flush=True)
|
|
print("✓ Training complete!", flush=True)
|
|
print("="*60, flush=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|