Files
Atomizer/optimization_engine/gnn/test_polar_graph.py
Antoine 96b196de58 feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00

109 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Test PolarMirrorGraph with actual V11 data."""
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
import numpy as np
from pathlib import Path
from optimization_engine.gnn.polar_graph import PolarMirrorGraph, create_mirror_dataset
from optimization_engine.gnn.extract_displacement_field import load_field
# Test 1: Basic graph construction
print("="*60)
print("TEST 1: Graph Construction")
print("="*60)
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\n{graph}")
node_feat = graph.get_node_features(normalized=True)
edge_feat = graph.get_edge_features(normalized=True)
print(f"\nNode features: {node_feat.shape}")
print(f" r normalized: [{node_feat[:, 0].min():.3f}, {node_feat[:, 0].max():.3f}]")
print(f" theta normalized: [{node_feat[:, 1].min():.3f}, {node_feat[:, 1].max():.3f}]")
print(f" x normalized: [{node_feat[:, 2].min():.3f}, {node_feat[:, 2].max():.3f}]")
print(f" y normalized: [{node_feat[:, 3].min():.3f}, {node_feat[:, 3].max():.3f}]")
print(f"\nEdge features: {edge_feat.shape}")
print(f" Edges per node: {edge_feat.shape[0] / graph.n_nodes:.1f}")
# Test 2: Load actual V11 field data and interpolate
print("\n" + "="*60)
print("TEST 2: Interpolation from V11 Data")
print("="*60)
field_path = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11/gnn_data/trial_0091/displacement_field.h5")
if field_path.exists():
field_data = load_field(field_path)
print(f"\nLoaded field data:")
print(f" FEA nodes: {len(field_data['node_ids'])}")
print(f" Subcases: {list(field_data['z_displacement'].keys())}")
# Interpolate to polar grid
result = graph.interpolate_field_data(field_data, subcases=[1, 2, 3, 4])
z_grid = result['z_displacement']
print(f"\nInterpolation result:")
print(f" Shape: {z_grid.shape} (expected: {graph.n_nodes} × 4)")
print(f" NaN count: {np.sum(np.isnan(z_grid))}")
for i, sc in enumerate([1, 2, 3, 4]):
disp = z_grid[:, i]
print(f" Subcase {sc}: [{disp.min():.6f}, {disp.max():.6f}] mm")
# Test relative deformation computation
print("\n--- Relative Deformations (like Zernike extraction) ---")
disp_90 = z_grid[:, 0] # Subcase 1 = 90°
disp_20 = z_grid[:, 1] # Subcase 2 = 20° (reference)
disp_40 = z_grid[:, 2] # Subcase 3 = 40°
disp_60 = z_grid[:, 3] # Subcase 4 = 60°
rel_40_vs_20 = disp_40 - disp_20
rel_60_vs_20 = disp_60 - disp_20
rel_90_vs_20 = disp_90 - disp_20
print(f" 40° - 20°: [{rel_40_vs_20.min():.6f}, {rel_40_vs_20.max():.6f}] mm, RMS={np.std(rel_40_vs_20)*1e6:.2f} nm")
print(f" 60° - 20°: [{rel_60_vs_20.min():.6f}, {rel_60_vs_20.max():.6f}] mm, RMS={np.std(rel_60_vs_20)*1e6:.2f} nm")
print(f" 90° - 20°: [{rel_90_vs_20.min():.6f}, {rel_90_vs_20.max():.6f}] mm, RMS={np.std(rel_90_vs_20)*1e6:.2f} nm")
else:
print(f"Field file not found: {field_path}")
# Test 3: Create full dataset from V11
print("\n" + "="*60)
print("TEST 3: Create Dataset from V11")
print("="*60)
study_dir = Path("C:/Users/Antoine/Atomizer/studies/m1_mirror_adaptive_V11")
if (study_dir / "gnn_data").exists():
dataset = create_mirror_dataset(study_dir, polar_graph=graph, verbose=True)
print(f"\n--- Dataset Summary ---")
print(f"Total samples: {len(dataset)}")
if dataset:
# Check consistency
shapes = [d['z_displacement'].shape for d in dataset]
unique_shapes = set(shapes)
print(f"Unique shapes: {unique_shapes}")
# Design variable info
n_vars = len(dataset[0]['design_vars'])
print(f"Design variables: {n_vars}")
if dataset[0]['design_names']:
print(f" Names: {dataset[0]['design_names'][:3]}...")
# Stack for statistics
all_z = np.stack([d['z_displacement'] for d in dataset])
print(f"\nAll data shape: {all_z.shape}")
print(f" Per-subcase ranges:")
for i in range(4):
print(f" Subcase {i+1}: [{all_z[:,:,i].min():.6f}, {all_z[:,:,i].max():.6f}] mm")
else:
print(f"No gnn_data folder found in {study_dir}")
print("\n" + "="*60)
print("✓ All tests completed!")
print("="*60)