Files
Atomizer/optimization_engine/gnn/differentiable_zernike.py

545 lines
17 KiB
Python
Raw Normal View History

feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 08:44:04 -05:00
"""
Differentiable Zernike Fitting Layer
=====================================
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
The key innovation is putting Zernike fitting INSIDE the neural network for
end-to-end training.
Why Differentiable Zernike?
Current MLP approach:
design MLP coefficients (learn 200 outputs independently)
GNN + Differentiable Zernike:
design GNN displacement field Zernike fit coefficients
Differentiable! Gradients flow back
This allows the network to learn:
1. Spatially coherent displacement fields
2. Fields that produce accurate Zernike coefficients
3. Correct relative deformation computation
Components:
1. DifferentiableZernikeFit - Fits coefficients from displacement field
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
Usage:
from optimization_engine.gnn.differentiable_zernike import (
DifferentiableZernikeFit,
ZernikeObjectiveLayer
)
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
graph = PolarMirrorGraph()
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
# In forward pass:
z_disp = gnn_model(...) # [n_nodes, 4]
objectives = objective_layer(z_disp) # Dict with RMS values
"""
import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
from pathlib import Path
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
"""
Compute Zernike polynomial for Noll index j.
Uses the standard Noll indexing convention:
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
Args:
j: Noll index (1-based)
r: Radial coordinates (normalized to [0, 1])
theta: Angular coordinates (radians)
Returns:
Zernike polynomial values at (r, theta)
"""
# Convert Noll index to (n, m)
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
m_sum = j - n * (n + 1) // 2 - 1
if n % 2 == 0:
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
else:
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
if (n - m) % 2 == 1:
m = -m
# Compute radial polynomial R_n^|m|(r)
R = np.zeros_like(r)
m_abs = abs(m)
for k in range((n - m_abs) // 2 + 1):
coef = ((-1) ** k * math.factorial(n - k) /
(math.factorial(k) *
math.factorial((n + m_abs) // 2 - k) *
math.factorial((n - m_abs) // 2 - k)))
R += coef * r ** (n - 2 * k)
# Combine with angular part
if m >= 0:
Z = R * np.cos(m_abs * theta)
else:
Z = R * np.sin(m_abs * theta)
# Normalization factor
if m == 0:
norm = np.sqrt(n + 1)
else:
norm = np.sqrt(2 * (n + 1))
return norm * Z
def build_zernike_matrix(
r: np.ndarray,
theta: np.ndarray,
n_modes: int = 50,
r_max: float = None
) -> np.ndarray:
"""
Build Zernike basis matrix for a set of points.
Args:
r: Radial coordinates
theta: Angular coordinates
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
r_max: Maximum radius for normalization (if None, use max(r))
Returns:
Z: [n_points, n_modes] Zernike basis matrix
"""
if r_max is None:
r_max = r.max()
r_norm = r / r_max
n_points = len(r)
Z = np.zeros((n_points, n_modes), dtype=np.float64)
for j in range(1, n_modes + 1):
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
return Z
class DifferentiableZernikeFit(nn.Module):
"""
GPU-accelerated, differentiable Zernike polynomial fitting.
This layer fits Zernike coefficients to a displacement field using
least squares. The key insight is that least squares has a closed-form
solution: c = (Z^T Z)^{-1} Z^T @ values
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
matrix multiply, which is fully differentiable.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes to fit
regularization: Tikhonov regularization for stability
"""
super().__init__()
self.n_modes = n_modes
# Get coordinates from polar graph
r = polar_graph.r
theta = polar_graph.theta
r_max = polar_graph.r_outer
# Build Zernike basis matrix [n_nodes, n_modes]
Z = build_zernike_matrix(r, theta, n_modes, r_max)
# Convert to tensor and register as buffer
Z_tensor = torch.tensor(Z, dtype=torch.float32)
self.register_buffer('Z', Z_tensor)
# Precompute pseudo-inverse with regularization
# c = (Z^T Z + λI)^{-1} Z^T @ values
ZtZ = Z_tensor.T @ Z_tensor
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
ZtZ_inv = torch.inverse(ZtZ_reg)
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
self.register_buffer('pseudo_inverse', pseudo_inv)
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
"""
Fit Zernike coefficients to displacement field.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
Returns:
coefficients: [n_modes] or [n_subcases, n_modes]
"""
if z_displacement.dim() == 1:
# Single field: [n_nodes] → [n_modes]
return self.pseudo_inverse @ z_displacement
else:
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
# Transpose, multiply, transpose back
return (self.pseudo_inverse @ z_displacement).T
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
"""
Reconstruct displacement field from coefficients.
Args:
coefficients: [n_modes] or [n_subcases, n_modes]
Returns:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
"""
if coefficients.dim() == 1:
return self.Z @ coefficients
else:
return self.Z @ coefficients.T
def fit_and_residual(
self,
z_displacement: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fit coefficients and return residual.
Args:
z_displacement: [n_nodes] or [n_nodes, n_subcases]
Returns:
coefficients, residual
"""
coeffs = self.forward(z_displacement)
reconstruction = self.reconstruct(coeffs)
residual = z_displacement - reconstruction
return coeffs, residual
class ZernikeObjectiveLayer(nn.Module):
"""
Compute Zernike-based optimization objectives from displacement field.
This layer replicates the exact computation done in FEA post-processing:
1. Compute relative displacement (e.g., 40° - 20°)
2. Convert to wavefront error (× 2 for reflection, mm nm)
3. Fit Zernike and remove low-order terms
4. Compute filtered RMS
The computation is fully differentiable, allowing end-to-end training
with objective-based loss.
"""
def __init__(
self,
polar_graph,
n_modes: int = 50,
regularization: float = 1e-6
):
"""
Args:
polar_graph: PolarMirrorGraph instance
n_modes: Number of Zernike modes
regularization: Regularization for Zernike fitting
"""
super().__init__()
self.n_modes = n_modes
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
# Precompute Zernike basis subsets for filtering
Z = self.zernike_fit.Z
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
self.register_buffer('Z_j1_to_j4', Z[:, :4])
# Only J1-J3 for manufacturing objective
self.register_buffer('Z_j1_to_j3', Z[:, :3])
# Store node count
self.n_nodes = Z.shape[0]
def forward(
self,
z_disp_all_subcases: torch.Tensor,
return_all: bool = False
) -> Dict[str, torch.Tensor]:
"""
Compute Zernike objectives from displacement field.
Args:
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
return_all: If True, return additional diagnostics
Returns:
Dictionary with objective values:
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
"""
# Unpack subcases
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
# === Objective 1: Relative filtered RMS 40° vs 20° ===
disp_rel_40 = disp_40 - disp_20
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
# === Objective 2: Relative filtered RMS 60° vs 20° ===
disp_rel_60 = disp_60 - disp_20
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
disp_rel_90 = disp_90 - disp_20
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
result = {
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
'mfg_90_optician_workload': mfg_90,
}
if return_all:
# Include intermediate values for debugging
result['wfe_rel_40'] = wfe_rel_40
result['wfe_rel_60'] = wfe_rel_60
result['wfe_rel_90'] = wfe_rel_90
return result
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
This is the standard filtered RMS for optical performance.
"""
# Fit low-order coefficients using precomputed pseudo-inverse
# c = (Z^T Z)^{-1} Z^T @ wfe
Z_low = self.Z_j1_to_j4
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
# Reconstruct low-order surface
wfe_low = Z_low @ coeffs_low
# Residual (high-order content)
wfe_filtered = wfe - wfe_low
# RMS
return torch.sqrt(torch.mean(wfe_filtered ** 2))
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
"""
Compute RMS after removing only J1-J3 (piston, tip, tilt).
This keeps defocus (J4), which is harder to polish out - represents
actual manufacturing workload.
"""
Z_low = self.Z_j1_to_j3
ZtZ_low = Z_low.T @ Z_low
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
wfe_low = Z_low @ coeffs_low
wfe_filtered = wfe - wfe_low
return torch.sqrt(torch.mean(wfe_filtered ** 2))
class ZernikeRMSLoss(nn.Module):
"""
Combined loss function for GNN training.
This loss combines:
1. Displacement field reconstruction loss (MSE)
2. Objective prediction loss (relative Zernike RMS)
The multi-task loss helps the network learn both accurate
displacement fields AND accurate objective predictions.
"""
def __init__(
self,
polar_graph,
field_weight: float = 1.0,
objective_weight: float = 0.1,
n_modes: int = 50
):
"""
Args:
polar_graph: PolarMirrorGraph instance
field_weight: Weight for displacement field loss
objective_weight: Weight for objective loss
n_modes: Number of Zernike modes
"""
super().__init__()
self.field_weight = field_weight
self.objective_weight = objective_weight
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
def forward(
self,
z_disp_pred: torch.Tensor,
z_disp_true: torch.Tensor,
objectives_true: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Compute combined loss.
Args:
z_disp_pred: Predicted displacement [n_nodes, 4]
z_disp_true: Ground truth displacement [n_nodes, 4]
objectives_true: Optional dict of true objective values
Returns:
total_loss, loss_components dict
"""
# Field reconstruction loss
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
# Scale field loss to account for small displacement values
# Displacements are ~1e-4 mm, so MSE is ~1e-8
loss_field_scaled = loss_field * 1e8
components = {
'loss_field': loss_field_scaled,
}
total_loss = self.field_weight * loss_field_scaled
# Objective loss (if ground truth provided)
if objectives_true is not None and self.objective_weight > 0:
objectives_pred = self.objective_layer(z_disp_pred)
loss_obj = 0.0
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
if key in objectives_true:
pred = objectives_pred[key]
true = objectives_true[key]
# Relative error squared
rel_err = ((pred - true) / (true + 1e-6)) ** 2
loss_obj = loss_obj + rel_err
components[f'loss_{key}'] = rel_err
components['loss_objectives'] = loss_obj
total_loss = total_loss + self.objective_weight * loss_obj
components['total_loss'] = total_loss
return total_loss, components
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
import sys
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
print("="*60)
print("Testing Differentiable Zernike Layer")
print("="*60)
# Create polar graph
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
print(f"\nPolar Graph: {graph.n_nodes} nodes")
# Create Zernike fitting layer
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
print(f" Z matrix: {zernike_fit.Z.shape}")
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
# Test with synthetic displacement
print("\n--- Test Zernike Fitting ---")
# Create synthetic displacement (defocus + astigmatism pattern)
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
theta = torch.tensor(graph.theta, dtype=torch.float32)
# Defocus (J4) + Astigmatism (J5)
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
# Fit coefficients
coeffs = zernike_fit(synthetic_disp)
print(f"Fitted coefficients shape: {coeffs.shape}")
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
# Reconstruct
recon = zernike_fit.reconstruct(coeffs)
error = (synthetic_disp - recon).abs()
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
# Test with multiple subcases
print("\n--- Test Multi-Subcase ---")
z_disp_multi = torch.stack([
synthetic_disp,
synthetic_disp * 0.5,
synthetic_disp * 0.7,
synthetic_disp * 0.9,
], dim=1) # [n_nodes, 4]
coeffs_multi = zernike_fit(z_disp_multi)
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
# Test objective layer
print("\n--- Test Objective Layer ---")
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
objectives = objective_layer(z_disp_multi)
print("Computed objectives:")
for key, val in objectives.items():
print(f" {key}: {val.item():.2f} nm")
# Test gradient flow
print("\n--- Test Gradient Flow ---")
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
objectives = objective_layer(z_disp_grad)
loss = objectives['rel_filtered_rms_40_vs_20']
loss.backward()
print(f"Gradient shape: {z_disp_grad.grad.shape}")
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
print("✓ Gradients flow through Zernike fitting!")
# Test loss function
print("\n--- Test Loss Function ---")
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
print(f"Total loss: {total_loss.item():.6f}")
for key, val in components.items():
if isinstance(val, torch.Tensor):
print(f" {key}: {val.item():.6f}")
print("\n" + "="*60)
print("✓ All tests passed!")
print("="*60)