feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
69
optimization_engine/gnn/__init__.py
Normal file
69
optimization_engine/gnn/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
GNN (Graph Neural Network) Surrogate Module for Atomizer
|
||||
=========================================================
|
||||
|
||||
This module provides Graph Neural Network-based surrogates for FEA optimization,
|
||||
particularly designed for Zernike-based mirror optimization where spatial structure
|
||||
matters.
|
||||
|
||||
Key Components:
|
||||
- PolarMirrorGraph: Fixed polar grid graph structure for mirror surface
|
||||
- ZernikeGNN: GNN model for predicting displacement fields
|
||||
- DifferentiableZernikeFit: GPU-accelerated Zernike fitting
|
||||
- ZernikeObjectiveLayer: Compute objectives from displacement fields
|
||||
- ZernikeGNNTrainer: Complete training pipeline
|
||||
|
||||
Why GNN over MLP for Zernike?
|
||||
1. Spatial awareness: GNN learns smooth deformation fields via message passing
|
||||
2. Correct relative computation: Predicts fields, then subtracts (like FEA)
|
||||
3. Multi-task learning: Field + objective supervision
|
||||
4. Physics-informed: Edge structure respects mirror geometry
|
||||
|
||||
Usage:
|
||||
# Training
|
||||
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
|
||||
|
||||
# API
|
||||
from optimization_engine.gnn import PolarMirrorGraph, ZernikeGNN, ZernikeGNNTrainer
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
# Core components
|
||||
from .polar_graph import PolarMirrorGraph, create_mirror_dataset
|
||||
from .zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model, load_model
|
||||
from .differentiable_zernike import (
|
||||
DifferentiableZernikeFit,
|
||||
ZernikeObjectiveLayer,
|
||||
ZernikeRMSLoss,
|
||||
build_zernike_matrix,
|
||||
)
|
||||
from .extract_displacement_field import (
|
||||
extract_displacement_field,
|
||||
save_field,
|
||||
load_field,
|
||||
)
|
||||
from .train_zernike_gnn import ZernikeGNNTrainer, MirrorDataset
|
||||
|
||||
__all__ = [
|
||||
# Polar Graph
|
||||
'PolarMirrorGraph',
|
||||
'create_mirror_dataset',
|
||||
# GNN Model
|
||||
'ZernikeGNN',
|
||||
'ZernikeGNNLite',
|
||||
'create_model',
|
||||
'load_model',
|
||||
# Zernike Layers
|
||||
'DifferentiableZernikeFit',
|
||||
'ZernikeObjectiveLayer',
|
||||
'ZernikeRMSLoss',
|
||||
'build_zernike_matrix',
|
||||
# Field Extraction
|
||||
'extract_displacement_field',
|
||||
'save_field',
|
||||
'load_field',
|
||||
# Training
|
||||
'ZernikeGNNTrainer',
|
||||
'MirrorDataset',
|
||||
]
|
||||
475
optimization_engine/gnn/backfill_field_data.py
Normal file
475
optimization_engine/gnn/backfill_field_data.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""
|
||||
Backfill Displacement Field Data from Existing Trials
|
||||
======================================================
|
||||
|
||||
This script scans existing mirror optimization studies (V11, V12, etc.) and extracts
|
||||
displacement field data from OP2 files for GNN training.
|
||||
|
||||
Structure it expects:
|
||||
studies/m1_mirror_adaptive_V11/
|
||||
├── 2_iterations/
|
||||
│ ├── iter91/
|
||||
│ │ ├── assy_m1_assyfem1_sim1-solution_1.op2
|
||||
│ │ ├── assy_m1_assyfem1_sim1-solution_1.dat
|
||||
│ │ └── params.exp
|
||||
│ ├── iter92/
|
||||
│ │ └── ...
|
||||
└── 3_results/
|
||||
└── study.db (Optuna database)
|
||||
|
||||
Output structure:
|
||||
studies/m1_mirror_adaptive_V11/
|
||||
└── gnn_data/
|
||||
├── trial_0000/
|
||||
│ ├── displacement_field.h5
|
||||
│ └── metadata.json
|
||||
├── trial_0001/
|
||||
│ └── ...
|
||||
└── dataset_index.json (maps iter -> trial)
|
||||
|
||||
Usage:
|
||||
python -m optimization_engine.gnn.backfill_field_data V11
|
||||
python -m optimization_engine.gnn.backfill_field_data V11 V12 --merge
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Add parent to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from optimization_engine.gnn.extract_displacement_field import (
|
||||
extract_displacement_field,
|
||||
save_field,
|
||||
load_field,
|
||||
HAS_H5PY,
|
||||
)
|
||||
|
||||
|
||||
def find_studies(base_dir: Path, pattern: str = "m1_mirror_adaptive_V*") -> List[Path]:
|
||||
"""Find all matching study directories."""
|
||||
studies_dir = base_dir / "studies"
|
||||
matches = list(studies_dir.glob(pattern))
|
||||
return sorted(matches)
|
||||
|
||||
|
||||
def find_op2_files(study_dir: Path) -> List[Tuple[int, Path, Path]]:
|
||||
"""
|
||||
Find all OP2 files in iteration folders.
|
||||
|
||||
Returns:
|
||||
List of (iter_number, op2_path, dat_path) tuples
|
||||
"""
|
||||
iterations_dir = study_dir / "2_iterations"
|
||||
if not iterations_dir.exists():
|
||||
print(f"[WARN] No 2_iterations folder in {study_dir.name}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for iter_dir in sorted(iterations_dir.iterdir()):
|
||||
if not iter_dir.is_dir():
|
||||
continue
|
||||
|
||||
# Extract iteration number
|
||||
match = re.match(r'iter(\d+)', iter_dir.name)
|
||||
if not match:
|
||||
continue
|
||||
iter_num = int(match.group(1))
|
||||
|
||||
# Find OP2 file
|
||||
op2_files = list(iter_dir.glob('*-solution_1.op2'))
|
||||
if not op2_files:
|
||||
op2_files = list(iter_dir.glob('*.op2'))
|
||||
if not op2_files:
|
||||
continue
|
||||
|
||||
op2_path = op2_files[0]
|
||||
|
||||
# Find DAT file
|
||||
dat_path = op2_path.with_suffix('.dat')
|
||||
if not dat_path.exists():
|
||||
dat_path = op2_path.with_suffix('.bdf')
|
||||
if not dat_path.exists():
|
||||
print(f"[WARN] No DAT/BDF for {op2_path.name}, skipping")
|
||||
continue
|
||||
|
||||
results.append((iter_num, op2_path, dat_path))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def read_params_exp(iter_dir: Path) -> Optional[Dict[str, float]]:
|
||||
"""Read design parameters from params.exp file."""
|
||||
params_file = iter_dir / "params.exp"
|
||||
if not params_file.exists():
|
||||
return None
|
||||
|
||||
params = {}
|
||||
with open(params_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if '=' in line:
|
||||
# Format: name = value
|
||||
parts = line.split('=')
|
||||
if len(parts) == 2:
|
||||
name = parts[0].strip()
|
||||
try:
|
||||
value = float(parts[1].strip())
|
||||
params[name] = value
|
||||
except ValueError:
|
||||
pass
|
||||
return params
|
||||
|
||||
|
||||
def backfill_study(
|
||||
study_dir: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
r_inner: float = 100.0,
|
||||
r_outer: float = 650.0,
|
||||
overwrite: bool = False,
|
||||
verbose: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Backfill displacement field data for a single study.
|
||||
|
||||
Args:
|
||||
study_dir: Path to study directory
|
||||
output_dir: Output directory (default: study_dir/gnn_data)
|
||||
r_inner: Inner radius for surface identification
|
||||
r_outer: Outer radius for surface identification
|
||||
overwrite: Overwrite existing field data
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
Summary dictionary with statistics
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = study_dir / "gnn_data"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BACKFILLING: {study_dir.name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Find all OP2 files
|
||||
op2_list = find_op2_files(study_dir)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(op2_list)} iterations with OP2 files")
|
||||
|
||||
# Track results
|
||||
success_count = 0
|
||||
skip_count = 0
|
||||
error_count = 0
|
||||
index = {}
|
||||
|
||||
for iter_num, op2_path, dat_path in op2_list:
|
||||
# Create trial directory
|
||||
trial_dir = output_dir / f"trial_{iter_num:04d}"
|
||||
|
||||
# Check if already exists
|
||||
field_ext = '.h5' if HAS_H5PY else '.npz'
|
||||
field_path = trial_dir / f"displacement_field{field_ext}"
|
||||
|
||||
if field_path.exists() and not overwrite:
|
||||
if verbose:
|
||||
print(f"[SKIP] iter{iter_num}: already processed")
|
||||
skip_count += 1
|
||||
index[iter_num] = {
|
||||
'trial_dir': str(trial_dir.relative_to(study_dir)),
|
||||
'status': 'skipped',
|
||||
}
|
||||
continue
|
||||
|
||||
try:
|
||||
# Extract displacement field
|
||||
if verbose:
|
||||
print(f"[{iter_num:3d}] Extracting from {op2_path.name}...", end=' ')
|
||||
|
||||
field_data = extract_displacement_field(
|
||||
op2_path,
|
||||
bdf_path=dat_path,
|
||||
r_inner=r_inner,
|
||||
r_outer=r_outer,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Save field data
|
||||
trial_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_field(field_data, field_path)
|
||||
|
||||
# Read params if available
|
||||
params = read_params_exp(op2_path.parent)
|
||||
|
||||
# Save metadata
|
||||
meta = {
|
||||
'iter_number': iter_num,
|
||||
'op2_file': str(op2_path.name),
|
||||
'n_nodes': len(field_data['node_ids']),
|
||||
'subcases': list(field_data['z_displacement'].keys()),
|
||||
'params': params,
|
||||
'extraction_timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
meta_path = trial_dir / "metadata.json"
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
if verbose:
|
||||
print(f"OK ({len(field_data['node_ids'])} nodes)")
|
||||
|
||||
success_count += 1
|
||||
index[iter_num] = {
|
||||
'trial_dir': str(trial_dir.relative_to(study_dir)),
|
||||
'n_nodes': len(field_data['node_ids']),
|
||||
'params': params,
|
||||
'status': 'success',
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f"ERROR: {e}")
|
||||
error_count += 1
|
||||
index[iter_num] = {
|
||||
'trial_dir': str(trial_dir.relative_to(study_dir)) if trial_dir.exists() else None,
|
||||
'error': str(e),
|
||||
'status': 'error',
|
||||
}
|
||||
|
||||
# Save index file
|
||||
index_path = output_dir / "dataset_index.json"
|
||||
index_data = {
|
||||
'study_name': study_dir.name,
|
||||
'generated': datetime.now().isoformat(),
|
||||
'summary': {
|
||||
'total': len(op2_list),
|
||||
'success': success_count,
|
||||
'skipped': skip_count,
|
||||
'errors': error_count,
|
||||
},
|
||||
'trials': index,
|
||||
}
|
||||
with open(index_path, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SUMMARY: {study_dir.name}")
|
||||
print(f" Success: {success_count}")
|
||||
print(f" Skipped: {skip_count}")
|
||||
print(f" Errors: {error_count}")
|
||||
print(f" Index: {index_path}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return index_data
|
||||
|
||||
|
||||
def merge_datasets(
|
||||
study_dirs: List[Path],
|
||||
output_dir: Path,
|
||||
train_ratio: float = 0.8,
|
||||
verbose: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge displacement field data from multiple studies into a single dataset.
|
||||
|
||||
Args:
|
||||
study_dirs: List of study directories
|
||||
output_dir: Output directory for merged dataset
|
||||
train_ratio: Fraction of data for training (rest for validation)
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
Dataset metadata dictionary
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print("MERGING DATASETS")
|
||||
print(f"{'='*60}")
|
||||
|
||||
all_trials = []
|
||||
|
||||
for study_dir in study_dirs:
|
||||
gnn_data_dir = study_dir / "gnn_data"
|
||||
index_path = gnn_data_dir / "dataset_index.json"
|
||||
|
||||
if not index_path.exists():
|
||||
print(f"[WARN] No index for {study_dir.name}, run backfill first")
|
||||
continue
|
||||
|
||||
with open(index_path, 'r') as f:
|
||||
index = json.load(f)
|
||||
|
||||
study_name = study_dir.name
|
||||
|
||||
for iter_num, trial_info in index['trials'].items():
|
||||
if trial_info.get('status') != 'success':
|
||||
continue
|
||||
|
||||
trial_dir = study_dir / trial_info['trial_dir']
|
||||
all_trials.append({
|
||||
'study': study_name,
|
||||
'iter': int(iter_num),
|
||||
'trial_dir': trial_dir,
|
||||
'params': trial_info.get('params', {}),
|
||||
'n_nodes': trial_info.get('n_nodes'),
|
||||
})
|
||||
|
||||
if verbose:
|
||||
print(f"Total successful trials: {len(all_trials)}")
|
||||
|
||||
# Shuffle and split
|
||||
np.random.seed(42)
|
||||
indices = np.random.permutation(len(all_trials))
|
||||
n_train = int(len(all_trials) * train_ratio)
|
||||
|
||||
train_indices = indices[:n_train]
|
||||
val_indices = indices[n_train:]
|
||||
|
||||
# Create split files
|
||||
splits = {
|
||||
'train': [all_trials[i] for i in train_indices],
|
||||
'val': [all_trials[i] for i in val_indices],
|
||||
}
|
||||
|
||||
for split_name, trials in splits.items():
|
||||
split_dir = output_dir / split_name
|
||||
split_dir.mkdir(exist_ok=True)
|
||||
|
||||
split_meta = []
|
||||
for i, trial in enumerate(trials):
|
||||
# Copy/link field data
|
||||
src_ext = '.h5' if HAS_H5PY else '.npz'
|
||||
src_path = trial['trial_dir'] / f"displacement_field{src_ext}"
|
||||
dst_path = split_dir / f"sample_{i:04d}{src_ext}"
|
||||
|
||||
if src_path.exists():
|
||||
# Copy file (or could use symlink on Linux)
|
||||
import shutil
|
||||
shutil.copy(src_path, dst_path)
|
||||
|
||||
split_meta.append({
|
||||
'index': i,
|
||||
'source_study': trial['study'],
|
||||
'source_iter': trial['iter'],
|
||||
'params': trial['params'],
|
||||
'n_nodes': trial['n_nodes'],
|
||||
})
|
||||
|
||||
# Save split metadata
|
||||
meta_path = split_dir / "metadata.json"
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump({
|
||||
'split': split_name,
|
||||
'n_samples': len(split_meta),
|
||||
'samples': split_meta,
|
||||
}, f, indent=2)
|
||||
|
||||
if verbose:
|
||||
print(f" {split_name}: {len(split_meta)} samples")
|
||||
|
||||
# Save overall metadata
|
||||
dataset_meta = {
|
||||
'created': datetime.now().isoformat(),
|
||||
'source_studies': [str(s.name) for s in study_dirs],
|
||||
'total_samples': len(all_trials),
|
||||
'train_samples': len(splits['train']),
|
||||
'val_samples': len(splits['val']),
|
||||
'train_ratio': train_ratio,
|
||||
}
|
||||
with open(output_dir / "dataset_meta.json", 'w') as f:
|
||||
json.dump(dataset_meta, f, indent=2)
|
||||
|
||||
if verbose:
|
||||
print(f"\nDataset saved to: {output_dir}")
|
||||
print(f" Train: {len(splits['train'])} samples")
|
||||
print(f" Val: {len(splits['val'])} samples")
|
||||
|
||||
return dataset_meta
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Backfill displacement field data for GNN training'
|
||||
)
|
||||
parser.add_argument('studies', nargs='+', type=str,
|
||||
help='Study versions (e.g., V11 V12) or "all"')
|
||||
parser.add_argument('--merge', action='store_true',
|
||||
help='Merge data from multiple studies')
|
||||
parser.add_argument('--output', '-o', type=Path,
|
||||
help='Output directory for merged dataset')
|
||||
parser.add_argument('--r-inner', type=float, default=100.0,
|
||||
help='Inner radius (mm)')
|
||||
parser.add_argument('--r-outer', type=float, default=650.0,
|
||||
help='Outer radius (mm)')
|
||||
parser.add_argument('--overwrite', action='store_true',
|
||||
help='Overwrite existing field data')
|
||||
parser.add_argument('--train-ratio', type=float, default=0.8,
|
||||
help='Train/val split ratio')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find base directory
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
|
||||
# Find studies
|
||||
if args.studies == ['all']:
|
||||
study_dirs = find_studies(base_dir, "m1_mirror_adaptive_V*")
|
||||
else:
|
||||
study_dirs = []
|
||||
for s in args.studies:
|
||||
if s.startswith('V'):
|
||||
pattern = f"m1_mirror_adaptive_{s}"
|
||||
else:
|
||||
pattern = s
|
||||
matches = find_studies(base_dir, pattern)
|
||||
study_dirs.extend(matches)
|
||||
|
||||
if not study_dirs:
|
||||
print("No studies found!")
|
||||
return 1
|
||||
|
||||
print(f"Found {len(study_dirs)} studies:")
|
||||
for s in study_dirs:
|
||||
print(f" - {s.name}")
|
||||
|
||||
# Backfill each study
|
||||
for study_dir in study_dirs:
|
||||
backfill_study(
|
||||
study_dir,
|
||||
r_inner=args.r_inner,
|
||||
r_outer=args.r_outer,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
|
||||
# Merge if requested
|
||||
if args.merge and len(study_dirs) > 1:
|
||||
output_dir = args.output
|
||||
if output_dir is None:
|
||||
output_dir = base_dir / "studies" / "gnn_merged_dataset"
|
||||
|
||||
merge_datasets(
|
||||
study_dirs,
|
||||
output_dir,
|
||||
train_ratio=args.train_ratio,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
544
optimization_engine/gnn/differentiable_zernike.py
Normal file
544
optimization_engine/gnn/differentiable_zernike.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
Differentiable Zernike Fitting Layer
|
||||
=====================================
|
||||
|
||||
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
|
||||
The key innovation is putting Zernike fitting INSIDE the neural network for
|
||||
end-to-end training.
|
||||
|
||||
Why Differentiable Zernike?
|
||||
|
||||
Current MLP approach:
|
||||
design → MLP → coefficients (learn 200 outputs independently)
|
||||
|
||||
GNN + Differentiable Zernike:
|
||||
design → GNN → displacement field → Zernike fit → coefficients
|
||||
↑
|
||||
Differentiable! Gradients flow back
|
||||
|
||||
This allows the network to learn:
|
||||
1. Spatially coherent displacement fields
|
||||
2. Fields that produce accurate Zernike coefficients
|
||||
3. Correct relative deformation computation
|
||||
|
||||
Components:
|
||||
1. DifferentiableZernikeFit - Fits coefficients from displacement field
|
||||
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.differentiable_zernike import (
|
||||
DifferentiableZernikeFit,
|
||||
ZernikeObjectiveLayer
|
||||
)
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
graph = PolarMirrorGraph()
|
||||
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
||||
|
||||
# In forward pass:
|
||||
z_disp = gnn_model(...) # [n_nodes, 4]
|
||||
objectives = objective_layer(z_disp) # Dict with RMS values
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute Zernike polynomial for Noll index j.
|
||||
|
||||
Uses the standard Noll indexing convention:
|
||||
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
|
||||
|
||||
Args:
|
||||
j: Noll index (1-based)
|
||||
r: Radial coordinates (normalized to [0, 1])
|
||||
theta: Angular coordinates (radians)
|
||||
|
||||
Returns:
|
||||
Zernike polynomial values at (r, theta)
|
||||
"""
|
||||
# Convert Noll index to (n, m)
|
||||
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
|
||||
m_sum = j - n * (n + 1) // 2 - 1
|
||||
|
||||
if n % 2 == 0:
|
||||
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
|
||||
else:
|
||||
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
|
||||
|
||||
if (n - m) % 2 == 1:
|
||||
m = -m
|
||||
|
||||
# Compute radial polynomial R_n^|m|(r)
|
||||
R = np.zeros_like(r)
|
||||
m_abs = abs(m)
|
||||
for k in range((n - m_abs) // 2 + 1):
|
||||
coef = ((-1) ** k * math.factorial(n - k) /
|
||||
(math.factorial(k) *
|
||||
math.factorial((n + m_abs) // 2 - k) *
|
||||
math.factorial((n - m_abs) // 2 - k)))
|
||||
R += coef * r ** (n - 2 * k)
|
||||
|
||||
# Combine with angular part
|
||||
if m >= 0:
|
||||
Z = R * np.cos(m_abs * theta)
|
||||
else:
|
||||
Z = R * np.sin(m_abs * theta)
|
||||
|
||||
# Normalization factor
|
||||
if m == 0:
|
||||
norm = np.sqrt(n + 1)
|
||||
else:
|
||||
norm = np.sqrt(2 * (n + 1))
|
||||
|
||||
return norm * Z
|
||||
|
||||
|
||||
def build_zernike_matrix(
|
||||
r: np.ndarray,
|
||||
theta: np.ndarray,
|
||||
n_modes: int = 50,
|
||||
r_max: float = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Build Zernike basis matrix for a set of points.
|
||||
|
||||
Args:
|
||||
r: Radial coordinates
|
||||
theta: Angular coordinates
|
||||
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
|
||||
r_max: Maximum radius for normalization (if None, use max(r))
|
||||
|
||||
Returns:
|
||||
Z: [n_points, n_modes] Zernike basis matrix
|
||||
"""
|
||||
if r_max is None:
|
||||
r_max = r.max()
|
||||
|
||||
r_norm = r / r_max
|
||||
|
||||
n_points = len(r)
|
||||
Z = np.zeros((n_points, n_modes), dtype=np.float64)
|
||||
|
||||
for j in range(1, n_modes + 1):
|
||||
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class DifferentiableZernikeFit(nn.Module):
|
||||
"""
|
||||
GPU-accelerated, differentiable Zernike polynomial fitting.
|
||||
|
||||
This layer fits Zernike coefficients to a displacement field using
|
||||
least squares. The key insight is that least squares has a closed-form
|
||||
solution: c = (Z^T Z)^{-1} Z^T @ values
|
||||
|
||||
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
|
||||
matrix multiply, which is fully differentiable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
n_modes: int = 50,
|
||||
regularization: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
n_modes: Number of Zernike modes to fit
|
||||
regularization: Tikhonov regularization for stability
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_modes = n_modes
|
||||
|
||||
# Get coordinates from polar graph
|
||||
r = polar_graph.r
|
||||
theta = polar_graph.theta
|
||||
r_max = polar_graph.r_outer
|
||||
|
||||
# Build Zernike basis matrix [n_nodes, n_modes]
|
||||
Z = build_zernike_matrix(r, theta, n_modes, r_max)
|
||||
|
||||
# Convert to tensor and register as buffer
|
||||
Z_tensor = torch.tensor(Z, dtype=torch.float32)
|
||||
self.register_buffer('Z', Z_tensor)
|
||||
|
||||
# Precompute pseudo-inverse with regularization
|
||||
# c = (Z^T Z + λI)^{-1} Z^T @ values
|
||||
ZtZ = Z_tensor.T @ Z_tensor
|
||||
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
|
||||
ZtZ_inv = torch.inverse(ZtZ_reg)
|
||||
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
|
||||
|
||||
self.register_buffer('pseudo_inverse', pseudo_inv)
|
||||
|
||||
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Fit Zernike coefficients to displacement field.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
|
||||
|
||||
Returns:
|
||||
coefficients: [n_modes] or [n_subcases, n_modes]
|
||||
"""
|
||||
if z_displacement.dim() == 1:
|
||||
# Single field: [n_nodes] → [n_modes]
|
||||
return self.pseudo_inverse @ z_displacement
|
||||
else:
|
||||
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
|
||||
# Transpose, multiply, transpose back
|
||||
return (self.pseudo_inverse @ z_displacement).T
|
||||
|
||||
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reconstruct displacement field from coefficients.
|
||||
|
||||
Args:
|
||||
coefficients: [n_modes] or [n_subcases, n_modes]
|
||||
|
||||
Returns:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
||||
"""
|
||||
if coefficients.dim() == 1:
|
||||
return self.Z @ coefficients
|
||||
else:
|
||||
return self.Z @ coefficients.T
|
||||
|
||||
def fit_and_residual(
|
||||
self,
|
||||
z_displacement: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fit coefficients and return residual.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
||||
|
||||
Returns:
|
||||
coefficients, residual
|
||||
"""
|
||||
coeffs = self.forward(z_displacement)
|
||||
reconstruction = self.reconstruct(coeffs)
|
||||
residual = z_displacement - reconstruction
|
||||
return coeffs, residual
|
||||
|
||||
|
||||
class ZernikeObjectiveLayer(nn.Module):
|
||||
"""
|
||||
Compute Zernike-based optimization objectives from displacement field.
|
||||
|
||||
This layer replicates the exact computation done in FEA post-processing:
|
||||
1. Compute relative displacement (e.g., 40° - 20°)
|
||||
2. Convert to wavefront error (× 2 for reflection, mm → nm)
|
||||
3. Fit Zernike and remove low-order terms
|
||||
4. Compute filtered RMS
|
||||
|
||||
The computation is fully differentiable, allowing end-to-end training
|
||||
with objective-based loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
n_modes: int = 50,
|
||||
regularization: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
n_modes: Number of Zernike modes
|
||||
regularization: Regularization for Zernike fitting
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_modes = n_modes
|
||||
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
|
||||
|
||||
# Precompute Zernike basis subsets for filtering
|
||||
Z = self.zernike_fit.Z
|
||||
|
||||
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
|
||||
self.register_buffer('Z_j1_to_j4', Z[:, :4])
|
||||
|
||||
# Only J1-J3 for manufacturing objective
|
||||
self.register_buffer('Z_j1_to_j3', Z[:, :3])
|
||||
|
||||
# Store node count
|
||||
self.n_nodes = Z.shape[0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z_disp_all_subcases: torch.Tensor,
|
||||
return_all: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute Zernike objectives from displacement field.
|
||||
|
||||
Args:
|
||||
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
|
||||
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
|
||||
return_all: If True, return additional diagnostics
|
||||
|
||||
Returns:
|
||||
Dictionary with objective values:
|
||||
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
|
||||
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
|
||||
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
|
||||
"""
|
||||
# Unpack subcases
|
||||
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
|
||||
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
|
||||
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
|
||||
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
|
||||
|
||||
# === Objective 1: Relative filtered RMS 40° vs 20° ===
|
||||
disp_rel_40 = disp_40 - disp_20
|
||||
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
|
||||
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
|
||||
|
||||
# === Objective 2: Relative filtered RMS 60° vs 20° ===
|
||||
disp_rel_60 = disp_60 - disp_20
|
||||
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
|
||||
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
|
||||
|
||||
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
|
||||
disp_rel_90 = disp_90 - disp_20
|
||||
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
|
||||
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
|
||||
|
||||
result = {
|
||||
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
|
||||
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
|
||||
'mfg_90_optician_workload': mfg_90,
|
||||
}
|
||||
|
||||
if return_all:
|
||||
# Include intermediate values for debugging
|
||||
result['wfe_rel_40'] = wfe_rel_40
|
||||
result['wfe_rel_60'] = wfe_rel_60
|
||||
result['wfe_rel_90'] = wfe_rel_90
|
||||
|
||||
return result
|
||||
|
||||
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
|
||||
|
||||
This is the standard filtered RMS for optical performance.
|
||||
"""
|
||||
# Fit low-order coefficients using precomputed pseudo-inverse
|
||||
# c = (Z^T Z)^{-1} Z^T @ wfe
|
||||
Z_low = self.Z_j1_to_j4
|
||||
ZtZ_low = Z_low.T @ Z_low
|
||||
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
||||
|
||||
# Reconstruct low-order surface
|
||||
wfe_low = Z_low @ coeffs_low
|
||||
|
||||
# Residual (high-order content)
|
||||
wfe_filtered = wfe - wfe_low
|
||||
|
||||
# RMS
|
||||
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
||||
|
||||
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS after removing only J1-J3 (piston, tip, tilt).
|
||||
|
||||
This keeps defocus (J4), which is harder to polish out - represents
|
||||
actual manufacturing workload.
|
||||
"""
|
||||
Z_low = self.Z_j1_to_j3
|
||||
ZtZ_low = Z_low.T @ Z_low
|
||||
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
||||
|
||||
wfe_low = Z_low @ coeffs_low
|
||||
wfe_filtered = wfe - wfe_low
|
||||
|
||||
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
||||
|
||||
|
||||
class ZernikeRMSLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function for GNN training.
|
||||
|
||||
This loss combines:
|
||||
1. Displacement field reconstruction loss (MSE)
|
||||
2. Objective prediction loss (relative Zernike RMS)
|
||||
|
||||
The multi-task loss helps the network learn both accurate
|
||||
displacement fields AND accurate objective predictions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
field_weight: float = 1.0,
|
||||
objective_weight: float = 0.1,
|
||||
n_modes: int = 50
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
field_weight: Weight for displacement field loss
|
||||
objective_weight: Weight for objective loss
|
||||
n_modes: Number of Zernike modes
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.field_weight = field_weight
|
||||
self.objective_weight = objective_weight
|
||||
|
||||
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z_disp_pred: torch.Tensor,
|
||||
z_disp_true: torch.Tensor,
|
||||
objectives_true: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Compute combined loss.
|
||||
|
||||
Args:
|
||||
z_disp_pred: Predicted displacement [n_nodes, 4]
|
||||
z_disp_true: Ground truth displacement [n_nodes, 4]
|
||||
objectives_true: Optional dict of true objective values
|
||||
|
||||
Returns:
|
||||
total_loss, loss_components dict
|
||||
"""
|
||||
# Field reconstruction loss
|
||||
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
|
||||
|
||||
# Scale field loss to account for small displacement values
|
||||
# Displacements are ~1e-4 mm, so MSE is ~1e-8
|
||||
loss_field_scaled = loss_field * 1e8
|
||||
|
||||
components = {
|
||||
'loss_field': loss_field_scaled,
|
||||
}
|
||||
|
||||
total_loss = self.field_weight * loss_field_scaled
|
||||
|
||||
# Objective loss (if ground truth provided)
|
||||
if objectives_true is not None and self.objective_weight > 0:
|
||||
objectives_pred = self.objective_layer(z_disp_pred)
|
||||
|
||||
loss_obj = 0.0
|
||||
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
if key in objectives_true:
|
||||
pred = objectives_pred[key]
|
||||
true = objectives_true[key]
|
||||
# Relative error squared
|
||||
rel_err = ((pred - true) / (true + 1e-6)) ** 2
|
||||
loss_obj = loss_obj + rel_err
|
||||
components[f'loss_{key}'] = rel_err
|
||||
|
||||
components['loss_objectives'] = loss_obj
|
||||
total_loss = total_loss + self.objective_weight * loss_obj
|
||||
|
||||
components['total_loss'] = total_loss
|
||||
return total_loss, components
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Testing
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
|
||||
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
print("="*60)
|
||||
print("Testing Differentiable Zernike Layer")
|
||||
print("="*60)
|
||||
|
||||
# Create polar graph
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"\nPolar Graph: {graph.n_nodes} nodes")
|
||||
|
||||
# Create Zernike fitting layer
|
||||
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
|
||||
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
|
||||
print(f" Z matrix: {zernike_fit.Z.shape}")
|
||||
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
|
||||
|
||||
# Test with synthetic displacement
|
||||
print("\n--- Test Zernike Fitting ---")
|
||||
|
||||
# Create synthetic displacement (defocus + astigmatism pattern)
|
||||
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
|
||||
theta = torch.tensor(graph.theta, dtype=torch.float32)
|
||||
|
||||
# Defocus (J4) + Astigmatism (J5)
|
||||
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
|
||||
|
||||
# Fit coefficients
|
||||
coeffs = zernike_fit(synthetic_disp)
|
||||
print(f"Fitted coefficients shape: {coeffs.shape}")
|
||||
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
|
||||
|
||||
# Reconstruct
|
||||
recon = zernike_fit.reconstruct(coeffs)
|
||||
error = (synthetic_disp - recon).abs()
|
||||
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
|
||||
|
||||
# Test with multiple subcases
|
||||
print("\n--- Test Multi-Subcase ---")
|
||||
z_disp_multi = torch.stack([
|
||||
synthetic_disp,
|
||||
synthetic_disp * 0.5,
|
||||
synthetic_disp * 0.7,
|
||||
synthetic_disp * 0.9,
|
||||
], dim=1) # [n_nodes, 4]
|
||||
|
||||
coeffs_multi = zernike_fit(z_disp_multi)
|
||||
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
|
||||
|
||||
# Test objective layer
|
||||
print("\n--- Test Objective Layer ---")
|
||||
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
||||
|
||||
objectives = objective_layer(z_disp_multi)
|
||||
print("Computed objectives:")
|
||||
for key, val in objectives.items():
|
||||
print(f" {key}: {val.item():.2f} nm")
|
||||
|
||||
# Test gradient flow
|
||||
print("\n--- Test Gradient Flow ---")
|
||||
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
|
||||
objectives = objective_layer(z_disp_grad)
|
||||
loss = objectives['rel_filtered_rms_40_vs_20']
|
||||
loss.backward()
|
||||
print(f"Gradient shape: {z_disp_grad.grad.shape}")
|
||||
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
|
||||
print("✓ Gradients flow through Zernike fitting!")
|
||||
|
||||
# Test loss function
|
||||
print("\n--- Test Loss Function ---")
|
||||
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
|
||||
|
||||
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
|
||||
|
||||
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
for key, val in components.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
print(f" {key}: {val.item():.6f}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("✓ All tests passed!")
|
||||
print("="*60)
|
||||
455
optimization_engine/gnn/extract_displacement_field.py
Normal file
455
optimization_engine/gnn/extract_displacement_field.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Displacement Field Extraction for GNN Training
|
||||
===============================================
|
||||
|
||||
This module extracts full displacement fields from Nastran OP2 files for GNN training.
|
||||
Unlike the Zernike extractors that reduce to coefficients, this preserves the raw
|
||||
spatial data that GNN needs to learn the physics.
|
||||
|
||||
Key Features:
|
||||
1. Extract Z-displacement for all optical surface nodes
|
||||
2. Store node coordinates for graph construction
|
||||
3. Support for multiple subcases (gravity orientations)
|
||||
4. HDF5 storage for efficient loading during training
|
||||
|
||||
Output Format (HDF5):
|
||||
/node_ids - [n_nodes] int array
|
||||
/node_coords - [n_nodes, 3] float array (X, Y, Z)
|
||||
/subcase_1 - [n_nodes] Z-displacement for subcase 1
|
||||
/subcase_2 - [n_nodes] Z-displacement for subcase 2
|
||||
/subcase_3 - [n_nodes] Z-displacement for subcase 3
|
||||
/subcase_4 - [n_nodes] Z-displacement for subcase 4
|
||||
/metadata - JSON string with extraction info
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.extract_displacement_field import extract_displacement_field
|
||||
|
||||
field_data = extract_displacement_field(op2_path, bdf_path)
|
||||
save_field_to_hdf5(field_data, output_path)
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
import h5py
|
||||
HAS_H5PY = True
|
||||
except ImportError:
|
||||
HAS_H5PY = False
|
||||
|
||||
from pyNastran.op2.op2 import OP2
|
||||
from pyNastran.bdf.bdf import BDF
|
||||
|
||||
|
||||
def identify_optical_surface_nodes(
|
||||
node_coords: Dict[int, np.ndarray],
|
||||
r_inner: float = 100.0,
|
||||
r_outer: float = 650.0,
|
||||
z_tolerance: float = 100.0
|
||||
) -> Tuple[List[int], np.ndarray]:
|
||||
"""
|
||||
Identify nodes on the optical surface by spatial filtering.
|
||||
|
||||
The optical surface is identified by:
|
||||
1. Radial position (between inner and outer radius)
|
||||
2. Consistent Z range (nodes on the curved mirror surface)
|
||||
|
||||
Args:
|
||||
node_coords: Dictionary mapping node ID to (X, Y, Z) coordinates
|
||||
r_inner: Inner radius cutoff (central hole)
|
||||
r_outer: Outer radius limit
|
||||
z_tolerance: Maximum Z deviation from mean to include
|
||||
|
||||
Returns:
|
||||
Tuple of (node_ids list, coordinates array [n, 3])
|
||||
"""
|
||||
# Get all coordinates as arrays
|
||||
nids = list(node_coords.keys())
|
||||
coords = np.array([node_coords[nid] for nid in nids])
|
||||
|
||||
# Calculate radial position
|
||||
r = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
|
||||
|
||||
# Initial radial filter
|
||||
radial_mask = (r >= r_inner) & (r <= r_outer)
|
||||
|
||||
# Find nodes in radial range
|
||||
radial_nids = np.array(nids)[radial_mask]
|
||||
radial_coords = coords[radial_mask]
|
||||
|
||||
if len(radial_coords) == 0:
|
||||
raise ValueError(f"No nodes found in radial range [{r_inner}, {r_outer}]")
|
||||
|
||||
# The optical surface should have a relatively small Z range
|
||||
z_vals = radial_coords[:, 2]
|
||||
z_mean = np.mean(z_vals)
|
||||
|
||||
# Filter to nodes within z_tolerance of the mean Z
|
||||
z_mask = np.abs(radial_coords[:, 2] - z_mean) < z_tolerance
|
||||
|
||||
surface_nids = radial_nids[z_mask].tolist()
|
||||
surface_coords = radial_coords[z_mask]
|
||||
|
||||
return surface_nids, surface_coords
|
||||
|
||||
|
||||
def extract_displacement_field(
|
||||
op2_path: Path,
|
||||
bdf_path: Optional[Path] = None,
|
||||
r_inner: float = 100.0,
|
||||
r_outer: float = 650.0,
|
||||
subcases: Optional[List[int]] = None,
|
||||
verbose: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract full displacement field for GNN training.
|
||||
|
||||
This function extracts Z-displacement from OP2 files for nodes on the optical
|
||||
surface (defined by radial position). It builds node coordinates directly from
|
||||
the OP2 data matched against BDF geometry, then filters by radial position.
|
||||
|
||||
Args:
|
||||
op2_path: Path to OP2 file
|
||||
bdf_path: Path to BDF/DAT file (auto-detected if None)
|
||||
r_inner: Inner radius for surface identification (mm)
|
||||
r_outer: Outer radius for surface identification (mm)
|
||||
subcases: List of subcases to extract (default: [1, 2, 3, 4])
|
||||
verbose: Print progress messages
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- node_ids: List of node IDs on optical surface
|
||||
- node_coords: Array [n_nodes, 3] of coordinates
|
||||
- z_displacement: Dict mapping subcase -> [n_nodes] Z-displacements
|
||||
- metadata: Extraction metadata
|
||||
"""
|
||||
op2_path = Path(op2_path)
|
||||
|
||||
# Find BDF file
|
||||
if bdf_path is None:
|
||||
for ext in ['.dat', '.bdf']:
|
||||
candidate = op2_path.with_suffix(ext)
|
||||
if candidate.exists():
|
||||
bdf_path = candidate
|
||||
break
|
||||
if bdf_path is None:
|
||||
raise FileNotFoundError(f"No .dat or .bdf found for {op2_path}")
|
||||
|
||||
if subcases is None:
|
||||
subcases = [1, 2, 3, 4]
|
||||
|
||||
if verbose:
|
||||
print(f"[FIELD] Reading geometry from: {bdf_path.name}")
|
||||
|
||||
# Read geometry from BDF
|
||||
bdf = BDF()
|
||||
bdf.read_bdf(str(bdf_path))
|
||||
node_geo = {int(nid): node.get_position() for nid, node in bdf.nodes.items()}
|
||||
|
||||
if verbose:
|
||||
print(f"[FIELD] Total nodes in BDF: {len(node_geo)}")
|
||||
|
||||
# Read OP2
|
||||
if verbose:
|
||||
print(f"[FIELD] Reading displacements from: {op2_path.name}")
|
||||
op2 = OP2()
|
||||
op2.read_op2(str(op2_path))
|
||||
|
||||
if not op2.displacements:
|
||||
raise RuntimeError("No displacement data in OP2")
|
||||
|
||||
# Extract data by iterating through OP2 nodes and matching to BDF geometry
|
||||
# This approach works even when node numbering differs between sources
|
||||
subcase_data = {}
|
||||
|
||||
for key, darr in op2.displacements.items():
|
||||
isub = int(getattr(darr, 'isubcase', key))
|
||||
if isub not in subcases:
|
||||
continue
|
||||
|
||||
data = darr.data
|
||||
dmat = data[0] if data.ndim == 3 else data
|
||||
ngt = darr.node_gridtype
|
||||
op2_node_ids = ngt[:, 0] if ngt.ndim == 2 else ngt
|
||||
|
||||
# Build arrays of matched data
|
||||
nids = []
|
||||
X = []
|
||||
Y = []
|
||||
Z = []
|
||||
disp_z = []
|
||||
|
||||
for i, nid in enumerate(op2_node_ids):
|
||||
nid_int = int(nid)
|
||||
if nid_int in node_geo:
|
||||
pos = node_geo[nid_int]
|
||||
nids.append(nid_int)
|
||||
X.append(pos[0])
|
||||
Y.append(pos[1])
|
||||
Z.append(pos[2])
|
||||
disp_z.append(float(dmat[i, 2])) # Z component
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
Z = np.array(Z, dtype=np.float32)
|
||||
disp_z = np.array(disp_z, dtype=np.float32)
|
||||
nids = np.array(nids, dtype=np.int32)
|
||||
|
||||
# Filter to optical surface by radial position
|
||||
r = np.sqrt(X**2 + Y**2)
|
||||
surface_mask = (r >= r_inner) & (r <= r_outer)
|
||||
|
||||
subcase_data[isub] = {
|
||||
'node_ids': nids[surface_mask],
|
||||
'coords': np.column_stack([X[surface_mask], Y[surface_mask], Z[surface_mask]]),
|
||||
'disp_z': disp_z[surface_mask],
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"[FIELD] Subcase {isub}: {len(nids)} matched, {np.sum(surface_mask)} on surface")
|
||||
|
||||
# Get common nodes across all subcases (should be the same)
|
||||
all_subcase_keys = list(subcase_data.keys())
|
||||
if not all_subcase_keys:
|
||||
raise RuntimeError("No subcases found in OP2")
|
||||
|
||||
# Use first subcase to define node list
|
||||
ref_subcase = all_subcase_keys[0]
|
||||
surface_nids = subcase_data[ref_subcase]['node_ids'].tolist()
|
||||
surface_coords = subcase_data[ref_subcase]['coords']
|
||||
|
||||
# Build displacement dict for all subcases
|
||||
z_displacement = {}
|
||||
for isub in subcases:
|
||||
if isub in subcase_data:
|
||||
z_displacement[isub] = subcase_data[isub]['disp_z']
|
||||
|
||||
if verbose:
|
||||
print(f"[FIELD] Final surface: {len(surface_nids)} nodes")
|
||||
r_surface = np.sqrt(surface_coords[:, 0]**2 + surface_coords[:, 1]**2)
|
||||
print(f"[FIELD] Radial range: [{r_surface.min():.1f}, {r_surface.max():.1f}] mm")
|
||||
|
||||
# Build metadata
|
||||
metadata = {
|
||||
'extraction_timestamp': datetime.now().isoformat(),
|
||||
'op2_file': str(op2_path.name),
|
||||
'bdf_file': str(bdf_path.name),
|
||||
'n_nodes': len(surface_nids),
|
||||
'r_inner': r_inner,
|
||||
'r_outer': r_outer,
|
||||
'subcases': list(z_displacement.keys()),
|
||||
}
|
||||
|
||||
return {
|
||||
'node_ids': surface_nids,
|
||||
'node_coords': surface_coords,
|
||||
'z_displacement': z_displacement,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
|
||||
def save_field_to_hdf5(
|
||||
field_data: Dict[str, Any],
|
||||
output_path: Path,
|
||||
compression: str = 'gzip'
|
||||
) -> None:
|
||||
"""
|
||||
Save displacement field data to HDF5 file.
|
||||
|
||||
Args:
|
||||
field_data: Output from extract_displacement_field()
|
||||
output_path: Path to save HDF5 file
|
||||
compression: Compression algorithm ('gzip', 'lzf', or None)
|
||||
"""
|
||||
if not HAS_H5PY:
|
||||
raise ImportError("h5py required for HDF5 storage: pip install h5py")
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with h5py.File(output_path, 'w') as f:
|
||||
# Node data
|
||||
f.create_dataset('node_ids', data=np.array(field_data['node_ids'], dtype=np.int32),
|
||||
compression=compression)
|
||||
f.create_dataset('node_coords', data=field_data['node_coords'].astype(np.float32),
|
||||
compression=compression)
|
||||
|
||||
# Displacement for each subcase
|
||||
for subcase, z_disp in field_data['z_displacement'].items():
|
||||
f.create_dataset(f'subcase_{subcase}', data=z_disp.astype(np.float32),
|
||||
compression=compression)
|
||||
|
||||
# Metadata as JSON string
|
||||
f.attrs['metadata'] = json.dumps(field_data['metadata'])
|
||||
|
||||
# Report file size
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"[FIELD] Saved to {output_path.name} ({size_kb:.1f} KB)")
|
||||
|
||||
|
||||
def load_field_from_hdf5(hdf5_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Load displacement field data from HDF5 file.
|
||||
|
||||
Args:
|
||||
hdf5_path: Path to HDF5 file
|
||||
|
||||
Returns:
|
||||
Dictionary with same structure as extract_displacement_field()
|
||||
"""
|
||||
if not HAS_H5PY:
|
||||
raise ImportError("h5py required for HDF5 storage: pip install h5py")
|
||||
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
node_ids = f['node_ids'][:].tolist()
|
||||
node_coords = f['node_coords'][:]
|
||||
|
||||
# Load subcases
|
||||
z_displacement = {}
|
||||
for key in f.keys():
|
||||
if key.startswith('subcase_'):
|
||||
subcase = int(key.split('_')[1])
|
||||
z_displacement[subcase] = f[key][:]
|
||||
|
||||
metadata = json.loads(f.attrs['metadata'])
|
||||
|
||||
return {
|
||||
'node_ids': node_ids,
|
||||
'node_coords': node_coords,
|
||||
'z_displacement': z_displacement,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
|
||||
def save_field_to_npz(
|
||||
field_data: Dict[str, Any],
|
||||
output_path: Path
|
||||
) -> None:
|
||||
"""
|
||||
Save displacement field data to compressed NPZ file (fallback if no h5py).
|
||||
|
||||
Args:
|
||||
field_data: Output from extract_displacement_field()
|
||||
output_path: Path to save NPZ file
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_dict = {
|
||||
'node_ids': np.array(field_data['node_ids'], dtype=np.int32),
|
||||
'node_coords': field_data['node_coords'].astype(np.float32),
|
||||
'metadata_json': np.array([json.dumps(field_data['metadata'])]),
|
||||
}
|
||||
|
||||
# Add subcases
|
||||
for subcase, z_disp in field_data['z_displacement'].items():
|
||||
save_dict[f'subcase_{subcase}'] = z_disp.astype(np.float32)
|
||||
|
||||
np.savez_compressed(output_path, **save_dict)
|
||||
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"[FIELD] Saved to {output_path.name} ({size_kb:.1f} KB)")
|
||||
|
||||
|
||||
def load_field_from_npz(npz_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Load displacement field data from NPZ file.
|
||||
|
||||
Args:
|
||||
npz_path: Path to NPZ file
|
||||
|
||||
Returns:
|
||||
Dictionary with same structure as extract_displacement_field()
|
||||
"""
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
|
||||
node_ids = data['node_ids'].tolist()
|
||||
node_coords = data['node_coords']
|
||||
metadata = json.loads(str(data['metadata_json'][0]))
|
||||
|
||||
# Load subcases
|
||||
z_displacement = {}
|
||||
for key in data.keys():
|
||||
if key.startswith('subcase_'):
|
||||
subcase = int(key.split('_')[1])
|
||||
z_displacement[subcase] = data[key]
|
||||
|
||||
return {
|
||||
'node_ids': node_ids,
|
||||
'node_coords': node_coords,
|
||||
'z_displacement': z_displacement,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Convenience functions
|
||||
# =============================================================================
|
||||
|
||||
def save_field(field_data: Dict[str, Any], output_path: Path) -> None:
|
||||
"""Save field data using best available format (HDF5 preferred)."""
|
||||
output_path = Path(output_path)
|
||||
if HAS_H5PY and output_path.suffix == '.h5':
|
||||
save_field_to_hdf5(field_data, output_path)
|
||||
else:
|
||||
if output_path.suffix != '.npz':
|
||||
output_path = output_path.with_suffix('.npz')
|
||||
save_field_to_npz(field_data, output_path)
|
||||
|
||||
|
||||
def load_field(path: Path) -> Dict[str, Any]:
|
||||
"""Load field data from HDF5 or NPZ file."""
|
||||
path = Path(path)
|
||||
if path.suffix == '.h5':
|
||||
return load_field_from_hdf5(path)
|
||||
else:
|
||||
return load_field_from_npz(path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract displacement field from Nastran OP2 for GNN training'
|
||||
)
|
||||
parser.add_argument('op2_path', type=Path, help='Path to OP2 file')
|
||||
parser.add_argument('-o', '--output', type=Path, help='Output path (default: same dir as OP2)')
|
||||
parser.add_argument('--r-inner', type=float, default=100.0, help='Inner radius (mm)')
|
||||
parser.add_argument('--r-outer', type=float, default=650.0, help='Outer radius (mm)')
|
||||
parser.add_argument('--format', choices=['h5', 'npz'], default='h5',
|
||||
help='Output format (default: h5)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract field
|
||||
field_data = extract_displacement_field(
|
||||
args.op2_path,
|
||||
r_inner=args.r_inner,
|
||||
r_outer=args.r_outer,
|
||||
)
|
||||
|
||||
# Determine output path
|
||||
if args.output:
|
||||
output_path = args.output
|
||||
else:
|
||||
output_path = args.op2_path.parent / f'displacement_field.{args.format}'
|
||||
|
||||
# Save
|
||||
save_field(field_data, output_path)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "="*60)
|
||||
print("EXTRACTION SUMMARY")
|
||||
print("="*60)
|
||||
print(f"Nodes: {len(field_data['node_ids'])}")
|
||||
print(f"Subcases: {list(field_data['z_displacement'].keys())}")
|
||||
for sc, disp in field_data['z_displacement'].items():
|
||||
print(f" Subcase {sc}: Z range [{disp.min():.4f}, {disp.max():.4f}] mm")
|
||||
718
optimization_engine/gnn/gnn_optimizer.py
Normal file
718
optimization_engine/gnn/gnn_optimizer.py
Normal file
@@ -0,0 +1,718 @@
|
||||
"""
|
||||
GNN-Based Optimizer for Zernike Mirror Optimization
|
||||
====================================================
|
||||
|
||||
This module provides a fast GNN-based optimization workflow:
|
||||
1. Load trained GNN checkpoint
|
||||
2. Run thousands of fast GNN predictions
|
||||
3. Select top candidates
|
||||
4. Validate with FEA (optional)
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
|
||||
|
||||
optimizer = ZernikeGNNOptimizer.from_checkpoint('zernike_gnn_checkpoint.pt')
|
||||
results = optimizer.turbo_optimize(n_trials=5000)
|
||||
|
||||
# Get best designs
|
||||
best = results.get_best(n=10)
|
||||
|
||||
# Validate with FEA
|
||||
validated = optimizer.validate_with_fea(best, study_dir)
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import optuna
|
||||
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
from optimization_engine.gnn.zernike_gnn import create_model, load_model
|
||||
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer
|
||||
|
||||
|
||||
@dataclass
|
||||
class GNNPrediction:
|
||||
"""Single GNN prediction result."""
|
||||
design_vars: Dict[str, float]
|
||||
objectives: Dict[str, float]
|
||||
z_displacement: Optional[np.ndarray] = None # [3000, 4] if stored
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
'design_vars': self.design_vars,
|
||||
'objectives': self.objectives,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationResults:
|
||||
"""Container for optimization results."""
|
||||
predictions: List[GNNPrediction] = field(default_factory=list)
|
||||
pareto_front: List[int] = field(default_factory=list) # Indices
|
||||
|
||||
def add(self, pred: GNNPrediction):
|
||||
self.predictions.append(pred)
|
||||
|
||||
def get_best(self, n: int = 10, objective: str = 'rel_filtered_rms_40_vs_20') -> List[GNNPrediction]:
|
||||
"""Get top N designs by a single objective."""
|
||||
sorted_preds = sorted(self.predictions, key=lambda p: p.objectives.get(objective, float('inf')))
|
||||
return sorted_preds[:n]
|
||||
|
||||
def get_pareto_front(self, objectives: List[str] = None) -> List[GNNPrediction]:
|
||||
"""Get Pareto-optimal designs."""
|
||||
if objectives is None:
|
||||
objectives = ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']
|
||||
|
||||
# Extract objective values
|
||||
obj_values = np.array([
|
||||
[p.objectives.get(obj, float('inf')) for obj in objectives]
|
||||
for p in self.predictions
|
||||
])
|
||||
|
||||
# Find Pareto front (all objectives are minimized)
|
||||
pareto_indices = []
|
||||
for i in range(len(self.predictions)):
|
||||
is_dominated = False
|
||||
for j in range(len(self.predictions)):
|
||||
if i != j:
|
||||
# j dominates i if j is <= in all objectives and < in at least one
|
||||
if np.all(obj_values[j] <= obj_values[i]) and np.any(obj_values[j] < obj_values[i]):
|
||||
is_dominated = True
|
||||
break
|
||||
if not is_dominated:
|
||||
pareto_indices.append(i)
|
||||
|
||||
self.pareto_front = pareto_indices
|
||||
return [self.predictions[i] for i in pareto_indices]
|
||||
|
||||
def to_dataframe(self):
|
||||
"""Convert to pandas DataFrame."""
|
||||
import pandas as pd
|
||||
|
||||
rows = []
|
||||
for i, pred in enumerate(self.predictions):
|
||||
row = {'index': i}
|
||||
row.update(pred.design_vars)
|
||||
row.update({f'obj_{k}': v for k, v in pred.objectives.items()})
|
||||
rows.append(row)
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def save(self, path: Path):
|
||||
"""Save results to JSON."""
|
||||
data = {
|
||||
'n_predictions': len(self.predictions),
|
||||
'pareto_front_indices': self.pareto_front,
|
||||
'predictions': [p.to_dict() for p in self.predictions],
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
class ZernikeGNNOptimizer:
|
||||
"""
|
||||
GNN-based optimizer for Zernike mirror optimization.
|
||||
|
||||
Provides fast objective prediction using trained GNN surrogate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
polar_graph: PolarMirrorGraph,
|
||||
design_names: List[str],
|
||||
design_bounds: Dict[str, Tuple[float, float]],
|
||||
design_mean: torch.Tensor,
|
||||
design_std: torch.Tensor,
|
||||
device: str = 'cpu',
|
||||
disp_scale: float = 1.0
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.model.eval()
|
||||
self.polar_graph = polar_graph
|
||||
self.design_names = design_names
|
||||
self.design_bounds = design_bounds
|
||||
self.design_mean = design_mean.to(device)
|
||||
self.design_std = design_std.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.disp_scale = disp_scale # Scaling factor from training
|
||||
|
||||
# Prepare fixed graph tensors
|
||||
self.node_features = torch.tensor(
|
||||
polar_graph.get_node_features(normalized=True),
|
||||
dtype=torch.float32
|
||||
).to(device)
|
||||
|
||||
self.edge_index = torch.tensor(
|
||||
polar_graph.edge_index,
|
||||
dtype=torch.long
|
||||
).to(device)
|
||||
|
||||
self.edge_attr = torch.tensor(
|
||||
polar_graph.get_edge_features(normalized=True),
|
||||
dtype=torch.float32
|
||||
).to(device)
|
||||
|
||||
# Objective computation layer (must be on same device as model)
|
||||
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes=50).to(device)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
checkpoint_path: Path,
|
||||
config_path: Optional[Path] = None,
|
||||
device: str = 'auto'
|
||||
) -> 'ZernikeGNNOptimizer':
|
||||
"""
|
||||
Load optimizer from trained checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to zernike_gnn_checkpoint.pt
|
||||
config_path: Path to optimization_config.json (for design bounds)
|
||||
device: Device to use ('cpu', 'cuda', 'auto')
|
||||
"""
|
||||
if device == 'auto':
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
||||
|
||||
# Create polar graph
|
||||
polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
|
||||
# Create model - handle both old ('model_config') and new ('config') format
|
||||
model_config = checkpoint.get('model_config') or checkpoint.get('config', {})
|
||||
model = create_model(**model_config)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Get design info from checkpoint
|
||||
design_mean = checkpoint['design_mean']
|
||||
design_std = checkpoint['design_std']
|
||||
disp_scale = checkpoint.get('disp_scale', 1.0) # Displacement scaling factor
|
||||
|
||||
# Try to get design names and bounds from config
|
||||
design_names = []
|
||||
design_bounds = {}
|
||||
|
||||
if config_path and Path(config_path).exists():
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
for var in config.get('design_variables', []):
|
||||
name = var['name']
|
||||
design_names.append(name)
|
||||
design_bounds[name] = (var['min'], var['max'])
|
||||
else:
|
||||
# Use generic names based on checkpoint
|
||||
n_vars = len(design_mean)
|
||||
design_names = [f'var_{i}' for i in range(n_vars)]
|
||||
# Default bounds (will be overridden if config provided)
|
||||
for name in design_names:
|
||||
design_bounds[name] = (-100, 100)
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
polar_graph=polar_graph,
|
||||
design_names=design_names,
|
||||
design_bounds=design_bounds,
|
||||
design_mean=design_mean,
|
||||
design_std=design_std,
|
||||
device=device,
|
||||
disp_scale=disp_scale
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(self, design_vars: Dict[str, float], return_field: bool = False) -> GNNPrediction:
|
||||
"""
|
||||
Predict objectives for a single design.
|
||||
|
||||
Args:
|
||||
design_vars: Dict mapping variable names to values
|
||||
return_field: Whether to include displacement field in result
|
||||
|
||||
Returns:
|
||||
GNNPrediction with objectives
|
||||
"""
|
||||
# Convert to tensor
|
||||
design_values = [design_vars.get(name, 0.0) for name in self.design_names]
|
||||
design_tensor = torch.tensor(design_values, dtype=torch.float32).to(self.device)
|
||||
|
||||
# Normalize
|
||||
design_norm = (design_tensor - self.design_mean) / self.design_std
|
||||
|
||||
# Forward pass
|
||||
z_disp_scaled = self.model(
|
||||
self.node_features,
|
||||
self.edge_index,
|
||||
self.edge_attr,
|
||||
design_norm
|
||||
) # [3000, 4] in scaled units (μm if disp_scale=1e6)
|
||||
|
||||
# Convert back to mm before computing objectives
|
||||
# During training: z_disp_mm * disp_scale = z_disp_scaled
|
||||
# So: z_disp_mm = z_disp_scaled / disp_scale
|
||||
z_disp_mm = z_disp_scaled / self.disp_scale
|
||||
|
||||
# Compute objectives (ZernikeObjectiveLayer expects mm input)
|
||||
objectives = self.objective_layer(z_disp_mm)
|
||||
|
||||
# Objectives are now directly in nm (no additional scaling needed)
|
||||
obj_dict = {
|
||||
'rel_filtered_rms_40_vs_20': objectives['rel_filtered_rms_40_vs_20'].item(),
|
||||
'rel_filtered_rms_60_vs_20': objectives['rel_filtered_rms_60_vs_20'].item(),
|
||||
'mfg_90_optician_workload': objectives['mfg_90_optician_workload'].item(),
|
||||
}
|
||||
|
||||
field_data = z_disp_mm.cpu().numpy() if return_field else None
|
||||
|
||||
return GNNPrediction(
|
||||
design_vars=design_vars,
|
||||
objectives=obj_dict,
|
||||
z_displacement=field_data
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_batch(self, designs: List[Dict[str, float]]) -> List[GNNPrediction]:
|
||||
"""
|
||||
Predict objectives for multiple designs (batched for efficiency).
|
||||
|
||||
Args:
|
||||
designs: List of design variable dicts
|
||||
|
||||
Returns:
|
||||
List of GNNPrediction
|
||||
"""
|
||||
results = []
|
||||
for design in designs:
|
||||
results.append(self.predict(design))
|
||||
return results
|
||||
|
||||
def random_design(self) -> Dict[str, float]:
|
||||
"""Generate a random design within bounds."""
|
||||
design = {}
|
||||
for name in self.design_names:
|
||||
low, high = self.design_bounds.get(name, (-100, 100))
|
||||
design[name] = np.random.uniform(low, high)
|
||||
return design
|
||||
|
||||
def turbo_optimize(
|
||||
self,
|
||||
n_trials: int = 5000,
|
||||
sampler: str = 'tpe',
|
||||
seed: int = 42,
|
||||
verbose: bool = True
|
||||
) -> OptimizationResults:
|
||||
"""
|
||||
Run fast GNN-based optimization.
|
||||
|
||||
Args:
|
||||
n_trials: Number of GNN trials to run
|
||||
sampler: Optuna sampler ('tpe', 'random', 'cmaes')
|
||||
seed: Random seed
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
OptimizationResults with all predictions
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
results = OptimizationResults()
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print("GNN TURBO OPTIMIZATION")
|
||||
print(f"{'='*60}")
|
||||
print(f"Trials: {n_trials}")
|
||||
print(f"Sampler: {sampler}")
|
||||
print(f"Design variables: {len(self.design_names)}")
|
||||
print(f"Device: {self.device}")
|
||||
|
||||
# Create Optuna study for smart sampling
|
||||
if sampler == 'tpe':
|
||||
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
|
||||
elif sampler == 'random':
|
||||
optuna_sampler = optuna.samplers.RandomSampler(seed=seed)
|
||||
elif sampler == 'cmaes':
|
||||
optuna_sampler = optuna.samplers.CmaEsSampler(seed=seed)
|
||||
else:
|
||||
optuna_sampler = optuna.samplers.TPESampler(seed=seed)
|
||||
|
||||
study = optuna.create_study(
|
||||
directions=['minimize', 'minimize', 'minimize'], # 3 objectives
|
||||
sampler=optuna_sampler
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
def objective(trial):
|
||||
# Sample design
|
||||
design = {}
|
||||
for name in self.design_names:
|
||||
low, high = self.design_bounds.get(name, (-100, 100))
|
||||
design[name] = trial.suggest_float(name, low, high)
|
||||
|
||||
# Predict with GNN
|
||||
pred = self.predict(design)
|
||||
results.add(pred)
|
||||
|
||||
return (
|
||||
pred.objectives['rel_filtered_rms_40_vs_20'],
|
||||
pred.objectives['rel_filtered_rms_60_vs_20'],
|
||||
pred.objectives['mfg_90_optician_workload']
|
||||
)
|
||||
|
||||
# Run optimization
|
||||
if verbose:
|
||||
print(f"\nRunning {n_trials} GNN trials...")
|
||||
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
study.optimize(objective, n_trials=n_trials, show_progress_bar=verbose)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if verbose:
|
||||
print(f"\nCompleted in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
|
||||
|
||||
# Compute Pareto front
|
||||
pareto = results.get_pareto_front()
|
||||
print(f"Pareto front: {len(pareto)} designs")
|
||||
|
||||
# Best by each objective
|
||||
print("\nBest by objective:")
|
||||
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
best = results.get_best(n=1, objective=obj)[0]
|
||||
print(f" {obj}: {best.objectives[obj]:.2f} nm")
|
||||
|
||||
return results
|
||||
|
||||
def validate_with_fea(
|
||||
self,
|
||||
candidates: List[GNNPrediction],
|
||||
study_dir: Path,
|
||||
verbose: bool = True,
|
||||
start_trial_num: int = 9000
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Validate GNN predictions with actual FEA.
|
||||
|
||||
This runs the full NX + Nastran workflow on each candidate
|
||||
to get true objective values.
|
||||
|
||||
Args:
|
||||
candidates: GNN predictions to validate
|
||||
study_dir: Path to study directory (for config and scripts)
|
||||
verbose: Print progress
|
||||
start_trial_num: Starting trial number for iteration folders
|
||||
|
||||
Returns:
|
||||
List of dicts with 'gnn' and 'fea' objectives for comparison
|
||||
"""
|
||||
import time
|
||||
import re
|
||||
from optimization_engine.nx_solver import NXSolver
|
||||
from optimization_engine.extractors import ZernikeExtractor
|
||||
|
||||
study_dir = Path(study_dir)
|
||||
config_path = study_dir / "1_setup" / "optimization_config.json"
|
||||
model_dir = study_dir / "1_setup" / "model"
|
||||
iterations_dir = study_dir / "2_iterations"
|
||||
|
||||
# Load config
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Initialize NX Solver
|
||||
nx_settings = config.get('nx_settings', {})
|
||||
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
|
||||
version_match = re.search(r'NX(\d+)', nx_install_dir)
|
||||
nastran_version = version_match.group(1) if version_match else "2506"
|
||||
|
||||
solver = NXSolver(
|
||||
master_model_dir=str(model_dir),
|
||||
nx_install_dir=nx_install_dir,
|
||||
nastran_version=nastran_version,
|
||||
timeout=nx_settings.get('simulation_timeout_s', 600),
|
||||
use_iteration_folders=True,
|
||||
study_name="gnn_validation"
|
||||
)
|
||||
|
||||
iterations_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print("FEA VALIDATION OF GNN PREDICTIONS")
|
||||
print(f"{'='*60}")
|
||||
print(f"Validating {len(candidates)} candidates")
|
||||
print(f"Study: {study_dir.name}")
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
trial_num = start_trial_num + i
|
||||
|
||||
if verbose:
|
||||
print(f"\n[{i+1}/{len(candidates)}] Trial {trial_num}")
|
||||
print(f" GNN predicted: 40vs20={candidate.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||||
|
||||
# Build expression updates from design variables
|
||||
expressions = {}
|
||||
for var in config.get('design_variables', []):
|
||||
var_name = var['name']
|
||||
expr_name = var.get('expression_name', var_name)
|
||||
if var_name in candidate.design_vars:
|
||||
expressions[expr_name] = candidate.design_vars[var_name]
|
||||
|
||||
# Create iteration folder with model copies
|
||||
try:
|
||||
iter_folder = solver.create_iteration_folder(
|
||||
iterations_base_dir=iterations_dir,
|
||||
iteration_number=trial_num,
|
||||
expression_updates=expressions
|
||||
)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f" ERROR creating iteration folder: {e}")
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': None,
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
continue
|
||||
|
||||
# Run simulation
|
||||
sim_file = iter_folder / nx_settings.get('sim_file', 'ASSY_M1_assyfem1_sim1.sim')
|
||||
solution_name = nx_settings.get('solution_name', 'Solution 1')
|
||||
|
||||
t_start = time.time()
|
||||
try:
|
||||
solve_result = solver.run_simulation(
|
||||
sim_file=sim_file,
|
||||
working_dir=iter_folder,
|
||||
expression_updates=expressions,
|
||||
solution_name=solution_name,
|
||||
cleanup=False
|
||||
)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f" ERROR in simulation: {e}")
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': None,
|
||||
'status': 'solve_error',
|
||||
'error': str(e)
|
||||
})
|
||||
continue
|
||||
|
||||
solve_time = time.time() - t_start
|
||||
|
||||
if not solve_result['success']:
|
||||
if verbose:
|
||||
print(f" Solve FAILED: {solve_result.get('errors', ['Unknown'])}")
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': None,
|
||||
'status': 'solve_failed',
|
||||
'errors': solve_result.get('errors', [])
|
||||
})
|
||||
continue
|
||||
|
||||
if verbose:
|
||||
print(f" Solved in {solve_time:.1f}s")
|
||||
|
||||
# Extract objectives using ZernikeExtractor
|
||||
op2_path = solve_result['op2_file']
|
||||
if op2_path is None or not Path(op2_path).exists():
|
||||
if verbose:
|
||||
print(f" ERROR: OP2 file not found")
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': None,
|
||||
'status': 'no_op2',
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
zernike_settings = config.get('zernike_settings', {})
|
||||
extractor = ZernikeExtractor(
|
||||
op2_path,
|
||||
bdf_path=None,
|
||||
displacement_unit=zernike_settings.get('displacement_unit', 'mm'),
|
||||
n_modes=zernike_settings.get('n_modes', 50),
|
||||
filter_orders=zernike_settings.get('filter_low_orders', 4)
|
||||
)
|
||||
|
||||
ref = zernike_settings.get('reference_subcase', '2')
|
||||
|
||||
# Extract objectives: 40 vs 20, 60 vs 20, mfg 90
|
||||
rel_40 = extractor.extract_relative("3", ref)
|
||||
rel_60 = extractor.extract_relative("4", ref)
|
||||
rel_90 = extractor.extract_relative("1", ref)
|
||||
|
||||
fea_objectives = {
|
||||
'rel_filtered_rms_40_vs_20': rel_40['relative_filtered_rms_nm'],
|
||||
'rel_filtered_rms_60_vs_20': rel_60['relative_filtered_rms_nm'],
|
||||
'mfg_90_optician_workload': rel_90['relative_rms_filter_j1to3'],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f" ERROR in Zernike extraction: {e}")
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': None,
|
||||
'status': 'extraction_error',
|
||||
'error': str(e)
|
||||
})
|
||||
continue
|
||||
|
||||
# Compute errors
|
||||
errors = {}
|
||||
for obj_name in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
gnn_val = candidate.objectives[obj_name]
|
||||
fea_val = fea_objectives[obj_name]
|
||||
errors[f'{obj_name}_abs_error'] = abs(gnn_val - fea_val)
|
||||
errors[f'{obj_name}_pct_error'] = 100 * abs(gnn_val - fea_val) / max(fea_val, 0.01)
|
||||
|
||||
if verbose:
|
||||
print(f" FEA results:")
|
||||
print(f" 40vs20: {fea_objectives['rel_filtered_rms_40_vs_20']:.2f} nm "
|
||||
f"(GNN: {candidate.objectives['rel_filtered_rms_40_vs_20']:.2f}, "
|
||||
f"err: {errors['rel_filtered_rms_40_vs_20_pct_error']:.1f}%)")
|
||||
print(f" 60vs20: {fea_objectives['rel_filtered_rms_60_vs_20']:.2f} nm "
|
||||
f"(GNN: {candidate.objectives['rel_filtered_rms_60_vs_20']:.2f}, "
|
||||
f"err: {errors['rel_filtered_rms_60_vs_20_pct_error']:.1f}%)")
|
||||
print(f" mfg90: {fea_objectives['mfg_90_optician_workload']:.2f} nm "
|
||||
f"(GNN: {candidate.objectives['mfg_90_optician_workload']:.2f}, "
|
||||
f"err: {errors['mfg_90_optician_workload_pct_error']:.1f}%)")
|
||||
|
||||
results.append({
|
||||
'design': candidate.design_vars,
|
||||
'gnn_objectives': candidate.objectives,
|
||||
'fea_objectives': fea_objectives,
|
||||
'errors': errors,
|
||||
'solve_time': solve_time,
|
||||
'trial_num': trial_num,
|
||||
'status': 'success'
|
||||
})
|
||||
|
||||
# Summary
|
||||
if verbose:
|
||||
successful = [r for r in results if r['status'] == 'success']
|
||||
print(f"\n{'='*60}")
|
||||
print(f"VALIDATION SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Successful: {len(successful)}/{len(candidates)}")
|
||||
|
||||
if successful:
|
||||
avg_errors = {}
|
||||
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
avg_errors[obj] = np.mean([r['errors'][f'{obj}_pct_error'] for r in successful])
|
||||
|
||||
print(f"\nAverage prediction errors:")
|
||||
print(f" 40 vs 20: {avg_errors['rel_filtered_rms_40_vs_20']:.1f}%")
|
||||
print(f" 60 vs 20: {avg_errors['rel_filtered_rms_60_vs_20']:.1f}%")
|
||||
print(f" mfg 90: {avg_errors['mfg_90_optician_workload']:.1f}%")
|
||||
|
||||
return results
|
||||
|
||||
def save_validation_report(
|
||||
self,
|
||||
validation_results: List[Dict],
|
||||
output_path: Path
|
||||
):
|
||||
"""Save validation results to JSON file."""
|
||||
report = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'n_candidates': len(validation_results),
|
||||
'n_successful': len([r for r in validation_results if r['status'] == 'success']),
|
||||
'results': validation_results,
|
||||
}
|
||||
|
||||
# Compute summary statistics if we have successful results
|
||||
successful = [r for r in validation_results if r['status'] == 'success']
|
||||
if successful:
|
||||
avg_errors = {}
|
||||
for obj in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
errors = [r['errors'][f'{obj}_pct_error'] for r in successful]
|
||||
avg_errors[obj] = {
|
||||
'mean_pct': float(np.mean(errors)),
|
||||
'std_pct': float(np.std(errors)),
|
||||
'max_pct': float(np.max(errors)),
|
||||
}
|
||||
report['error_summary'] = avg_errors
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
print(f"Validation report saved to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Example usage of GNN optimizer."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='GNN-based Zernike optimization')
|
||||
parser.add_argument('checkpoint', type=Path, help='Path to GNN checkpoint')
|
||||
parser.add_argument('--config', type=Path, help='Path to optimization_config.json')
|
||||
parser.add_argument('--trials', type=int, default=5000, help='Number of GNN trials')
|
||||
parser.add_argument('--output', '-o', type=Path, help='Output results JSON')
|
||||
parser.add_argument('--top-n', type=int, default=20, help='Number of top candidates to show')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load optimizer
|
||||
print(f"Loading GNN from {args.checkpoint}...")
|
||||
optimizer = ZernikeGNNOptimizer.from_checkpoint(
|
||||
args.checkpoint,
|
||||
config_path=args.config
|
||||
)
|
||||
|
||||
# Run turbo optimization
|
||||
results = optimizer.turbo_optimize(n_trials=args.trials)
|
||||
|
||||
# Show top candidates
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TOP {args.top_n} CANDIDATES (by rel_filtered_rms_40_vs_20)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
top = results.get_best(n=args.top_n, objective='rel_filtered_rms_40_vs_20')
|
||||
for i, pred in enumerate(top):
|
||||
print(f"\n#{i+1}:")
|
||||
print(f" 40 vs 20: {pred.objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||||
print(f" 60 vs 20: {pred.objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
|
||||
print(f" mfg_90: {pred.objectives['mfg_90_optician_workload']:.2f} nm")
|
||||
|
||||
# Save results
|
||||
if args.output:
|
||||
results.save(args.output)
|
||||
print(f"\nResults saved to {args.output}")
|
||||
|
||||
# Show Pareto front
|
||||
pareto = results.get_pareto_front()
|
||||
print(f"\n{'='*60}")
|
||||
print(f"PARETO FRONT: {len(pareto)} designs")
|
||||
print(f"{'='*60}")
|
||||
|
||||
for i, pred in enumerate(pareto[:10]): # Show first 10
|
||||
print(f" [{i+1}] 40vs20={pred.objectives['rel_filtered_rms_40_vs_20']:.1f}, "
|
||||
f"60vs20={pred.objectives['rel_filtered_rms_60_vs_20']:.1f}, "
|
||||
f"mfg={pred.objectives['mfg_90_optician_workload']:.1f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
617
optimization_engine/gnn/polar_graph.py
Normal file
617
optimization_engine/gnn/polar_graph.py
Normal file
@@ -0,0 +1,617 @@
|
||||
"""
|
||||
Polar Mirror Graph for GNN Training
|
||||
====================================
|
||||
|
||||
This module creates a fixed polar grid graph structure for the mirror optical surface.
|
||||
The key insight is that the mirror has a fixed topology (circular annulus), so we can
|
||||
use a fixed graph structure regardless of FEA mesh variations.
|
||||
|
||||
Why Polar Grid?
|
||||
1. Matches mirror geometry (annulus)
|
||||
2. Same approach as extract_zernike_surface.py
|
||||
3. Enables mesh-independent training
|
||||
4. Edge structure respects radial/angular physics
|
||||
|
||||
Grid Structure:
|
||||
- n_radial points from r_inner to r_outer
|
||||
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
|
||||
- Total nodes = n_radial × n_angular
|
||||
- Edges connect radial neighbors and angular neighbors (wrap-around)
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
|
||||
# Interpolate FEA results to fixed grid
|
||||
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
|
||||
|
||||
# Get PyTorch Geometric data
|
||||
data = graph.to_pyg_data(z_disp_grid, design_vars)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import json
|
||||
|
||||
try:
|
||||
import torch
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
|
||||
from scipy.spatial import Delaunay
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
HAS_SCIPY = False
|
||||
|
||||
|
||||
class PolarMirrorGraph:
|
||||
"""
|
||||
Fixed polar grid graph for mirror optical surface.
|
||||
|
||||
This creates a mesh-independent graph structure that can be used for GNN training
|
||||
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
|
||||
|
||||
Attributes:
|
||||
n_nodes: Total number of nodes (n_radial × n_angular)
|
||||
r: Radial coordinates [n_nodes]
|
||||
theta: Angular coordinates [n_nodes]
|
||||
x: Cartesian X coordinates [n_nodes]
|
||||
y: Cartesian Y coordinates [n_nodes]
|
||||
edge_index: Graph edges [2, n_edges]
|
||||
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
r_inner: float = 100.0,
|
||||
r_outer: float = 650.0,
|
||||
n_radial: int = 50,
|
||||
n_angular: int = 60
|
||||
):
|
||||
"""
|
||||
Initialize polar grid graph.
|
||||
|
||||
Args:
|
||||
r_inner: Inner radius (central hole), mm
|
||||
r_outer: Outer radius, mm
|
||||
n_radial: Number of radial samples
|
||||
n_angular: Number of angular samples
|
||||
"""
|
||||
self.r_inner = r_inner
|
||||
self.r_outer = r_outer
|
||||
self.n_radial = n_radial
|
||||
self.n_angular = n_angular
|
||||
self.n_nodes = n_radial * n_angular
|
||||
|
||||
# Create polar grid coordinates
|
||||
r_1d = np.linspace(r_inner, r_outer, n_radial)
|
||||
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
|
||||
|
||||
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
|
||||
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
|
||||
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
|
||||
|
||||
# Flatten: radial index varies slowest
|
||||
self.r = R.flatten().astype(np.float32)
|
||||
self.theta = Theta.flatten().astype(np.float32)
|
||||
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
|
||||
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
|
||||
|
||||
# Build graph edges
|
||||
self.edge_index, self.edge_attr = self._build_polar_edges()
|
||||
|
||||
# Precompute normalization factors
|
||||
self._r_mean = (r_inner + r_outer) / 2
|
||||
self._r_std = (r_outer - r_inner) / 2
|
||||
|
||||
def _node_index(self, i_r: int, i_theta: int) -> int:
|
||||
"""Convert (radial_index, angular_index) to flat node index."""
|
||||
# Angular wraps around
|
||||
i_theta = i_theta % self.n_angular
|
||||
return i_r * self.n_angular + i_theta
|
||||
|
||||
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Create graph edges respecting polar topology.
|
||||
|
||||
Edge types:
|
||||
1. Radial edges: Connect adjacent radial rings
|
||||
2. Angular edges: Connect adjacent angular positions (with wrap-around)
|
||||
3. Diagonal edges: Connect diagonal neighbors for better message passing
|
||||
|
||||
Returns:
|
||||
edge_index: [2, n_edges] array of (source, target) pairs
|
||||
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
|
||||
"""
|
||||
edges = []
|
||||
edge_features = []
|
||||
|
||||
for i_r in range(self.n_radial):
|
||||
for i_theta in range(self.n_angular):
|
||||
node = self._node_index(i_r, i_theta)
|
||||
|
||||
# Radial neighbor (outward)
|
||||
if i_r < self.n_radial - 1:
|
||||
neighbor = self._node_index(i_r + 1, i_theta)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node]) # Bidirectional
|
||||
|
||||
# Edge features: (dr, dtheta, distance, relative_angle)
|
||||
dr = self.r[neighbor] - self.r[node]
|
||||
dtheta = 0.0
|
||||
dist = abs(dr)
|
||||
angle = 0.0 # Radial direction
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
|
||||
|
||||
# Angular neighbor (counterclockwise, with wrap-around)
|
||||
neighbor = self._node_index(i_r, i_theta + 1)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node]) # Bidirectional
|
||||
|
||||
# Edge features for angular edge
|
||||
dr = 0.0
|
||||
dtheta = 2 * np.pi / self.n_angular
|
||||
# Arc length at this radius
|
||||
dist = self.r[node] * dtheta
|
||||
angle = np.pi / 2 # Tangential direction
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
|
||||
|
||||
# Diagonal neighbor (outward + counterclockwise) for better connectivity
|
||||
if i_r < self.n_radial - 1:
|
||||
neighbor = self._node_index(i_r + 1, i_theta + 1)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node])
|
||||
|
||||
dr = self.r[neighbor] - self.r[node]
|
||||
dtheta = 2 * np.pi / self.n_angular
|
||||
dx = self.x[neighbor] - self.x[node]
|
||||
dy = self.y[neighbor] - self.y[node]
|
||||
dist = np.sqrt(dx**2 + dy**2)
|
||||
angle = np.arctan2(dy, dx)
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
|
||||
|
||||
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
|
||||
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
|
||||
|
||||
return edge_index, edge_attr
|
||||
|
||||
def get_node_features(self, normalized: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Get node features for GNN input.
|
||||
|
||||
Features: (r, theta, x, y) - polar and Cartesian coordinates
|
||||
|
||||
Args:
|
||||
normalized: If True, normalize features to ~[-1, 1] range
|
||||
|
||||
Returns:
|
||||
Node features [n_nodes, 4]
|
||||
"""
|
||||
if normalized:
|
||||
r_norm = (self.r - self._r_mean) / self._r_std
|
||||
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
|
||||
x_norm = self.x / self.r_outer
|
||||
y_norm = self.y / self.r_outer
|
||||
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
|
||||
else:
|
||||
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
|
||||
|
||||
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Get edge features for GNN input.
|
||||
|
||||
Features: (dr, dtheta, distance, angle)
|
||||
|
||||
Args:
|
||||
normalized: If True, normalize features
|
||||
|
||||
Returns:
|
||||
Edge features [n_edges, 4]
|
||||
"""
|
||||
if normalized:
|
||||
edge_attr = self.edge_attr.copy()
|
||||
edge_attr[:, 0] /= self._r_std # dr
|
||||
edge_attr[:, 1] /= np.pi # dtheta
|
||||
edge_attr[:, 2] /= self.r_outer # distance
|
||||
edge_attr[:, 3] /= np.pi # angle
|
||||
return edge_attr
|
||||
else:
|
||||
return self.edge_attr
|
||||
|
||||
def interpolate_from_mesh(
|
||||
self,
|
||||
mesh_coords: np.ndarray,
|
||||
mesh_values: np.ndarray,
|
||||
method: str = 'rbf'
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Interpolate FEA results from mesh nodes to fixed polar grid.
|
||||
|
||||
Args:
|
||||
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
|
||||
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
|
||||
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
|
||||
|
||||
Returns:
|
||||
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
|
||||
"""
|
||||
if not HAS_SCIPY:
|
||||
raise ImportError("scipy required for interpolation: pip install scipy")
|
||||
|
||||
# Use only X, Y coordinates
|
||||
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
|
||||
|
||||
# Handle multi-dimensional values
|
||||
values_1d = mesh_values.ndim == 1
|
||||
if values_1d:
|
||||
mesh_values = mesh_values.reshape(-1, 1)
|
||||
|
||||
# Target coordinates
|
||||
target_xy = np.column_stack([self.x, self.y])
|
||||
|
||||
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
|
||||
|
||||
for i in range(mesh_values.shape[1]):
|
||||
vals = mesh_values[:, i]
|
||||
|
||||
if method == 'rbf':
|
||||
# RBF interpolation - smooth, handles scattered data well
|
||||
interp = RBFInterpolator(
|
||||
xy, vals,
|
||||
kernel='thin_plate_spline',
|
||||
smoothing=0.0
|
||||
)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
elif method == 'linear':
|
||||
# Linear interpolation via Delaunay triangulation
|
||||
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
# Handle NaN (points outside convex hull) with nearest neighbor
|
||||
nan_mask = np.isnan(result[:, i])
|
||||
if nan_mask.any():
|
||||
from scipy.spatial import cKDTree
|
||||
tree = cKDTree(xy)
|
||||
_, idx = tree.query(target_xy[nan_mask])
|
||||
result[nan_mask, i] = vals[idx]
|
||||
|
||||
elif method == 'clough_tocher':
|
||||
# Clough-Tocher (C1 smooth) interpolation
|
||||
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
# Handle NaN
|
||||
nan_mask = np.isnan(result[:, i])
|
||||
if nan_mask.any():
|
||||
from scipy.spatial import cKDTree
|
||||
tree = cKDTree(xy)
|
||||
_, idx = tree.query(target_xy[nan_mask])
|
||||
result[nan_mask, i] = vals[idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown interpolation method: {method}")
|
||||
|
||||
return result[:, 0] if values_1d else result
|
||||
|
||||
def interpolate_field_data(
|
||||
self,
|
||||
field_data: Dict[str, Any],
|
||||
subcases: List[int] = [1, 2, 3, 4],
|
||||
method: str = 'linear' # Changed from 'rbf' - much faster
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Interpolate field data from extract_displacement_field() to polar grid.
|
||||
|
||||
Args:
|
||||
field_data: Output from extract_displacement_field()
|
||||
subcases: List of subcases to interpolate
|
||||
method: Interpolation method
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- z_displacement: [n_nodes, n_subcases] array
|
||||
- original_n_nodes: Number of FEA nodes
|
||||
"""
|
||||
mesh_coords = field_data['node_coords']
|
||||
z_disp_dict = field_data['z_displacement']
|
||||
|
||||
# Stack subcases
|
||||
z_disp_list = []
|
||||
for sc in subcases:
|
||||
if sc in z_disp_dict:
|
||||
z_disp_list.append(z_disp_dict[sc])
|
||||
else:
|
||||
raise KeyError(f"Subcase {sc} not found in field_data")
|
||||
|
||||
# [n_fea_nodes, n_subcases]
|
||||
z_disp_mesh = np.column_stack(z_disp_list)
|
||||
|
||||
# Interpolate to polar grid
|
||||
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
|
||||
|
||||
return {
|
||||
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
|
||||
'original_n_nodes': len(mesh_coords),
|
||||
}
|
||||
|
||||
def to_pyg_data(
|
||||
self,
|
||||
z_displacement: np.ndarray,
|
||||
design_vars: np.ndarray,
|
||||
objectives: Optional[Dict[str, float]] = None
|
||||
):
|
||||
"""
|
||||
Convert to PyTorch Geometric Data object.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes, n_subcases] displacement field
|
||||
design_vars: [n_design_vars] design parameters
|
||||
objectives: Optional dict of objective values (ground truth)
|
||||
|
||||
Returns:
|
||||
torch_geometric.data.Data object
|
||||
"""
|
||||
if not HAS_TORCH:
|
||||
raise ImportError("PyTorch required: pip install torch")
|
||||
|
||||
try:
|
||||
from torch_geometric.data import Data
|
||||
except ImportError:
|
||||
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
|
||||
|
||||
# Node features: (r, theta, x, y)
|
||||
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
|
||||
|
||||
# Edge index and features
|
||||
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
|
||||
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
|
||||
|
||||
# Target: Z-displacement field
|
||||
y = torch.tensor(z_displacement, dtype=torch.float32)
|
||||
|
||||
# Design variables (global feature)
|
||||
design = torch.tensor(design_vars, dtype=torch.float32)
|
||||
|
||||
data = Data(
|
||||
x=node_features,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
y=y,
|
||||
design=design,
|
||||
)
|
||||
|
||||
# Add objectives if provided
|
||||
if objectives:
|
||||
for key, value in objectives.items():
|
||||
setattr(data, key, torch.tensor([value], dtype=torch.float32))
|
||||
|
||||
return data
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""Save graph structure to JSON file."""
|
||||
path = Path(path)
|
||||
|
||||
data = {
|
||||
'r_inner': self.r_inner,
|
||||
'r_outer': self.r_outer,
|
||||
'n_radial': self.n_radial,
|
||||
'n_angular': self.n_angular,
|
||||
'n_nodes': self.n_nodes,
|
||||
'n_edges': self.edge_index.shape[1],
|
||||
}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
# Save arrays separately for efficiency
|
||||
np.savez_compressed(
|
||||
path.with_suffix('.npz'),
|
||||
r=self.r,
|
||||
theta=self.theta,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
edge_index=self.edge_index,
|
||||
edge_attr=self.edge_attr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> 'PolarMirrorGraph':
|
||||
"""Load graph structure from file."""
|
||||
path = Path(path)
|
||||
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
return cls(
|
||||
r_inner=data['r_inner'],
|
||||
r_outer=data['r_outer'],
|
||||
n_radial=data['n_radial'],
|
||||
n_angular=data['n_angular'],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PolarMirrorGraph("
|
||||
f"r=[{self.r_inner}, {self.r_outer}]mm, "
|
||||
f"grid={self.n_radial}×{self.n_angular}, "
|
||||
f"nodes={self.n_nodes}, "
|
||||
f"edges={self.edge_index.shape[1]})"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Convenience functions
|
||||
# =============================================================================
|
||||
|
||||
def create_mirror_dataset(
|
||||
study_dir: Path,
|
||||
polar_graph: Optional[PolarMirrorGraph] = None,
|
||||
subcases: List[int] = [1, 2, 3, 4],
|
||||
verbose: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Create GNN dataset from a study's gnn_data folder.
|
||||
|
||||
Args:
|
||||
study_dir: Path to study directory
|
||||
polar_graph: PolarMirrorGraph instance (created if None)
|
||||
subcases: Subcases to include
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
List of data dictionaries, each containing:
|
||||
- z_displacement: [n_nodes, n_subcases]
|
||||
- design_vars: [n_vars]
|
||||
- trial_number: int
|
||||
- original_n_nodes: int
|
||||
"""
|
||||
from optimization_engine.gnn.extract_displacement_field import load_field
|
||||
|
||||
study_dir = Path(study_dir)
|
||||
gnn_data_dir = study_dir / "gnn_data"
|
||||
|
||||
if not gnn_data_dir.exists():
|
||||
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
|
||||
|
||||
# Load index
|
||||
index_path = gnn_data_dir / "dataset_index.json"
|
||||
with open(index_path, 'r') as f:
|
||||
index = json.load(f)
|
||||
|
||||
if polar_graph is None:
|
||||
polar_graph = PolarMirrorGraph()
|
||||
|
||||
dataset = []
|
||||
|
||||
for trial_num, trial_info in index['trials'].items():
|
||||
if trial_info.get('status') != 'success':
|
||||
continue
|
||||
|
||||
trial_dir = study_dir / trial_info['trial_dir']
|
||||
|
||||
# Find field file
|
||||
field_path = None
|
||||
for ext in ['.h5', '.npz']:
|
||||
candidate = trial_dir / f"displacement_field{ext}"
|
||||
if candidate.exists():
|
||||
field_path = candidate
|
||||
break
|
||||
|
||||
if field_path is None:
|
||||
if verbose:
|
||||
print(f"[WARN] No field file for trial {trial_num}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Load field data
|
||||
field_data = load_field(field_path)
|
||||
|
||||
# Interpolate to polar grid
|
||||
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
|
||||
|
||||
# Get design parameters
|
||||
params = trial_info.get('params', {})
|
||||
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
|
||||
|
||||
dataset.append({
|
||||
'z_displacement': interp_result['z_displacement'],
|
||||
'design_vars': design_vars,
|
||||
'design_names': list(params.keys()) if params else [],
|
||||
'trial_number': int(trial_num),
|
||||
'original_n_nodes': interp_result['original_n_nodes'],
|
||||
})
|
||||
|
||||
if verbose:
|
||||
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']} → {polar_graph.n_nodes} nodes")
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f"[ERR] Trial {trial_num}: {e}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
|
||||
parser.add_argument('--test', action='store_true', help='Run basic tests')
|
||||
parser.add_argument('--study', type=Path, help='Create dataset from study')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.test:
|
||||
print("="*60)
|
||||
print("TESTING PolarMirrorGraph")
|
||||
print("="*60)
|
||||
|
||||
# Create graph
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"\n{graph}")
|
||||
|
||||
# Check node features
|
||||
node_feat = graph.get_node_features(normalized=True)
|
||||
print(f"\nNode features shape: {node_feat.shape}")
|
||||
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
|
||||
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
|
||||
|
||||
# Check edge features
|
||||
edge_feat = graph.get_edge_features(normalized=True)
|
||||
print(f"\nEdge features shape: {edge_feat.shape}")
|
||||
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
|
||||
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
|
||||
|
||||
# Test interpolation with synthetic data
|
||||
print("\n--- Testing Interpolation ---")
|
||||
|
||||
# Create fake mesh data (random points in annulus)
|
||||
np.random.seed(42)
|
||||
n_mesh = 5000
|
||||
r_mesh = np.random.uniform(100, 650, n_mesh)
|
||||
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
|
||||
x_mesh = r_mesh * np.cos(theta_mesh)
|
||||
y_mesh = r_mesh * np.sin(theta_mesh)
|
||||
mesh_coords = np.column_stack([x_mesh, y_mesh])
|
||||
|
||||
# Synthetic displacement: smooth function
|
||||
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
|
||||
|
||||
# Interpolate
|
||||
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
|
||||
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
|
||||
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
|
||||
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
|
||||
elif args.study:
|
||||
# Create dataset from study
|
||||
print(f"Creating dataset from: {args.study}")
|
||||
|
||||
graph = PolarMirrorGraph()
|
||||
dataset = create_mirror_dataset(args.study, polar_graph=graph)
|
||||
|
||||
print(f"\nDataset: {len(dataset)} samples")
|
||||
if dataset:
|
||||
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
|
||||
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
|
||||
|
||||
else:
|
||||
# Default: just show info
|
||||
graph = PolarMirrorGraph()
|
||||
print(graph)
|
||||
print(f"\nNode features: {graph.get_node_features().shape}")
|
||||
print(f"Edge index: {graph.edge_index.shape}")
|
||||
print(f"Edge features: {graph.edge_attr.shape}")
|
||||
37
optimization_engine/gnn/test_field_extraction.py
Normal file
37
optimization_engine/gnn/test_field_extraction.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Quick test script for displacement field extraction."""
|
||||
import h5py
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Test file
|
||||
h5_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/gnn_data/trial_0091/displacement_field.h5")
|
||||
|
||||
print(f"Testing: {h5_path}")
|
||||
print(f"Exists: {h5_path.exists()}")
|
||||
|
||||
if h5_path.exists():
|
||||
with h5py.File(h5_path, 'r') as f:
|
||||
print(f"\nDatasets in file: {list(f.keys())}")
|
||||
|
||||
node_coords = f['node_coords'][:]
|
||||
node_ids = f['node_ids'][:]
|
||||
|
||||
print(f"\nTotal nodes: {len(node_ids)}")
|
||||
|
||||
# Calculate radial position
|
||||
r = np.sqrt(node_coords[:, 0]**2 + node_coords[:, 1]**2)
|
||||
print(f"Radial range: [{r.min():.1f}, {r.max():.1f}] mm")
|
||||
print(f"Z range: [{node_coords[:, 2].min():.1f}, {node_coords[:, 2].max():.1f}] mm")
|
||||
|
||||
# Check nodes in optical surface range (100-650 mm radius)
|
||||
surface_mask = (r >= 100) & (r <= 650)
|
||||
print(f"Nodes in r=[100, 650]: {np.sum(surface_mask)}")
|
||||
|
||||
# Check subcases
|
||||
subcases = [k for k in f.keys() if k.startswith("subcase_")]
|
||||
print(f"Subcases: {subcases}")
|
||||
|
||||
if subcases:
|
||||
for sc in subcases:
|
||||
disp = f[sc][:]
|
||||
print(f" {sc}: Z-disp range [{disp.min():.4f}, {disp.max():.4f}] mm")
|
||||
35
optimization_engine/gnn/test_new_extraction.py
Normal file
35
optimization_engine/gnn/test_new_extraction.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Test the fixed extraction function directly on OP2."""
|
||||
import sys
|
||||
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from optimization_engine.gnn.extract_displacement_field import extract_displacement_field
|
||||
|
||||
# Test direct extraction from OP2
|
||||
op2_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/2_iterations/iter91/assy_m1_assyfem1_sim1-solution_1.op2")
|
||||
|
||||
print(f"Testing extraction from: {op2_path.name}")
|
||||
print(f"Exists: {op2_path.exists()}")
|
||||
|
||||
if op2_path.exists():
|
||||
field_data = extract_displacement_field(op2_path, r_inner=100.0, r_outer=650.0)
|
||||
|
||||
print(f"\n=== EXTRACTION RESULT ===")
|
||||
print(f"Total surface nodes: {len(field_data['node_ids'])}")
|
||||
|
||||
coords = field_data['node_coords']
|
||||
r = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
|
||||
print(f"Radial range: [{r.min():.1f}, {r.max():.1f}] mm")
|
||||
print(f"Z range: [{coords[:, 2].min():.1f}, {coords[:, 2].max():.1f}] mm")
|
||||
|
||||
print(f"\nSubcases: {list(field_data['z_displacement'].keys())}")
|
||||
for sc, disp in field_data['z_displacement'].items():
|
||||
nan_count = np.sum(np.isnan(disp))
|
||||
if nan_count == 0:
|
||||
print(f" Subcase {sc}: Z-disp range [{disp.min():.6f}, {disp.max():.6f}] mm")
|
||||
else:
|
||||
valid = disp[~np.isnan(disp)]
|
||||
print(f" Subcase {sc}: {nan_count}/{len(disp)} NaN values, valid range: [{valid.min():.6f}, {valid.max():.6f}]")
|
||||
else:
|
||||
print("OP2 file not found!")
|
||||
108
optimization_engine/gnn/test_polar_graph.py
Normal file
108
optimization_engine/gnn/test_polar_graph.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Test PolarMirrorGraph with actual V11 data."""
|
||||
import sys
|
||||
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
|
||||
from optimization_engine.gnn.extract_displacement_field import load_field
|
||||
|
||||
# Test 1: Basic graph construction
|
||||
print("="*60)
|
||||
print("TEST 1: Graph Construction")
|
||||
print("="*60)
|
||||
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"\n{graph}")
|
||||
|
||||
node_feat = graph.get_node_features(normalized=True)
|
||||
edge_feat = graph.get_edge_features(normalized=True)
|
||||
|
||||
print(f"\nNode features: {node_feat.shape}")
|
||||
print(f" r normalized: [{node_feat[:, 0].min():.3f}, {node_feat[:, 0].max():.3f}]")
|
||||
print(f" theta normalized: [{node_feat[:, 1].min():.3f}, {node_feat[:, 1].max():.3f}]")
|
||||
print(f" x normalized: [{node_feat[:, 2].min():.3f}, {node_feat[:, 2].max():.3f}]")
|
||||
print(f" y normalized: [{node_feat[:, 3].min():.3f}, {node_feat[:, 3].max():.3f}]")
|
||||
|
||||
print(f"\nEdge features: {edge_feat.shape}")
|
||||
print(f" Edges per node: {edge_feat.shape[0] / graph.n_nodes:.1f}")
|
||||
|
||||
# Test 2: Load actual V11 field data and interpolate
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: Interpolation from V11 Data")
|
||||
print("="*60)
|
||||
|
||||
field_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/gnn_data/trial_0091/displacement_field.h5")
|
||||
|
||||
if field_path.exists():
|
||||
field_data = load_field(field_path)
|
||||
|
||||
print(f"\nLoaded field data:")
|
||||
print(f" FEA nodes: {len(field_data['node_ids'])}")
|
||||
print(f" Subcases: {list(field_data['z_displacement'].keys())}")
|
||||
|
||||
# Interpolate to polar grid
|
||||
result = graph.interpolate_field_data(field_data, subcases=[1, 2, 3, 4])
|
||||
z_grid = result['z_displacement']
|
||||
|
||||
print(f"\nInterpolation result:")
|
||||
print(f" Shape: {z_grid.shape} (expected: {graph.n_nodes} × 4)")
|
||||
print(f" NaN count: {np.sum(np.isnan(z_grid))}")
|
||||
|
||||
for i, sc in enumerate([1, 2, 3, 4]):
|
||||
disp = z_grid[:, i]
|
||||
print(f" Subcase {sc}: [{disp.min():.6f}, {disp.max():.6f}] mm")
|
||||
|
||||
# Test relative deformation computation
|
||||
print("\n--- Relative Deformations (like Zernike extraction) ---")
|
||||
disp_90 = z_grid[:, 0] # Subcase 1 = 90°
|
||||
disp_20 = z_grid[:, 1] # Subcase 2 = 20° (reference)
|
||||
disp_40 = z_grid[:, 2] # Subcase 3 = 40°
|
||||
disp_60 = z_grid[:, 3] # Subcase 4 = 60°
|
||||
|
||||
rel_40_vs_20 = disp_40 - disp_20
|
||||
rel_60_vs_20 = disp_60 - disp_20
|
||||
rel_90_vs_20 = disp_90 - disp_20
|
||||
|
||||
print(f" 40° - 20°: [{rel_40_vs_20.min():.6f}, {rel_40_vs_20.max():.6f}] mm, RMS={np.std(rel_40_vs_20)*1e6:.2f} nm")
|
||||
print(f" 60° - 20°: [{rel_60_vs_20.min():.6f}, {rel_60_vs_20.max():.6f}] mm, RMS={np.std(rel_60_vs_20)*1e6:.2f} nm")
|
||||
print(f" 90° - 20°: [{rel_90_vs_20.min():.6f}, {rel_90_vs_20.max():.6f}] mm, RMS={np.std(rel_90_vs_20)*1e6:.2f} nm")
|
||||
else:
|
||||
print(f"Field file not found: {field_path}")
|
||||
|
||||
# Test 3: Create full dataset from V11
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: Create Dataset from V11")
|
||||
print("="*60)
|
||||
|
||||
study_dir = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11")
|
||||
if (study_dir / "gnn_data").exists():
|
||||
dataset = create_mirror_dataset(study_dir, polar_graph=graph, verbose=True)
|
||||
|
||||
print(f"\n--- Dataset Summary ---")
|
||||
print(f"Total samples: {len(dataset)}")
|
||||
|
||||
if dataset:
|
||||
# Check consistency
|
||||
shapes = [d['z_displacement'].shape for d in dataset]
|
||||
unique_shapes = set(shapes)
|
||||
print(f"Unique shapes: {unique_shapes}")
|
||||
|
||||
# Design variable info
|
||||
n_vars = len(dataset[0]['design_vars'])
|
||||
print(f"Design variables: {n_vars}")
|
||||
if dataset[0]['design_names']:
|
||||
print(f" Names: {dataset[0]['design_names'][:3]}...")
|
||||
|
||||
# Stack for statistics
|
||||
all_z = np.stack([d['z_displacement'] for d in dataset])
|
||||
print(f"\nAll data shape: {all_z.shape}")
|
||||
print(f" Per-subcase ranges:")
|
||||
for i in range(4):
|
||||
print(f" Subcase {i+1}: [{all_z[:,:,i].min():.6f}, {all_z[:,:,i].max():.6f}] mm")
|
||||
else:
|
||||
print(f"No gnn_data folder found in {study_dir}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("✓ All tests completed!")
|
||||
print("="*60)
|
||||
600
optimization_engine/gnn/train_zernike_gnn.py
Normal file
600
optimization_engine/gnn/train_zernike_gnn.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Training Pipeline for ZernikeGNN
|
||||
=================================
|
||||
|
||||
This module provides the complete training pipeline for the Zernike GNN surrogate.
|
||||
|
||||
Training Flow:
|
||||
1. Load displacement field data from gnn_data/ folders
|
||||
2. Interpolate to fixed polar grid
|
||||
3. Normalize inputs (design vars) and outputs (displacements)
|
||||
4. Train with multi-task loss (field + objectives)
|
||||
5. Validate on held-out data
|
||||
6. Save best model checkpoint
|
||||
|
||||
Usage:
|
||||
# Command line
|
||||
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
|
||||
|
||||
# Python API
|
||||
from optimization_engine.gnn.train_zernike_gnn import ZernikeGNNTrainer
|
||||
|
||||
trainer = ZernikeGNNTrainer(['V11', 'V12'])
|
||||
trainer.train(epochs=200)
|
||||
trainer.save_checkpoint('model.pt')
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
|
||||
from optimization_engine.gnn.zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model
|
||||
from optimization_engine.gnn.differentiable_zernike import ZernikeObjectiveLayer, ZernikeRMSLoss
|
||||
|
||||
|
||||
class MirrorDataset(Dataset):
|
||||
"""PyTorch Dataset for mirror displacement fields."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_list: List[Dict[str, Any]],
|
||||
design_mean: Optional[torch.Tensor] = None,
|
||||
design_std: Optional[torch.Tensor] = None,
|
||||
disp_scale: float = 1e6 # mm → μm for numerical stability
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_list: Output from create_mirror_dataset()
|
||||
design_mean: Mean for design normalization (computed if None)
|
||||
design_std: Std for design normalization (computed if None)
|
||||
disp_scale: Scale factor for displacements
|
||||
"""
|
||||
self.data_list = data_list
|
||||
self.disp_scale = disp_scale
|
||||
|
||||
# Stack all design variables for normalization
|
||||
all_designs = np.stack([d['design_vars'] for d in data_list])
|
||||
|
||||
if design_mean is None:
|
||||
self.design_mean = torch.tensor(np.mean(all_designs, axis=0), dtype=torch.float32)
|
||||
else:
|
||||
self.design_mean = design_mean
|
||||
|
||||
if design_std is None:
|
||||
self.design_std = torch.tensor(np.std(all_designs, axis=0) + 1e-6, dtype=torch.float32)
|
||||
else:
|
||||
self.design_std = design_std
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
item = self.data_list[idx]
|
||||
|
||||
# Normalize design variables
|
||||
design = torch.tensor(item['design_vars'], dtype=torch.float32)
|
||||
design_norm = (design - self.design_mean) / self.design_std
|
||||
|
||||
# Scale displacements for numerical stability
|
||||
z_disp = torch.tensor(item['z_displacement'], dtype=torch.float32)
|
||||
z_disp_scaled = z_disp * self.disp_scale
|
||||
|
||||
return {
|
||||
'design': design_norm,
|
||||
'design_raw': design,
|
||||
'z_displacement': z_disp_scaled,
|
||||
'trial_number': item['trial_number'],
|
||||
}
|
||||
|
||||
|
||||
class ZernikeGNNTrainer:
|
||||
"""
|
||||
Complete training pipeline for ZernikeGNN.
|
||||
|
||||
Handles:
|
||||
- Data loading and preprocessing
|
||||
- Model initialization
|
||||
- Training loop with validation
|
||||
- Checkpointing
|
||||
- Metrics tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
study_versions: List[str],
|
||||
base_dir: Optional[Path] = None,
|
||||
model_type: str = 'full',
|
||||
hidden_dim: int = 128,
|
||||
n_layers: int = 6,
|
||||
device: str = 'auto'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
study_versions: List of study versions (e.g., ['V11', 'V12'])
|
||||
base_dir: Base Atomizer directory
|
||||
model_type: 'full' or 'lite'
|
||||
hidden_dim: Model hidden dimension
|
||||
n_layers: Number of message passing layers
|
||||
device: 'cpu', 'cuda', or 'auto'
|
||||
"""
|
||||
if base_dir is None:
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
|
||||
self.base_dir = Path(base_dir)
|
||||
self.study_versions = study_versions
|
||||
|
||||
# Determine device
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
print(f"[TRAINER] Device: {self.device}", flush=True)
|
||||
|
||||
# Create polar graph (fixed structure)
|
||||
self.polar_graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"[TRAINER] Polar graph: {self.polar_graph.n_nodes} nodes, {self.polar_graph.edge_index.shape[1]} edges", flush=True)
|
||||
|
||||
# Prepare graph tensors
|
||||
self.node_features = torch.tensor(
|
||||
self.polar_graph.get_node_features(normalized=True),
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
self.edge_index = torch.tensor(
|
||||
self.polar_graph.edge_index,
|
||||
dtype=torch.long
|
||||
).to(self.device)
|
||||
|
||||
self.edge_attr = torch.tensor(
|
||||
self.polar_graph.get_edge_features(normalized=True),
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Load data
|
||||
self._load_data()
|
||||
|
||||
# Create model
|
||||
self.model_config = {
|
||||
'model_type': model_type,
|
||||
'n_design_vars': len(self.train_dataset.data_list[0]['design_vars']),
|
||||
'n_subcases': 4,
|
||||
'hidden_dim': hidden_dim,
|
||||
'n_layers': n_layers,
|
||||
}
|
||||
|
||||
self.model = create_model(**self.model_config).to(self.device)
|
||||
print(f"[TRAINER] Model: {self.model.__class__.__name__} with {sum(p.numel() for p in self.model.parameters()):,} parameters", flush=True)
|
||||
|
||||
# Objective layer for evaluation
|
||||
self.objective_layer = ZernikeObjectiveLayer(self.polar_graph, n_modes=50)
|
||||
|
||||
# Training state
|
||||
self.best_val_loss = float('inf')
|
||||
self.history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
|
||||
|
||||
def _load_data(self):
|
||||
"""Load and prepare training data from studies."""
|
||||
all_data = []
|
||||
|
||||
for version in self.study_versions:
|
||||
study_dir = self.base_dir / "studies" / f"m1_mirror_adaptive_{version}"
|
||||
|
||||
if not study_dir.exists():
|
||||
print(f"[WARN] Study not found: {study_dir}", flush=True)
|
||||
continue
|
||||
|
||||
print(f"[TRAINER] Loading data from {study_dir.name}...", flush=True)
|
||||
dataset = create_mirror_dataset(study_dir, polar_graph=self.polar_graph, verbose=True)
|
||||
print(f"[TRAINER] Loaded {len(dataset)} samples", flush=True)
|
||||
all_data.extend(dataset)
|
||||
|
||||
if not all_data:
|
||||
raise ValueError("No data loaded!")
|
||||
|
||||
print(f"[TRAINER] Total samples: {len(all_data)}", flush=True)
|
||||
|
||||
# Train/val split (80/20)
|
||||
np.random.seed(42)
|
||||
indices = np.random.permutation(len(all_data))
|
||||
n_train = int(0.8 * len(all_data))
|
||||
|
||||
train_data = [all_data[i] for i in indices[:n_train]]
|
||||
val_data = [all_data[i] for i in indices[n_train:]]
|
||||
|
||||
print(f"[TRAINER] Train: {len(train_data)}, Val: {len(val_data)}", flush=True)
|
||||
|
||||
# Create datasets
|
||||
self.train_dataset = MirrorDataset(train_data)
|
||||
self.val_dataset = MirrorDataset(
|
||||
val_data,
|
||||
design_mean=self.train_dataset.design_mean,
|
||||
design_std=self.train_dataset.design_std
|
||||
)
|
||||
|
||||
# Store normalization params for inference
|
||||
self.design_mean = self.train_dataset.design_mean
|
||||
self.design_std = self.train_dataset.design_std
|
||||
self.disp_scale = self.train_dataset.disp_scale
|
||||
|
||||
def train(
|
||||
self,
|
||||
epochs: int = 200,
|
||||
lr: float = 1e-3,
|
||||
weight_decay: float = 1e-5,
|
||||
batch_size: int = 4,
|
||||
field_weight: float = 1.0,
|
||||
patience: int = 50,
|
||||
verbose: bool = True
|
||||
):
|
||||
"""
|
||||
Train the GNN model.
|
||||
|
||||
Args:
|
||||
epochs: Number of training epochs
|
||||
lr: Learning rate
|
||||
weight_decay: Weight decay for regularization
|
||||
batch_size: Training batch size
|
||||
field_weight: Weight for field loss
|
||||
patience: Early stopping patience
|
||||
verbose: Print training progress
|
||||
"""
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
self.train_dataset, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
self.val_dataset, batch_size=batch_size, shuffle=False
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
|
||||
# Training loop
|
||||
no_improve = 0
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Training
|
||||
self.model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Move to device
|
||||
design = batch['design'].to(self.device)
|
||||
z_disp_true = batch['z_displacement'].to(self.device)
|
||||
|
||||
# Forward pass for each sample in batch
|
||||
batch_loss = 0.0
|
||||
for i in range(design.size(0)):
|
||||
z_disp_pred = self.model(
|
||||
self.node_features,
|
||||
self.edge_index,
|
||||
self.edge_attr,
|
||||
design[i]
|
||||
)
|
||||
|
||||
# MSE loss on displacement field
|
||||
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
|
||||
batch_loss = batch_loss + loss
|
||||
|
||||
batch_loss = batch_loss / design.size(0)
|
||||
batch_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += batch_loss.item()
|
||||
|
||||
train_loss /= len(train_loader)
|
||||
scheduler.step()
|
||||
|
||||
# Validation
|
||||
val_loss, val_metrics = self._validate(val_loader)
|
||||
|
||||
# Track history
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['val_r2'].append(val_metrics.get('r2_mean', 0))
|
||||
|
||||
# Early stopping
|
||||
if val_loss < self.best_val_loss:
|
||||
self.best_val_loss = val_loss
|
||||
self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
|
||||
no_improve = 0
|
||||
else:
|
||||
no_improve += 1
|
||||
|
||||
if verbose and epoch % 10 == 0:
|
||||
print(f"[Epoch {epoch:3d}] Train: {train_loss:.6f}, Val: {val_loss:.6f}, "
|
||||
f"R²: {val_metrics.get('r2_mean', 0):.4f}, LR: {scheduler.get_last_lr()[0]:.2e}", flush=True)
|
||||
|
||||
if no_improve >= patience:
|
||||
print(f"[TRAINER] Early stopping at epoch {epoch}", flush=True)
|
||||
break
|
||||
|
||||
# Restore best model
|
||||
self.model.load_state_dict(self.best_model_state)
|
||||
print(f"[TRAINER] Training complete. Best val loss: {self.best_val_loss:.6f}", flush=True)
|
||||
|
||||
def _validate(self, val_loader: DataLoader) -> Tuple[float, Dict[str, float]]:
|
||||
"""Run validation and compute metrics."""
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
all_pred = []
|
||||
all_true = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
design = batch['design'].to(self.device)
|
||||
z_disp_true = batch['z_displacement'].to(self.device)
|
||||
|
||||
for i in range(design.size(0)):
|
||||
z_disp_pred = self.model(
|
||||
self.node_features,
|
||||
self.edge_index,
|
||||
self.edge_attr,
|
||||
design[i]
|
||||
)
|
||||
|
||||
loss = F.mse_loss(z_disp_pred, z_disp_true[i])
|
||||
val_loss += loss.item()
|
||||
|
||||
all_pred.append(z_disp_pred.cpu())
|
||||
all_true.append(z_disp_true[i].cpu())
|
||||
|
||||
val_loss /= len(self.val_dataset)
|
||||
|
||||
# Compute R² for each subcase
|
||||
all_pred = torch.stack(all_pred) # [n_val, n_nodes, 4]
|
||||
all_true = torch.stack(all_true)
|
||||
|
||||
r2_per_subcase = []
|
||||
for sc in range(4):
|
||||
pred_flat = all_pred[:, :, sc].flatten()
|
||||
true_flat = all_true[:, :, sc].flatten()
|
||||
|
||||
ss_res = ((true_flat - pred_flat) ** 2).sum()
|
||||
ss_tot = ((true_flat - true_flat.mean()) ** 2).sum()
|
||||
r2 = 1 - ss_res / (ss_tot + 1e-8)
|
||||
r2_per_subcase.append(r2.item())
|
||||
|
||||
metrics = {
|
||||
'r2_mean': np.mean(r2_per_subcase),
|
||||
'r2_per_subcase': r2_per_subcase,
|
||||
}
|
||||
|
||||
return val_loss, metrics
|
||||
|
||||
def evaluate_objectives(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate objective prediction accuracy on validation set.
|
||||
|
||||
Returns:
|
||||
Dictionary with per-objective metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
obj_pred_all = {k: [] for k in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']}
|
||||
obj_true_all = {k: [] for k in obj_pred_all}
|
||||
|
||||
# Move objective layer to CPU for now (small dataset)
|
||||
with torch.no_grad():
|
||||
for i in range(len(self.val_dataset)):
|
||||
item = self.val_dataset[i]
|
||||
|
||||
design = item['design'].to(self.device)
|
||||
z_disp_true = item['z_displacement'] # Already scaled
|
||||
|
||||
# Predict
|
||||
z_disp_pred = self.model(
|
||||
self.node_features,
|
||||
self.edge_index,
|
||||
self.edge_attr,
|
||||
design
|
||||
).cpu()
|
||||
|
||||
# Unscale for objective computation
|
||||
z_disp_pred_mm = z_disp_pred / self.disp_scale
|
||||
z_disp_true_mm = z_disp_true / self.disp_scale
|
||||
|
||||
# Compute objectives
|
||||
obj_pred = self.objective_layer(z_disp_pred_mm)
|
||||
obj_true = self.objective_layer(z_disp_true_mm)
|
||||
|
||||
for k in obj_pred_all:
|
||||
obj_pred_all[k].append(obj_pred[k].item())
|
||||
obj_true_all[k].append(obj_true[k].item())
|
||||
|
||||
# Compute metrics per objective
|
||||
results = {}
|
||||
for k in obj_pred_all:
|
||||
pred = np.array(obj_pred_all[k])
|
||||
true = np.array(obj_true_all[k])
|
||||
|
||||
mae = np.mean(np.abs(pred - true))
|
||||
mape = np.mean(np.abs(pred - true) / (np.abs(true) + 1e-6)) * 100
|
||||
|
||||
ss_res = np.sum((true - pred) ** 2)
|
||||
ss_tot = np.sum((true - np.mean(true)) ** 2)
|
||||
r2 = 1 - ss_res / (ss_tot + 1e-8)
|
||||
|
||||
results[k] = {
|
||||
'mae': mae,
|
||||
'mape': mape,
|
||||
'r2': r2,
|
||||
'pred_range': [pred.min(), pred.max()],
|
||||
'true_range': [true.min(), true.max()],
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def save_checkpoint(self, path: Path) -> None:
|
||||
"""Save model checkpoint."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'config': self.model_config,
|
||||
'design_mean': self.design_mean,
|
||||
'design_std': self.design_std,
|
||||
'disp_scale': self.disp_scale,
|
||||
'history': self.history,
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'study_versions': self.study_versions,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
torch.save(checkpoint, path)
|
||||
print(f"[TRAINER] Saved checkpoint to {path}", flush=True)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, path: Path, device: str = 'auto') -> 'ZernikeGNNTrainer':
|
||||
"""Load trainer from checkpoint."""
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
|
||||
# Create trainer with same config
|
||||
trainer = cls(
|
||||
study_versions=checkpoint['study_versions'],
|
||||
model_type=checkpoint['config']['model_type'],
|
||||
hidden_dim=checkpoint['config']['hidden_dim'],
|
||||
n_layers=checkpoint['config']['n_layers'],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Load model weights
|
||||
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Restore normalization
|
||||
trainer.design_mean = checkpoint['design_mean']
|
||||
trainer.design_std = checkpoint['design_std']
|
||||
trainer.disp_scale = checkpoint['disp_scale']
|
||||
|
||||
# Restore history
|
||||
trainer.history = checkpoint['history']
|
||||
trainer.best_val_loss = checkpoint['best_val_loss']
|
||||
|
||||
return trainer
|
||||
|
||||
def predict(self, design_vars: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""
|
||||
Make prediction for new design.
|
||||
|
||||
Args:
|
||||
design_vars: Dictionary of design parameter values
|
||||
|
||||
Returns:
|
||||
Dictionary with displacement field and objectives
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
# Convert to tensor
|
||||
design_names = self.train_dataset.data_list[0]['design_names']
|
||||
design = torch.tensor(
|
||||
[design_vars[name] for name in design_names],
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
# Normalize
|
||||
design_norm = (design - self.design_mean) / self.design_std
|
||||
|
||||
with torch.no_grad():
|
||||
z_disp_scaled = self.model(
|
||||
self.node_features,
|
||||
self.edge_index,
|
||||
self.edge_attr,
|
||||
design_norm.to(self.device)
|
||||
).cpu()
|
||||
|
||||
# Unscale
|
||||
z_disp_mm = z_disp_scaled / self.disp_scale
|
||||
|
||||
# Compute objectives
|
||||
objectives = self.objective_layer(z_disp_mm)
|
||||
|
||||
return {
|
||||
'z_displacement': z_disp_mm.numpy(),
|
||||
'objectives': {k: v.item() for k, v in objectives.items()},
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train ZernikeGNN surrogate')
|
||||
parser.add_argument('studies', nargs='+', help='Study versions (e.g., V11 V12)')
|
||||
parser.add_argument('--epochs', type=int, default=200, help='Training epochs')
|
||||
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
|
||||
parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
|
||||
parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension')
|
||||
parser.add_argument('--n-layers', type=int, default=6, help='Message passing layers')
|
||||
parser.add_argument('--model-type', choices=['full', 'lite'], default='full')
|
||||
parser.add_argument('--output', '-o', type=Path, help='Output checkpoint path')
|
||||
parser.add_argument('--device', default='auto', help='Device (cpu, cuda, auto)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create trainer
|
||||
print("="*60, flush=True)
|
||||
print("ZERNIKE GNN TRAINING", flush=True)
|
||||
print("="*60, flush=True)
|
||||
|
||||
trainer = ZernikeGNNTrainer(
|
||||
study_versions=args.studies,
|
||||
model_type=args.model_type,
|
||||
hidden_dim=args.hidden_dim,
|
||||
n_layers=args.n_layers,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train(
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
# Evaluate objectives
|
||||
print("\n--- Objective Prediction Evaluation ---", flush=True)
|
||||
obj_results = trainer.evaluate_objectives()
|
||||
for k, v in obj_results.items():
|
||||
print(f"\n{k}:", flush=True)
|
||||
print(f" MAE: {v['mae']:.2f} nm", flush=True)
|
||||
print(f" MAPE: {v['mape']:.1f}%", flush=True)
|
||||
print(f" R²: {v['r2']:.4f}", flush=True)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output:
|
||||
output_path = args.output
|
||||
else:
|
||||
output_path = Path("zernike_gnn_checkpoint.pt")
|
||||
|
||||
trainer.save_checkpoint(output_path)
|
||||
|
||||
print("\n" + "="*60, flush=True)
|
||||
print("✓ Training complete!", flush=True)
|
||||
print("="*60, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
582
optimization_engine/gnn/zernike_gnn.py
Normal file
582
optimization_engine/gnn/zernike_gnn.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Zernike GNN Model for Mirror Surface Deformation Prediction
|
||||
============================================================
|
||||
|
||||
This module implements a Graph Neural Network specifically designed for predicting
|
||||
mirror surface displacement fields from design parameters. The key innovation is
|
||||
using design-conditioned message passing on a polar grid graph.
|
||||
|
||||
Architecture:
|
||||
Design Variables [11]
|
||||
│
|
||||
▼
|
||||
Design Encoder [11 → 128]
|
||||
│
|
||||
└──────────────────┐
|
||||
│
|
||||
Node Features │
|
||||
[r, θ, x, y] │
|
||||
│ │
|
||||
▼ │
|
||||
Node Encoder │
|
||||
[4 → 128] │
|
||||
│ │
|
||||
└─────────┬────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ Design-Conditioned │
|
||||
│ Message Passing (× 6) │
|
||||
│ │
|
||||
│ • Polar-aware edges │
|
||||
│ • Design modulates messages │
|
||||
│ • Residual connections │
|
||||
└─────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
Per-Node Decoder [128 → 4]
|
||||
│
|
||||
▼
|
||||
Z-Displacement Field [3000, 4]
|
||||
(one value per node per subcase)
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.zernike_gnn import ZernikeGNN
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
graph = PolarMirrorGraph()
|
||||
model = ZernikeGNN(n_design_vars=11, n_subcases=4)
|
||||
|
||||
# Forward pass
|
||||
z_disp = model(
|
||||
node_features=graph.get_node_features(),
|
||||
edge_index=graph.edge_index,
|
||||
edge_attr=graph.get_edge_features(),
|
||||
design_vars=design_tensor
|
||||
)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from torch_geometric.nn import MessagePassing
|
||||
HAS_PYG = True
|
||||
except ImportError:
|
||||
HAS_PYG = False
|
||||
MessagePassing = nn.Module # Fallback for type hints
|
||||
|
||||
|
||||
class DesignConditionedConv(MessagePassing if HAS_PYG else nn.Module):
|
||||
"""
|
||||
Message passing layer conditioned on global design parameters.
|
||||
|
||||
This layer propagates information through the polar graph while
|
||||
conditioning on design parameters. The design embedding modulates
|
||||
how messages flow between nodes.
|
||||
|
||||
Key insight: Design parameters affect the stiffness distribution
|
||||
in the mirror support structure. This layer learns how those changes
|
||||
propagate spatially through the optical surface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
design_channels: int,
|
||||
edge_channels: int = 4,
|
||||
aggr: str = 'mean'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_channels: Input node feature dimension
|
||||
out_channels: Output node feature dimension
|
||||
design_channels: Design embedding dimension
|
||||
edge_channels: Edge feature dimension
|
||||
aggr: Aggregation method ('mean', 'sum', 'max')
|
||||
"""
|
||||
if HAS_PYG:
|
||||
super().__init__(aggr=aggr)
|
||||
else:
|
||||
super().__init__()
|
||||
self.aggr = aggr
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# Message network: source node + target node + design + edge
|
||||
msg_input_dim = 2 * in_channels + design_channels + edge_channels
|
||||
self.message_net = nn.Sequential(
|
||||
nn.Linear(msg_input_dim, out_channels * 2),
|
||||
nn.LayerNorm(out_channels * 2),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(out_channels * 2, out_channels),
|
||||
)
|
||||
|
||||
# Update network: combines aggregated messages with original features
|
||||
self.update_net = nn.Sequential(
|
||||
nn.Linear(in_channels + out_channels, out_channels),
|
||||
nn.LayerNorm(out_channels),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
# Design gate: allows design to modulate message importance
|
||||
self.design_gate = nn.Sequential(
|
||||
nn.Linear(design_channels, out_channels),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
design_embed: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with design conditioning.
|
||||
|
||||
Args:
|
||||
x: Node features [n_nodes, in_channels]
|
||||
edge_index: Graph connectivity [2, n_edges]
|
||||
edge_attr: Edge features [n_edges, edge_channels]
|
||||
design_embed: Design embedding [design_channels]
|
||||
|
||||
Returns:
|
||||
Updated node features [n_nodes, out_channels]
|
||||
"""
|
||||
if HAS_PYG:
|
||||
# Use PyG's message passing
|
||||
out = self.propagate(
|
||||
edge_index, x=x, edge_attr=edge_attr, design=design_embed
|
||||
)
|
||||
else:
|
||||
# Fallback implementation without PyG
|
||||
out = self._manual_propagate(x, edge_index, edge_attr, design_embed)
|
||||
|
||||
# Apply design-based gating
|
||||
gate = self.design_gate(design_embed)
|
||||
out = out * gate
|
||||
|
||||
return out
|
||||
|
||||
def message(
|
||||
self,
|
||||
x_i: torch.Tensor,
|
||||
x_j: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
design: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute messages from source (j) to target (i) nodes.
|
||||
|
||||
Args:
|
||||
x_i: Target node features [n_edges, in_channels]
|
||||
x_j: Source node features [n_edges, in_channels]
|
||||
edge_attr: Edge features [n_edges, edge_channels]
|
||||
design: Design embedding, broadcast to edges
|
||||
|
||||
Returns:
|
||||
Messages [n_edges, out_channels]
|
||||
"""
|
||||
# Broadcast design to all edges
|
||||
design_broadcast = design.expand(x_i.size(0), -1)
|
||||
|
||||
# Concatenate all inputs
|
||||
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
||||
|
||||
return self.message_net(msg_input)
|
||||
|
||||
def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Update node features with aggregated messages.
|
||||
|
||||
Args:
|
||||
aggr_out: Aggregated messages [n_nodes, out_channels]
|
||||
x: Original node features [n_nodes, in_channels]
|
||||
|
||||
Returns:
|
||||
Updated node features [n_nodes, out_channels]
|
||||
"""
|
||||
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
||||
|
||||
def _manual_propagate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
design: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Fallback message passing without PyG."""
|
||||
row, col = edge_index # row = target, col = source
|
||||
|
||||
# Gather features
|
||||
x_i = x[row] # Target features
|
||||
x_j = x[col] # Source features
|
||||
|
||||
# Compute messages
|
||||
design_broadcast = design.expand(x_i.size(0), -1)
|
||||
msg_input = torch.cat([x_i, x_j, design_broadcast, edge_attr], dim=-1)
|
||||
messages = self.message_net(msg_input)
|
||||
|
||||
# Aggregate (mean)
|
||||
n_nodes = x.size(0)
|
||||
aggr_out = torch.zeros(n_nodes, messages.size(-1), device=x.device)
|
||||
count = torch.zeros(n_nodes, 1, device=x.device)
|
||||
|
||||
aggr_out.scatter_add_(0, row.unsqueeze(-1).expand_as(messages), messages)
|
||||
count.scatter_add_(0, row.unsqueeze(-1), torch.ones_like(row, dtype=torch.float).unsqueeze(-1))
|
||||
count = count.clamp(min=1)
|
||||
aggr_out = aggr_out / count
|
||||
|
||||
# Update
|
||||
return self.update_net(torch.cat([x, aggr_out], dim=-1))
|
||||
|
||||
|
||||
class ZernikeGNN(nn.Module):
|
||||
"""
|
||||
Graph Neural Network for mirror surface displacement prediction.
|
||||
|
||||
This model learns to predict Z-displacement fields for all 4 gravity
|
||||
subcases from 11 design parameters. It uses a fixed polar grid graph
|
||||
structure and design-conditioned message passing.
|
||||
|
||||
The key advantages over MLP:
|
||||
1. Spatial awareness through message passing
|
||||
2. Design conditioning modulates spatial information flow
|
||||
3. Predicts full field (enabling correct relative computation)
|
||||
4. Respects physics: smooth fields, radial/angular structure
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_design_vars: int = 11,
|
||||
n_subcases: int = 4,
|
||||
hidden_dim: int = 128,
|
||||
n_layers: int = 6,
|
||||
node_feat_dim: int = 4,
|
||||
edge_feat_dim: int = 4,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_design_vars: Number of design parameters (11 for mirror)
|
||||
n_subcases: Number of gravity subcases (4: 90°, 20°, 40°, 60°)
|
||||
hidden_dim: Hidden layer dimension
|
||||
n_layers: Number of message passing layers
|
||||
node_feat_dim: Node feature dimension (r, theta, x, y)
|
||||
edge_feat_dim: Edge feature dimension (dr, dtheta, dist, angle)
|
||||
dropout: Dropout rate
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_design_vars = n_design_vars
|
||||
self.n_subcases = n_subcases
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_layers = n_layers
|
||||
|
||||
# === Design Encoder ===
|
||||
# Maps design parameters to hidden space
|
||||
self.design_encoder = nn.Sequential(
|
||||
nn.Linear(n_design_vars, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
)
|
||||
|
||||
# === Node Encoder ===
|
||||
# Maps polar coordinates to hidden space
|
||||
self.node_encoder = nn.Sequential(
|
||||
nn.Linear(node_feat_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
)
|
||||
|
||||
# === Edge Encoder ===
|
||||
# Maps edge features (dr, dtheta, distance, angle) to hidden space
|
||||
edge_hidden = hidden_dim // 2
|
||||
self.edge_encoder = nn.Sequential(
|
||||
nn.Linear(edge_feat_dim, edge_hidden),
|
||||
nn.SiLU(),
|
||||
nn.Linear(edge_hidden, edge_hidden),
|
||||
)
|
||||
|
||||
# === Message Passing Layers ===
|
||||
self.conv_layers = nn.ModuleList([
|
||||
DesignConditionedConv(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
design_channels=hidden_dim,
|
||||
edge_channels=edge_hidden,
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
|
||||
# Layer norms for residual connections
|
||||
self.layer_norms = nn.ModuleList([
|
||||
nn.LayerNorm(hidden_dim) for _ in range(n_layers)
|
||||
])
|
||||
|
||||
# === Displacement Decoder ===
|
||||
# Predicts Z-displacement for each subcase
|
||||
self.displacement_decoder = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim // 2, n_subcases),
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize weights with Xavier/Glorot initialization."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
node_features: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
design_vars: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass: design parameters → displacement field.
|
||||
|
||||
Args:
|
||||
node_features: [n_nodes, 4] - (r, theta, x, y) normalized
|
||||
edge_index: [2, n_edges] - graph connectivity
|
||||
edge_attr: [n_edges, 4] - edge features normalized
|
||||
design_vars: [n_design_vars] or [batch, n_design_vars]
|
||||
|
||||
Returns:
|
||||
z_displacement: [n_nodes, n_subcases] - Z-disp per subcase
|
||||
or [batch, n_nodes, n_subcases] if batched
|
||||
"""
|
||||
# Handle batched vs single design
|
||||
is_batched = design_vars.dim() == 2
|
||||
if not is_batched:
|
||||
design_vars = design_vars.unsqueeze(0) # [1, n_design_vars]
|
||||
|
||||
batch_size = design_vars.size(0)
|
||||
n_nodes = node_features.size(0)
|
||||
|
||||
# Encode inputs
|
||||
design_h = self.design_encoder(design_vars) # [batch, hidden]
|
||||
node_h = self.node_encoder(node_features) # [n_nodes, hidden]
|
||||
edge_h = self.edge_encoder(edge_attr) # [n_edges, edge_hidden]
|
||||
|
||||
# Process each batch item
|
||||
outputs = []
|
||||
for b in range(batch_size):
|
||||
h = node_h.clone() # Start fresh for each design
|
||||
|
||||
# Message passing with residual connections
|
||||
for conv, norm in zip(self.conv_layers, self.layer_norms):
|
||||
h_new = conv(h, edge_index, edge_h, design_h[b])
|
||||
h = norm(h + h_new) # Residual + LayerNorm
|
||||
|
||||
# Decode to displacement
|
||||
z_disp = self.displacement_decoder(h) # [n_nodes, n_subcases]
|
||||
outputs.append(z_disp)
|
||||
|
||||
# Stack outputs
|
||||
if is_batched:
|
||||
return torch.stack(outputs, dim=0) # [batch, n_nodes, n_subcases]
|
||||
else:
|
||||
return outputs[0] # [n_nodes, n_subcases]
|
||||
|
||||
def count_parameters(self) -> int:
|
||||
"""Count trainable parameters."""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class ZernikeGNNLite(nn.Module):
|
||||
"""
|
||||
Lightweight version of ZernikeGNN for faster training/inference.
|
||||
|
||||
Uses fewer layers and smaller hidden dimension, suitable for
|
||||
initial experiments or when training data is limited.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_design_vars: int = 11,
|
||||
n_subcases: int = 4,
|
||||
hidden_dim: int = 64,
|
||||
n_layers: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_subcases = n_subcases
|
||||
|
||||
# Simpler design encoder
|
||||
self.design_encoder = nn.Sequential(
|
||||
nn.Linear(n_design_vars, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
|
||||
# Simpler node encoder
|
||||
self.node_encoder = nn.Sequential(
|
||||
nn.Linear(4, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
|
||||
# Edge encoder
|
||||
self.edge_encoder = nn.Linear(4, hidden_dim // 2)
|
||||
|
||||
# Message passing
|
||||
self.conv_layers = nn.ModuleList([
|
||||
DesignConditionedConv(hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
|
||||
# Decoder
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim // 2, n_subcases),
|
||||
)
|
||||
|
||||
def forward(self, node_features, edge_index, edge_attr, design_vars):
|
||||
"""Forward pass."""
|
||||
design_h = self.design_encoder(design_vars)
|
||||
node_h = self.node_encoder(node_features)
|
||||
edge_h = self.edge_encoder(edge_attr)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
node_h = node_h + conv(node_h, edge_index, edge_h, design_h)
|
||||
|
||||
return self.decoder(node_h)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utility functions
|
||||
# =============================================================================
|
||||
|
||||
def create_model(
|
||||
n_design_vars: int = 11,
|
||||
n_subcases: int = 4,
|
||||
model_type: str = 'full',
|
||||
**kwargs
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create GNN model.
|
||||
|
||||
Args:
|
||||
n_design_vars: Number of design parameters
|
||||
n_subcases: Number of subcases
|
||||
model_type: 'full' or 'lite'
|
||||
**kwargs: Additional arguments passed to model
|
||||
|
||||
Returns:
|
||||
GNN model instance
|
||||
"""
|
||||
if model_type == 'lite':
|
||||
return ZernikeGNNLite(n_design_vars, n_subcases, **kwargs)
|
||||
else:
|
||||
return ZernikeGNN(n_design_vars, n_subcases, **kwargs)
|
||||
|
||||
|
||||
def load_model(checkpoint_path: str, device: str = 'cpu') -> nn.Module:
|
||||
"""
|
||||
Load trained model from checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to .pt checkpoint file
|
||||
device: Device to load model to
|
||||
|
||||
Returns:
|
||||
Loaded model in eval mode
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# Get model config
|
||||
config = checkpoint.get('config', {})
|
||||
model_type = config.get('model_type', 'full')
|
||||
|
||||
# Create model
|
||||
model = create_model(
|
||||
n_design_vars=config.get('n_design_vars', 11),
|
||||
n_subcases=config.get('n_subcases', 4),
|
||||
model_type=model_type,
|
||||
hidden_dim=config.get('hidden_dim', 128),
|
||||
n_layers=config.get('n_layers', 6),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Testing
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("="*60)
|
||||
print("Testing ZernikeGNN")
|
||||
print("="*60)
|
||||
|
||||
# Create model
|
||||
model = ZernikeGNN(n_design_vars=11, n_subcases=4, hidden_dim=128, n_layers=6)
|
||||
print(f"\nModel: {model.__class__.__name__}")
|
||||
print(f"Parameters: {model.count_parameters():,}")
|
||||
|
||||
# Create dummy inputs
|
||||
n_nodes = 3000
|
||||
n_edges = 17760
|
||||
|
||||
node_features = torch.randn(n_nodes, 4)
|
||||
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||
edge_attr = torch.randn(n_edges, 4)
|
||||
design_vars = torch.randn(11)
|
||||
|
||||
# Forward pass
|
||||
print("\n--- Single Forward Pass ---")
|
||||
with torch.no_grad():
|
||||
output = model(node_features, edge_index, edge_attr, design_vars)
|
||||
print(f"Input design: {design_vars.shape}")
|
||||
print(f"Output shape: {output.shape}")
|
||||
print(f"Output range: [{output.min():.6f}, {output.max():.6f}]")
|
||||
|
||||
# Batched forward pass
|
||||
print("\n--- Batched Forward Pass ---")
|
||||
batch_design = torch.randn(8, 11)
|
||||
with torch.no_grad():
|
||||
output_batch = model(node_features, edge_index, edge_attr, batch_design)
|
||||
print(f"Batch design: {batch_design.shape}")
|
||||
print(f"Batch output: {output_batch.shape}")
|
||||
|
||||
# Test lite model
|
||||
print("\n--- Lite Model ---")
|
||||
model_lite = ZernikeGNNLite(n_design_vars=11, n_subcases=4)
|
||||
print(f"Lite parameters: {sum(p.numel() for p in model_lite.parameters()):,}")
|
||||
|
||||
with torch.no_grad():
|
||||
output_lite = model_lite(node_features, edge_index, edge_attr, design_vars)
|
||||
print(f"Lite output shape: {output_lite.shape}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("✓ All tests passed!")
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user