Files
Atomizer/optimization_engine/gnn/differentiable_zernike.py
Antoine 96b196de58 feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization
and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II).

## GNN Surrogate Module (optimization_engine/gnn/)

New module for Graph Neural Network surrogate prediction of mirror deformations:

- `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure
- `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing
- `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives
- `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss
- `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour)
- `extract_displacement_field.py`: OP2 to HDF5 field extraction
- `backfill_field_data.py`: Extract fields from existing FEA trials

Key innovation: Design-conditioned convolutions that modulate message passing
based on structural design parameters, enabling accurate field prediction.

## M1 Mirror Studies

### V12: GNN Field Prediction + FEA Validation
- Zernike GNN trained on V10/V11 FEA data (238 samples)
- Turbo mode: 5000 GNN predictions → top candidates → FEA validation
- Calibration workflow for GNN-to-FEA error correction
- Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py

### V13: Pure NSGA-II FEA (Ground Truth)
- Seeds 217 FEA trials from V11+V12
- Pure multi-objective NSGA-II without any surrogate
- Establishes ground-truth Pareto front for GNN accuracy evaluation
- Narrowed blank_backface_angle range to [4.0, 5.0]

## Documentation Updates

- SYS_14: Added Zernike GNN section with architecture diagrams
- CLAUDE.md: Added GNN module reference and quick start
- V13 README: Study documentation with seeding strategy

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00

545 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Differentiable Zernike Fitting Layer
=====================================
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
The key innovation is putting Zernike fitting INSIDE the neural network for
end-to-end training.
Why Differentiable Zernike?
Current MLP approach:
design → MLP → coefficients (learn 200 outputs independently)
GNN + Differentiable Zernike:
design → GNN → displacement field → Zernike fit → coefficients
Differentiable! Gradients flow back
This allows the network to learn:
1. Spatially coherent displacement fields
2. Fields that produce accurate Zernike coefficients
3. Correct relative deformation computation
Components:
1. DifferentiableZernikeFit - Fits coefficients from displacement field
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
Usage:
from optimization_engine.gnn.differentiable_zernike import (
DifferentiableZernikeFit,
ZernikeObjectiveLayer
)
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph()
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
# In forward pass:
z_disp = gnn_model(...) # [n_nodes, 4]
objectives = objective_layer(z_disp) # Dict with RMS values
"""
import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
from pathlib import Path
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
"""
Compute Zernike polynomial for Noll index j.
Uses the standard Noll indexing convention:
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
Args:
j: Noll index (1-based)
r: Radial coordinates (normalized to [0, 1])
theta: Angular coordinates (radians)
Returns:
Zernike polynomial values at (r, theta)
"""
# Convert Noll index to (n, m)
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
m_sum = j - n * (n + 1) // 2 - 1
if n % 2 == 0:
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
else:
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
if (n - m) % 2 == 1:
m = -m
# Compute radial polynomial R_n^|m|(r)
R = np.zeros_like(r)
m_abs = abs(m)
for k in range((n - m_abs) // 2 + 1):
coef = ((-1) ** k * math.factorial(n - k) /
(math.factorial(k) *
math.factorial((n + m_abs) // 2 - k) *
math.factorial((n - m_abs) // 2 - k)))
R += coef * r ** (n - 2 * k)
# Combine with angular part
if m >= 0:
Z = R * np.cos(m_abs * theta)
else:
Z = R * np.sin(m_abs * theta)
# Normalization factor
if m == 0:
norm = np.sqrt(n + 1)
else:
norm = np.sqrt(2 * (n + 1))
return norm * Z
def build_zernike_matrix(
r: np.ndarray,
theta: np.ndarray,
n_modes: int = 50,
r_max: float = None
) -> np.ndarray:
"""
Build Zernike basis matrix for a set of points.
Args:
r: Radial coordinates
theta: Angular coordinates
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
r_max: Maximum radius for normalization (if None, use max(r))
Returns:
Z: [n_points, n_modes] Zernike basis matrix
"""
if r_max is None:
r_max = r.max()
r_norm = r / r_max
n_points = len(r)
Z = np.zeros((n_points, n_modes), dtype=np.float64)
for j in range(1, n_modes + 1):
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
return Z
class DifferentiableZernikeFit(nn.Module):
"""
GPU-accelerated, differentiable Zernike polynomial fitting.
This layer fits Zernike coefficients to a displacement field using
least squares. The key insight is that least squares has a closed-form
solution: c = (Z^T Z)^{-1} Z^T @ values
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
matrix multiply, which is fully differentiable.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes to fit
regularization: Tikhonov regularization for stability
"""
super().__init__()
self.n_modes = n_modes
# Get coordinates from polar graph
r = polar_graph.r
theta = polar_graph.theta
r_max = polar_graph.r_outer
# Build Zernike basis matrix [n_nodes, n_modes]
Z = build_zernike_matrix(r, theta, n_modes, r_max)
# Convert to tensor and register as buffer
Z_tensor = torch.tensor(Z, dtype=torch.float32)
self.register_buffer('Z', Z_tensor)
# Precompute pseudo-inverse with regularization
# c = (Z^T Z + λI)^{-1} Z^T @ values
ZtZ = Z_tensor.T @ Z_tensor
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
ZtZ_inv = torch.inverse(ZtZ_reg)
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
self.register_buffer('pseudo_inverse', pseudo_inv)
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
"""
Fit Zernike coefficients to displacement field.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
Returns:
coefficients: [n_modes] or [n_subcases, n_modes]
"""
if z_displacement.dim() == 1:
# Single field: [n_nodes] → [n_modes]
return self.pseudo_inverse @ z_displacement
else:
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
# Transpose, multiply, transpose back
return (self.pseudo_inverse @ z_displacement).T
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
"""
Reconstruct displacement field from coefficients.
Args:
coefficients: [n_modes] or [n_subcases, n_modes]
Returns:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
"""
if coefficients.dim() == 1:
return self.Z @ coefficients
else:
return self.Z @ coefficients.T
def fit_and_residual(
self,
z_displacement: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fit coefficients and return residual.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
Returns:
coefficients, residual
"""
coeffs = self.forward(z_displacement)
reconstruction = self.reconstruct(coeffs)
residual = z_displacement - reconstruction
return coeffs, residual
class ZernikeObjectiveLayer(nn.Module):
"""
Compute Zernike-based optimization objectives from displacement field.
This layer replicates the exact computation done in FEA post-processing:
1. Compute relative displacement (e.g., 40° - 20°)
2. Convert to wavefront error (× 2 for reflection, mm → nm)
3. Fit Zernike and remove low-order terms
4. Compute filtered RMS
The computation is fully differentiable, allowing end-to-end training
with objective-based loss.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes
regularization: Regularization for Zernike fitting
"""
super().__init__()
self.n_modes = n_modes
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
# Precompute Zernike basis subsets for filtering
Z = self.zernike_fit.Z
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
self.register_buffer('Z_j1_to_j4', Z[:, :4])
# Only J1-J3 for manufacturing objective
self.register_buffer('Z_j1_to_j3', Z[:, :3])
# Store node count
self.n_nodes = Z.shape[0]
def forward(
self,
z_disp_all_subcases: torch.Tensor,
return_all: bool = False
) -> Dict[str, torch.Tensor]:
"""
Compute Zernike objectives from displacement field.
Args:
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
return_all: If True, return additional diagnostics
Returns:
Dictionary with objective values:
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
"""
# Unpack subcases
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
# === Objective 1: Relative filtered RMS 40° vs 20° ===
disp_rel_40 = disp_40 - disp_20
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
# === Objective 2: Relative filtered RMS 60° vs 20° ===
disp_rel_60 = disp_60 - disp_20
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
disp_rel_90 = disp_90 - disp_20
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
result = {
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
'mfg_90_optician_workload': mfg_90,
}
if return_all:
# Include intermediate values for debugging
result['wfe_rel_40'] = wfe_rel_40
result['wfe_rel_60'] = wfe_rel_60
result['wfe_rel_90'] = wfe_rel_90
return result
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
This is the standard filtered RMS for optical performance.
"""
# Fit low-order coefficients using precomputed pseudo-inverse
# c = (Z^T Z)^{-1} Z^T @ wfe
Z_low = self.Z_j1_to_j4
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
# Reconstruct low-order surface
wfe_low = Z_low @ coeffs_low
# Residual (high-order content)
wfe_filtered = wfe - wfe_low
# RMS
return torch.sqrt(torch.mean(wfe_filtered ** 2))
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing only J1-J3 (piston, tip, tilt).
This keeps defocus (J4), which is harder to polish out - represents
actual manufacturing workload.
"""
Z_low = self.Z_j1_to_j3
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
wfe_low = Z_low @ coeffs_low
wfe_filtered = wfe - wfe_low
return torch.sqrt(torch.mean(wfe_filtered ** 2))
class ZernikeRMSLoss(nn.Module):
"""
Combined loss function for GNN training.
This loss combines:
1. Displacement field reconstruction loss (MSE)
2. Objective prediction loss (relative Zernike RMS)
The multi-task loss helps the network learn both accurate
displacement fields AND accurate objective predictions.
"""
def __init__(
self,
polar_graph,
field_weight: float = 1.0,
objective_weight: float = 0.1,
n_modes: int = 50
):
"""
Args:
polar_graph: PolarMirrorGraph instance
field_weight: Weight for displacement field loss
objective_weight: Weight for objective loss
n_modes: Number of Zernike modes
"""
super().__init__()
self.field_weight = field_weight
self.objective_weight = objective_weight
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
def forward(
self,
z_disp_pred: torch.Tensor,
z_disp_true: torch.Tensor,
objectives_true: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Compute combined loss.
Args:
z_disp_pred: Predicted displacement [n_nodes, 4]
z_disp_true: Ground truth displacement [n_nodes, 4]
objectives_true: Optional dict of true objective values
Returns:
total_loss, loss_components dict
"""
# Field reconstruction loss
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
# Scale field loss to account for small displacement values
# Displacements are ~1e-4 mm, so MSE is ~1e-8
loss_field_scaled = loss_field * 1e8
components = {
'loss_field': loss_field_scaled,
}
total_loss = self.field_weight * loss_field_scaled
# Objective loss (if ground truth provided)
if objectives_true is not None and self.objective_weight > 0:
objectives_pred = self.objective_layer(z_disp_pred)
loss_obj = 0.0
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
if key in objectives_true:
pred = objectives_pred[key]
true = objectives_true[key]
# Relative error squared
rel_err = ((pred - true) / (true + 1e-6)) ** 2
loss_obj = loss_obj + rel_err
components[f'loss_{key}'] = rel_err
components['loss_objectives'] = loss_obj
total_loss = total_loss + self.objective_weight * loss_obj
components['total_loss'] = total_loss
return total_loss, components
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
print("="*60)
print("Testing Differentiable Zernike Layer")
print("="*60)
# Create polar graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\nPolar Graph: {graph.n_nodes} nodes")
# Create Zernike fitting layer
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
print(f" Z matrix: {zernike_fit.Z.shape}")
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
# Test with synthetic displacement
print("\n--- Test Zernike Fitting ---")
# Create synthetic displacement (defocus + astigmatism pattern)
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
theta = torch.tensor(graph.theta, dtype=torch.float32)
# Defocus (J4) + Astigmatism (J5)
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
# Fit coefficients
coeffs = zernike_fit(synthetic_disp)
print(f"Fitted coefficients shape: {coeffs.shape}")
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
# Reconstruct
recon = zernike_fit.reconstruct(coeffs)
error = (synthetic_disp - recon).abs()
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
# Test with multiple subcases
print("\n--- Test Multi-Subcase ---")
z_disp_multi = torch.stack([
synthetic_disp,
synthetic_disp * 0.5,
synthetic_disp * 0.7,
synthetic_disp * 0.9,
], dim=1) # [n_nodes, 4]
coeffs_multi = zernike_fit(z_disp_multi)
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
# Test objective layer
print("\n--- Test Objective Layer ---")
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
objectives = objective_layer(z_disp_multi)
print("Computed objectives:")
for key, val in objectives.items():
print(f" {key}: {val.item():.2f} nm")
# Test gradient flow
print("\n--- Test Gradient Flow ---")
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
objectives = objective_layer(z_disp_grad)
loss = objectives['rel_filtered_rms_40_vs_20']
loss.backward()
print(f"Gradient shape: {z_disp_grad.grad.shape}")
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
print("✓ Gradients flow through Zernike fitting!")
# Test loss function
print("\n--- Test Loss Function ---")
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
print(f"Total loss: {total_loss.item():.6f}")
for key, val in components.items():
if isinstance(val, torch.Tensor):
print(f" {key}: {val.item():.6f}")
print("\n" + "="*60)
print("✓ All tests passed!")
print("="*60)