feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
544
optimization_engine/gnn/differentiable_zernike.py
Normal file
544
optimization_engine/gnn/differentiable_zernike.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
Differentiable Zernike Fitting Layer
|
||||
=====================================
|
||||
|
||||
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
|
||||
The key innovation is putting Zernike fitting INSIDE the neural network for
|
||||
end-to-end training.
|
||||
|
||||
Why Differentiable Zernike?
|
||||
|
||||
Current MLP approach:
|
||||
design → MLP → coefficients (learn 200 outputs independently)
|
||||
|
||||
GNN + Differentiable Zernike:
|
||||
design → GNN → displacement field → Zernike fit → coefficients
|
||||
↑
|
||||
Differentiable! Gradients flow back
|
||||
|
||||
This allows the network to learn:
|
||||
1. Spatially coherent displacement fields
|
||||
2. Fields that produce accurate Zernike coefficients
|
||||
3. Correct relative deformation computation
|
||||
|
||||
Components:
|
||||
1. DifferentiableZernikeFit - Fits coefficients from displacement field
|
||||
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.differentiable_zernike import (
|
||||
DifferentiableZernikeFit,
|
||||
ZernikeObjectiveLayer
|
||||
)
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
graph = PolarMirrorGraph()
|
||||
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
||||
|
||||
# In forward pass:
|
||||
z_disp = gnn_model(...) # [n_nodes, 4]
|
||||
objectives = objective_layer(z_disp) # Dict with RMS values
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute Zernike polynomial for Noll index j.
|
||||
|
||||
Uses the standard Noll indexing convention:
|
||||
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
|
||||
|
||||
Args:
|
||||
j: Noll index (1-based)
|
||||
r: Radial coordinates (normalized to [0, 1])
|
||||
theta: Angular coordinates (radians)
|
||||
|
||||
Returns:
|
||||
Zernike polynomial values at (r, theta)
|
||||
"""
|
||||
# Convert Noll index to (n, m)
|
||||
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
|
||||
m_sum = j - n * (n + 1) // 2 - 1
|
||||
|
||||
if n % 2 == 0:
|
||||
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
|
||||
else:
|
||||
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
|
||||
|
||||
if (n - m) % 2 == 1:
|
||||
m = -m
|
||||
|
||||
# Compute radial polynomial R_n^|m|(r)
|
||||
R = np.zeros_like(r)
|
||||
m_abs = abs(m)
|
||||
for k in range((n - m_abs) // 2 + 1):
|
||||
coef = ((-1) ** k * math.factorial(n - k) /
|
||||
(math.factorial(k) *
|
||||
math.factorial((n + m_abs) // 2 - k) *
|
||||
math.factorial((n - m_abs) // 2 - k)))
|
||||
R += coef * r ** (n - 2 * k)
|
||||
|
||||
# Combine with angular part
|
||||
if m >= 0:
|
||||
Z = R * np.cos(m_abs * theta)
|
||||
else:
|
||||
Z = R * np.sin(m_abs * theta)
|
||||
|
||||
# Normalization factor
|
||||
if m == 0:
|
||||
norm = np.sqrt(n + 1)
|
||||
else:
|
||||
norm = np.sqrt(2 * (n + 1))
|
||||
|
||||
return norm * Z
|
||||
|
||||
|
||||
def build_zernike_matrix(
|
||||
r: np.ndarray,
|
||||
theta: np.ndarray,
|
||||
n_modes: int = 50,
|
||||
r_max: float = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Build Zernike basis matrix for a set of points.
|
||||
|
||||
Args:
|
||||
r: Radial coordinates
|
||||
theta: Angular coordinates
|
||||
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
|
||||
r_max: Maximum radius for normalization (if None, use max(r))
|
||||
|
||||
Returns:
|
||||
Z: [n_points, n_modes] Zernike basis matrix
|
||||
"""
|
||||
if r_max is None:
|
||||
r_max = r.max()
|
||||
|
||||
r_norm = r / r_max
|
||||
|
||||
n_points = len(r)
|
||||
Z = np.zeros((n_points, n_modes), dtype=np.float64)
|
||||
|
||||
for j in range(1, n_modes + 1):
|
||||
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class DifferentiableZernikeFit(nn.Module):
|
||||
"""
|
||||
GPU-accelerated, differentiable Zernike polynomial fitting.
|
||||
|
||||
This layer fits Zernike coefficients to a displacement field using
|
||||
least squares. The key insight is that least squares has a closed-form
|
||||
solution: c = (Z^T Z)^{-1} Z^T @ values
|
||||
|
||||
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
|
||||
matrix multiply, which is fully differentiable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
n_modes: int = 50,
|
||||
regularization: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
n_modes: Number of Zernike modes to fit
|
||||
regularization: Tikhonov regularization for stability
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_modes = n_modes
|
||||
|
||||
# Get coordinates from polar graph
|
||||
r = polar_graph.r
|
||||
theta = polar_graph.theta
|
||||
r_max = polar_graph.r_outer
|
||||
|
||||
# Build Zernike basis matrix [n_nodes, n_modes]
|
||||
Z = build_zernike_matrix(r, theta, n_modes, r_max)
|
||||
|
||||
# Convert to tensor and register as buffer
|
||||
Z_tensor = torch.tensor(Z, dtype=torch.float32)
|
||||
self.register_buffer('Z', Z_tensor)
|
||||
|
||||
# Precompute pseudo-inverse with regularization
|
||||
# c = (Z^T Z + λI)^{-1} Z^T @ values
|
||||
ZtZ = Z_tensor.T @ Z_tensor
|
||||
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
|
||||
ZtZ_inv = torch.inverse(ZtZ_reg)
|
||||
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
|
||||
|
||||
self.register_buffer('pseudo_inverse', pseudo_inv)
|
||||
|
||||
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Fit Zernike coefficients to displacement field.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
|
||||
|
||||
Returns:
|
||||
coefficients: [n_modes] or [n_subcases, n_modes]
|
||||
"""
|
||||
if z_displacement.dim() == 1:
|
||||
# Single field: [n_nodes] → [n_modes]
|
||||
return self.pseudo_inverse @ z_displacement
|
||||
else:
|
||||
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
|
||||
# Transpose, multiply, transpose back
|
||||
return (self.pseudo_inverse @ z_displacement).T
|
||||
|
||||
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reconstruct displacement field from coefficients.
|
||||
|
||||
Args:
|
||||
coefficients: [n_modes] or [n_subcases, n_modes]
|
||||
|
||||
Returns:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
||||
"""
|
||||
if coefficients.dim() == 1:
|
||||
return self.Z @ coefficients
|
||||
else:
|
||||
return self.Z @ coefficients.T
|
||||
|
||||
def fit_and_residual(
|
||||
self,
|
||||
z_displacement: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fit coefficients and return residual.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
||||
|
||||
Returns:
|
||||
coefficients, residual
|
||||
"""
|
||||
coeffs = self.forward(z_displacement)
|
||||
reconstruction = self.reconstruct(coeffs)
|
||||
residual = z_displacement - reconstruction
|
||||
return coeffs, residual
|
||||
|
||||
|
||||
class ZernikeObjectiveLayer(nn.Module):
|
||||
"""
|
||||
Compute Zernike-based optimization objectives from displacement field.
|
||||
|
||||
This layer replicates the exact computation done in FEA post-processing:
|
||||
1. Compute relative displacement (e.g., 40° - 20°)
|
||||
2. Convert to wavefront error (× 2 for reflection, mm → nm)
|
||||
3. Fit Zernike and remove low-order terms
|
||||
4. Compute filtered RMS
|
||||
|
||||
The computation is fully differentiable, allowing end-to-end training
|
||||
with objective-based loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
n_modes: int = 50,
|
||||
regularization: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
n_modes: Number of Zernike modes
|
||||
regularization: Regularization for Zernike fitting
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_modes = n_modes
|
||||
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
|
||||
|
||||
# Precompute Zernike basis subsets for filtering
|
||||
Z = self.zernike_fit.Z
|
||||
|
||||
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
|
||||
self.register_buffer('Z_j1_to_j4', Z[:, :4])
|
||||
|
||||
# Only J1-J3 for manufacturing objective
|
||||
self.register_buffer('Z_j1_to_j3', Z[:, :3])
|
||||
|
||||
# Store node count
|
||||
self.n_nodes = Z.shape[0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z_disp_all_subcases: torch.Tensor,
|
||||
return_all: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute Zernike objectives from displacement field.
|
||||
|
||||
Args:
|
||||
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
|
||||
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
|
||||
return_all: If True, return additional diagnostics
|
||||
|
||||
Returns:
|
||||
Dictionary with objective values:
|
||||
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
|
||||
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
|
||||
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
|
||||
"""
|
||||
# Unpack subcases
|
||||
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
|
||||
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
|
||||
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
|
||||
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
|
||||
|
||||
# === Objective 1: Relative filtered RMS 40° vs 20° ===
|
||||
disp_rel_40 = disp_40 - disp_20
|
||||
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
|
||||
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
|
||||
|
||||
# === Objective 2: Relative filtered RMS 60° vs 20° ===
|
||||
disp_rel_60 = disp_60 - disp_20
|
||||
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
|
||||
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
|
||||
|
||||
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
|
||||
disp_rel_90 = disp_90 - disp_20
|
||||
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
|
||||
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
|
||||
|
||||
result = {
|
||||
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
|
||||
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
|
||||
'mfg_90_optician_workload': mfg_90,
|
||||
}
|
||||
|
||||
if return_all:
|
||||
# Include intermediate values for debugging
|
||||
result['wfe_rel_40'] = wfe_rel_40
|
||||
result['wfe_rel_60'] = wfe_rel_60
|
||||
result['wfe_rel_90'] = wfe_rel_90
|
||||
|
||||
return result
|
||||
|
||||
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
|
||||
|
||||
This is the standard filtered RMS for optical performance.
|
||||
"""
|
||||
# Fit low-order coefficients using precomputed pseudo-inverse
|
||||
# c = (Z^T Z)^{-1} Z^T @ wfe
|
||||
Z_low = self.Z_j1_to_j4
|
||||
ZtZ_low = Z_low.T @ Z_low
|
||||
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
||||
|
||||
# Reconstruct low-order surface
|
||||
wfe_low = Z_low @ coeffs_low
|
||||
|
||||
# Residual (high-order content)
|
||||
wfe_filtered = wfe - wfe_low
|
||||
|
||||
# RMS
|
||||
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
||||
|
||||
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS after removing only J1-J3 (piston, tip, tilt).
|
||||
|
||||
This keeps defocus (J4), which is harder to polish out - represents
|
||||
actual manufacturing workload.
|
||||
"""
|
||||
Z_low = self.Z_j1_to_j3
|
||||
ZtZ_low = Z_low.T @ Z_low
|
||||
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
||||
|
||||
wfe_low = Z_low @ coeffs_low
|
||||
wfe_filtered = wfe - wfe_low
|
||||
|
||||
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
||||
|
||||
|
||||
class ZernikeRMSLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function for GNN training.
|
||||
|
||||
This loss combines:
|
||||
1. Displacement field reconstruction loss (MSE)
|
||||
2. Objective prediction loss (relative Zernike RMS)
|
||||
|
||||
The multi-task loss helps the network learn both accurate
|
||||
displacement fields AND accurate objective predictions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
polar_graph,
|
||||
field_weight: float = 1.0,
|
||||
objective_weight: float = 0.1,
|
||||
n_modes: int = 50
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
polar_graph: PolarMirrorGraph instance
|
||||
field_weight: Weight for displacement field loss
|
||||
objective_weight: Weight for objective loss
|
||||
n_modes: Number of Zernike modes
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.field_weight = field_weight
|
||||
self.objective_weight = objective_weight
|
||||
|
||||
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z_disp_pred: torch.Tensor,
|
||||
z_disp_true: torch.Tensor,
|
||||
objectives_true: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Compute combined loss.
|
||||
|
||||
Args:
|
||||
z_disp_pred: Predicted displacement [n_nodes, 4]
|
||||
z_disp_true: Ground truth displacement [n_nodes, 4]
|
||||
objectives_true: Optional dict of true objective values
|
||||
|
||||
Returns:
|
||||
total_loss, loss_components dict
|
||||
"""
|
||||
# Field reconstruction loss
|
||||
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
|
||||
|
||||
# Scale field loss to account for small displacement values
|
||||
# Displacements are ~1e-4 mm, so MSE is ~1e-8
|
||||
loss_field_scaled = loss_field * 1e8
|
||||
|
||||
components = {
|
||||
'loss_field': loss_field_scaled,
|
||||
}
|
||||
|
||||
total_loss = self.field_weight * loss_field_scaled
|
||||
|
||||
# Objective loss (if ground truth provided)
|
||||
if objectives_true is not None and self.objective_weight > 0:
|
||||
objectives_pred = self.objective_layer(z_disp_pred)
|
||||
|
||||
loss_obj = 0.0
|
||||
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
||||
if key in objectives_true:
|
||||
pred = objectives_pred[key]
|
||||
true = objectives_true[key]
|
||||
# Relative error squared
|
||||
rel_err = ((pred - true) / (true + 1e-6)) ** 2
|
||||
loss_obj = loss_obj + rel_err
|
||||
components[f'loss_{key}'] = rel_err
|
||||
|
||||
components['loss_objectives'] = loss_obj
|
||||
total_loss = total_loss + self.objective_weight * loss_obj
|
||||
|
||||
components['total_loss'] = total_loss
|
||||
return total_loss, components
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Testing
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
|
||||
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
print("="*60)
|
||||
print("Testing Differentiable Zernike Layer")
|
||||
print("="*60)
|
||||
|
||||
# Create polar graph
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"\nPolar Graph: {graph.n_nodes} nodes")
|
||||
|
||||
# Create Zernike fitting layer
|
||||
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
|
||||
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
|
||||
print(f" Z matrix: {zernike_fit.Z.shape}")
|
||||
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
|
||||
|
||||
# Test with synthetic displacement
|
||||
print("\n--- Test Zernike Fitting ---")
|
||||
|
||||
# Create synthetic displacement (defocus + astigmatism pattern)
|
||||
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
|
||||
theta = torch.tensor(graph.theta, dtype=torch.float32)
|
||||
|
||||
# Defocus (J4) + Astigmatism (J5)
|
||||
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
|
||||
|
||||
# Fit coefficients
|
||||
coeffs = zernike_fit(synthetic_disp)
|
||||
print(f"Fitted coefficients shape: {coeffs.shape}")
|
||||
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
|
||||
|
||||
# Reconstruct
|
||||
recon = zernike_fit.reconstruct(coeffs)
|
||||
error = (synthetic_disp - recon).abs()
|
||||
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
|
||||
|
||||
# Test with multiple subcases
|
||||
print("\n--- Test Multi-Subcase ---")
|
||||
z_disp_multi = torch.stack([
|
||||
synthetic_disp,
|
||||
synthetic_disp * 0.5,
|
||||
synthetic_disp * 0.7,
|
||||
synthetic_disp * 0.9,
|
||||
], dim=1) # [n_nodes, 4]
|
||||
|
||||
coeffs_multi = zernike_fit(z_disp_multi)
|
||||
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
|
||||
|
||||
# Test objective layer
|
||||
print("\n--- Test Objective Layer ---")
|
||||
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
||||
|
||||
objectives = objective_layer(z_disp_multi)
|
||||
print("Computed objectives:")
|
||||
for key, val in objectives.items():
|
||||
print(f" {key}: {val.item():.2f} nm")
|
||||
|
||||
# Test gradient flow
|
||||
print("\n--- Test Gradient Flow ---")
|
||||
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
|
||||
objectives = objective_layer(z_disp_grad)
|
||||
loss = objectives['rel_filtered_rms_40_vs_20']
|
||||
loss.backward()
|
||||
print(f"Gradient shape: {z_disp_grad.grad.shape}")
|
||||
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
|
||||
print("✓ Gradients flow through Zernike fitting!")
|
||||
|
||||
# Test loss function
|
||||
print("\n--- Test Loss Function ---")
|
||||
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
|
||||
|
||||
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
|
||||
|
||||
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
for key, val in components.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
print(f" {key}: {val.item():.6f}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("✓ All tests passed!")
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user