545 lines
17 KiB
Python
545 lines
17 KiB
Python
|
|
"""
|
|||
|
|
Differentiable Zernike Fitting Layer
|
|||
|
|
=====================================
|
|||
|
|
|
|||
|
|
This module provides GPU-accelerated, differentiable Zernike polynomial fitting.
|
|||
|
|
The key innovation is putting Zernike fitting INSIDE the neural network for
|
|||
|
|
end-to-end training.
|
|||
|
|
|
|||
|
|
Why Differentiable Zernike?
|
|||
|
|
|
|||
|
|
Current MLP approach:
|
|||
|
|
design → MLP → coefficients (learn 200 outputs independently)
|
|||
|
|
|
|||
|
|
GNN + Differentiable Zernike:
|
|||
|
|
design → GNN → displacement field → Zernike fit → coefficients
|
|||
|
|
↑
|
|||
|
|
Differentiable! Gradients flow back
|
|||
|
|
|
|||
|
|
This allows the network to learn:
|
|||
|
|
1. Spatially coherent displacement fields
|
|||
|
|
2. Fields that produce accurate Zernike coefficients
|
|||
|
|
3. Correct relative deformation computation
|
|||
|
|
|
|||
|
|
Components:
|
|||
|
|
1. DifferentiableZernikeFit - Fits coefficients from displacement field
|
|||
|
|
2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
from optimization_engine.gnn.differentiable_zernike import (
|
|||
|
|
DifferentiableZernikeFit,
|
|||
|
|
ZernikeObjectiveLayer
|
|||
|
|
)
|
|||
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
|||
|
|
|
|||
|
|
graph = PolarMirrorGraph()
|
|||
|
|
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
|||
|
|
|
|||
|
|
# In forward pass:
|
|||
|
|
z_disp = gnn_model(...) # [n_nodes, 4]
|
|||
|
|
objectives = objective_layer(z_disp) # Dict with RMS values
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import math
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
from typing import Dict, Optional, Tuple
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
Compute Zernike polynomial for Noll index j.
|
|||
|
|
|
|||
|
|
Uses the standard Noll indexing convention:
|
|||
|
|
j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
j: Noll index (1-based)
|
|||
|
|
r: Radial coordinates (normalized to [0, 1])
|
|||
|
|
theta: Angular coordinates (radians)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Zernike polynomial values at (r, theta)
|
|||
|
|
"""
|
|||
|
|
# Convert Noll index to (n, m)
|
|||
|
|
n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2))
|
|||
|
|
m_sum = j - n * (n + 1) // 2 - 1
|
|||
|
|
|
|||
|
|
if n % 2 == 0:
|
|||
|
|
m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1
|
|||
|
|
else:
|
|||
|
|
m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2)
|
|||
|
|
|
|||
|
|
if (n - m) % 2 == 1:
|
|||
|
|
m = -m
|
|||
|
|
|
|||
|
|
# Compute radial polynomial R_n^|m|(r)
|
|||
|
|
R = np.zeros_like(r)
|
|||
|
|
m_abs = abs(m)
|
|||
|
|
for k in range((n - m_abs) // 2 + 1):
|
|||
|
|
coef = ((-1) ** k * math.factorial(n - k) /
|
|||
|
|
(math.factorial(k) *
|
|||
|
|
math.factorial((n + m_abs) // 2 - k) *
|
|||
|
|
math.factorial((n - m_abs) // 2 - k)))
|
|||
|
|
R += coef * r ** (n - 2 * k)
|
|||
|
|
|
|||
|
|
# Combine with angular part
|
|||
|
|
if m >= 0:
|
|||
|
|
Z = R * np.cos(m_abs * theta)
|
|||
|
|
else:
|
|||
|
|
Z = R * np.sin(m_abs * theta)
|
|||
|
|
|
|||
|
|
# Normalization factor
|
|||
|
|
if m == 0:
|
|||
|
|
norm = np.sqrt(n + 1)
|
|||
|
|
else:
|
|||
|
|
norm = np.sqrt(2 * (n + 1))
|
|||
|
|
|
|||
|
|
return norm * Z
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_zernike_matrix(
|
|||
|
|
r: np.ndarray,
|
|||
|
|
theta: np.ndarray,
|
|||
|
|
n_modes: int = 50,
|
|||
|
|
r_max: float = None
|
|||
|
|
) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
Build Zernike basis matrix for a set of points.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
r: Radial coordinates
|
|||
|
|
theta: Angular coordinates
|
|||
|
|
n_modes: Number of Zernike modes (Noll indices 1 to n_modes)
|
|||
|
|
r_max: Maximum radius for normalization (if None, use max(r))
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Z: [n_points, n_modes] Zernike basis matrix
|
|||
|
|
"""
|
|||
|
|
if r_max is None:
|
|||
|
|
r_max = r.max()
|
|||
|
|
|
|||
|
|
r_norm = r / r_max
|
|||
|
|
|
|||
|
|
n_points = len(r)
|
|||
|
|
Z = np.zeros((n_points, n_modes), dtype=np.float64)
|
|||
|
|
|
|||
|
|
for j in range(1, n_modes + 1):
|
|||
|
|
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
|
|||
|
|
|
|||
|
|
return Z
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DifferentiableZernikeFit(nn.Module):
|
|||
|
|
"""
|
|||
|
|
GPU-accelerated, differentiable Zernike polynomial fitting.
|
|||
|
|
|
|||
|
|
This layer fits Zernike coefficients to a displacement field using
|
|||
|
|
least squares. The key insight is that least squares has a closed-form
|
|||
|
|
solution: c = (Z^T Z)^{-1} Z^T @ values
|
|||
|
|
|
|||
|
|
By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single
|
|||
|
|
matrix multiply, which is fully differentiable.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
polar_graph,
|
|||
|
|
n_modes: int = 50,
|
|||
|
|
regularization: float = 1e-6
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
polar_graph: PolarMirrorGraph instance
|
|||
|
|
n_modes: Number of Zernike modes to fit
|
|||
|
|
regularization: Tikhonov regularization for stability
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.n_modes = n_modes
|
|||
|
|
|
|||
|
|
# Get coordinates from polar graph
|
|||
|
|
r = polar_graph.r
|
|||
|
|
theta = polar_graph.theta
|
|||
|
|
r_max = polar_graph.r_outer
|
|||
|
|
|
|||
|
|
# Build Zernike basis matrix [n_nodes, n_modes]
|
|||
|
|
Z = build_zernike_matrix(r, theta, n_modes, r_max)
|
|||
|
|
|
|||
|
|
# Convert to tensor and register as buffer
|
|||
|
|
Z_tensor = torch.tensor(Z, dtype=torch.float32)
|
|||
|
|
self.register_buffer('Z', Z_tensor)
|
|||
|
|
|
|||
|
|
# Precompute pseudo-inverse with regularization
|
|||
|
|
# c = (Z^T Z + λI)^{-1} Z^T @ values
|
|||
|
|
ZtZ = Z_tensor.T @ Z_tensor
|
|||
|
|
ZtZ_reg = ZtZ + regularization * torch.eye(n_modes)
|
|||
|
|
ZtZ_inv = torch.inverse(ZtZ_reg)
|
|||
|
|
pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes]
|
|||
|
|
|
|||
|
|
self.register_buffer('pseudo_inverse', pseudo_inv)
|
|||
|
|
|
|||
|
|
def forward(self, z_displacement: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Fit Zernike coefficients to displacement field.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
coefficients: [n_modes] or [n_subcases, n_modes]
|
|||
|
|
"""
|
|||
|
|
if z_displacement.dim() == 1:
|
|||
|
|
# Single field: [n_nodes] → [n_modes]
|
|||
|
|
return self.pseudo_inverse @ z_displacement
|
|||
|
|
else:
|
|||
|
|
# Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes]
|
|||
|
|
# Transpose, multiply, transpose back
|
|||
|
|
return (self.pseudo_inverse @ z_displacement).T
|
|||
|
|
|
|||
|
|
def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Reconstruct displacement field from coefficients.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
coefficients: [n_modes] or [n_subcases, n_modes]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
|||
|
|
"""
|
|||
|
|
if coefficients.dim() == 1:
|
|||
|
|
return self.Z @ coefficients
|
|||
|
|
else:
|
|||
|
|
return self.Z @ coefficients.T
|
|||
|
|
|
|||
|
|
def fit_and_residual(
|
|||
|
|
self,
|
|||
|
|
z_displacement: torch.Tensor
|
|||
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||
|
|
"""
|
|||
|
|
Fit coefficients and return residual.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
z_displacement: [n_nodes] or [n_nodes, n_subcases]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
coefficients, residual
|
|||
|
|
"""
|
|||
|
|
coeffs = self.forward(z_displacement)
|
|||
|
|
reconstruction = self.reconstruct(coeffs)
|
|||
|
|
residual = z_displacement - reconstruction
|
|||
|
|
return coeffs, residual
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZernikeObjectiveLayer(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Compute Zernike-based optimization objectives from displacement field.
|
|||
|
|
|
|||
|
|
This layer replicates the exact computation done in FEA post-processing:
|
|||
|
|
1. Compute relative displacement (e.g., 40° - 20°)
|
|||
|
|
2. Convert to wavefront error (× 2 for reflection, mm → nm)
|
|||
|
|
3. Fit Zernike and remove low-order terms
|
|||
|
|
4. Compute filtered RMS
|
|||
|
|
|
|||
|
|
The computation is fully differentiable, allowing end-to-end training
|
|||
|
|
with objective-based loss.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
polar_graph,
|
|||
|
|
n_modes: int = 50,
|
|||
|
|
regularization: float = 1e-6
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
polar_graph: PolarMirrorGraph instance
|
|||
|
|
n_modes: Number of Zernike modes
|
|||
|
|
regularization: Regularization for Zernike fitting
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.n_modes = n_modes
|
|||
|
|
self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization)
|
|||
|
|
|
|||
|
|
# Precompute Zernike basis subsets for filtering
|
|||
|
|
Z = self.zernike_fit.Z
|
|||
|
|
|
|||
|
|
# Low-order modes (J1-J4: piston, tip, tilt, defocus)
|
|||
|
|
self.register_buffer('Z_j1_to_j4', Z[:, :4])
|
|||
|
|
|
|||
|
|
# Only J1-J3 for manufacturing objective
|
|||
|
|
self.register_buffer('Z_j1_to_j3', Z[:, :3])
|
|||
|
|
|
|||
|
|
# Store node count
|
|||
|
|
self.n_nodes = Z.shape[0]
|
|||
|
|
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
z_disp_all_subcases: torch.Tensor,
|
|||
|
|
return_all: bool = False
|
|||
|
|
) -> Dict[str, torch.Tensor]:
|
|||
|
|
"""
|
|||
|
|
Compute Zernike objectives from displacement field.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases
|
|||
|
|
Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60°
|
|||
|
|
return_all: If True, return additional diagnostics
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dictionary with objective values:
|
|||
|
|
- rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm)
|
|||
|
|
- rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm)
|
|||
|
|
- mfg_90_optician_workload: RMS after J1-J3 removal (nm)
|
|||
|
|
"""
|
|||
|
|
# Unpack subcases
|
|||
|
|
disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90°
|
|||
|
|
disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference)
|
|||
|
|
disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40°
|
|||
|
|
disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60°
|
|||
|
|
|
|||
|
|
# === Objective 1: Relative filtered RMS 40° vs 20° ===
|
|||
|
|
disp_rel_40 = disp_40 - disp_20
|
|||
|
|
wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection
|
|||
|
|
rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40)
|
|||
|
|
|
|||
|
|
# === Objective 2: Relative filtered RMS 60° vs 20° ===
|
|||
|
|
disp_rel_60 = disp_60 - disp_20
|
|||
|
|
wfe_rel_60 = 2.0 * disp_rel_60 * 1e6
|
|||
|
|
rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60)
|
|||
|
|
|
|||
|
|
# === Objective 3: Manufacturing 90° (J1-J3 filtered) ===
|
|||
|
|
disp_rel_90 = disp_90 - disp_20
|
|||
|
|
wfe_rel_90 = 2.0 * disp_rel_90 * 1e6
|
|||
|
|
mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90)
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'rel_filtered_rms_40_vs_20': rms_40_vs_20,
|
|||
|
|
'rel_filtered_rms_60_vs_20': rms_60_vs_20,
|
|||
|
|
'mfg_90_optician_workload': mfg_90,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if return_all:
|
|||
|
|
# Include intermediate values for debugging
|
|||
|
|
result['wfe_rel_40'] = wfe_rel_40
|
|||
|
|
result['wfe_rel_60'] = wfe_rel_60
|
|||
|
|
result['wfe_rel_90'] = wfe_rel_90
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Compute RMS after removing J1-J4 (piston, tip, tilt, defocus).
|
|||
|
|
|
|||
|
|
This is the standard filtered RMS for optical performance.
|
|||
|
|
"""
|
|||
|
|
# Fit low-order coefficients using precomputed pseudo-inverse
|
|||
|
|
# c = (Z^T Z)^{-1} Z^T @ wfe
|
|||
|
|
Z_low = self.Z_j1_to_j4
|
|||
|
|
ZtZ_low = Z_low.T @ Z_low
|
|||
|
|
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
|||
|
|
|
|||
|
|
# Reconstruct low-order surface
|
|||
|
|
wfe_low = Z_low @ coeffs_low
|
|||
|
|
|
|||
|
|
# Residual (high-order content)
|
|||
|
|
wfe_filtered = wfe - wfe_low
|
|||
|
|
|
|||
|
|
# RMS
|
|||
|
|
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
|||
|
|
|
|||
|
|
def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Compute RMS after removing only J1-J3 (piston, tip, tilt).
|
|||
|
|
|
|||
|
|
This keeps defocus (J4), which is harder to polish out - represents
|
|||
|
|
actual manufacturing workload.
|
|||
|
|
"""
|
|||
|
|
Z_low = self.Z_j1_to_j3
|
|||
|
|
ZtZ_low = Z_low.T @ Z_low
|
|||
|
|
coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe)
|
|||
|
|
|
|||
|
|
wfe_low = Z_low @ coeffs_low
|
|||
|
|
wfe_filtered = wfe - wfe_low
|
|||
|
|
|
|||
|
|
return torch.sqrt(torch.mean(wfe_filtered ** 2))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZernikeRMSLoss(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Combined loss function for GNN training.
|
|||
|
|
|
|||
|
|
This loss combines:
|
|||
|
|
1. Displacement field reconstruction loss (MSE)
|
|||
|
|
2. Objective prediction loss (relative Zernike RMS)
|
|||
|
|
|
|||
|
|
The multi-task loss helps the network learn both accurate
|
|||
|
|
displacement fields AND accurate objective predictions.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
polar_graph,
|
|||
|
|
field_weight: float = 1.0,
|
|||
|
|
objective_weight: float = 0.1,
|
|||
|
|
n_modes: int = 50
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
polar_graph: PolarMirrorGraph instance
|
|||
|
|
field_weight: Weight for displacement field loss
|
|||
|
|
objective_weight: Weight for objective loss
|
|||
|
|
n_modes: Number of Zernike modes
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.field_weight = field_weight
|
|||
|
|
self.objective_weight = objective_weight
|
|||
|
|
|
|||
|
|
self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes)
|
|||
|
|
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
z_disp_pred: torch.Tensor,
|
|||
|
|
z_disp_true: torch.Tensor,
|
|||
|
|
objectives_true: Optional[Dict[str, torch.Tensor]] = None
|
|||
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|||
|
|
"""
|
|||
|
|
Compute combined loss.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
z_disp_pred: Predicted displacement [n_nodes, 4]
|
|||
|
|
z_disp_true: Ground truth displacement [n_nodes, 4]
|
|||
|
|
objectives_true: Optional dict of true objective values
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
total_loss, loss_components dict
|
|||
|
|
"""
|
|||
|
|
# Field reconstruction loss
|
|||
|
|
loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true)
|
|||
|
|
|
|||
|
|
# Scale field loss to account for small displacement values
|
|||
|
|
# Displacements are ~1e-4 mm, so MSE is ~1e-8
|
|||
|
|
loss_field_scaled = loss_field * 1e8
|
|||
|
|
|
|||
|
|
components = {
|
|||
|
|
'loss_field': loss_field_scaled,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
total_loss = self.field_weight * loss_field_scaled
|
|||
|
|
|
|||
|
|
# Objective loss (if ground truth provided)
|
|||
|
|
if objectives_true is not None and self.objective_weight > 0:
|
|||
|
|
objectives_pred = self.objective_layer(z_disp_pred)
|
|||
|
|
|
|||
|
|
loss_obj = 0.0
|
|||
|
|
for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']:
|
|||
|
|
if key in objectives_true:
|
|||
|
|
pred = objectives_pred[key]
|
|||
|
|
true = objectives_true[key]
|
|||
|
|
# Relative error squared
|
|||
|
|
rel_err = ((pred - true) / (true + 1e-6)) ** 2
|
|||
|
|
loss_obj = loss_obj + rel_err
|
|||
|
|
components[f'loss_{key}'] = rel_err
|
|||
|
|
|
|||
|
|
components['loss_objectives'] = loss_obj
|
|||
|
|
total_loss = total_loss + self.objective_weight * loss_obj
|
|||
|
|
|
|||
|
|
components['total_loss'] = total_loss
|
|||
|
|
return total_loss, components
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Testing
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
import sys
|
|||
|
|
sys.path.insert(0, "C:/Users/Antoine/Atomizer")
|
|||
|
|
|
|||
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
|||
|
|
|
|||
|
|
print("="*60)
|
|||
|
|
print("Testing Differentiable Zernike Layer")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# Create polar graph
|
|||
|
|
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
|||
|
|
print(f"\nPolar Graph: {graph.n_nodes} nodes")
|
|||
|
|
|
|||
|
|
# Create Zernike fitting layer
|
|||
|
|
zernike_fit = DifferentiableZernikeFit(graph, n_modes=50)
|
|||
|
|
print(f"Zernike Fit: {zernike_fit.n_modes} modes")
|
|||
|
|
print(f" Z matrix: {zernike_fit.Z.shape}")
|
|||
|
|
print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}")
|
|||
|
|
|
|||
|
|
# Test with synthetic displacement
|
|||
|
|
print("\n--- Test Zernike Fitting ---")
|
|||
|
|
|
|||
|
|
# Create synthetic displacement (defocus + astigmatism pattern)
|
|||
|
|
r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32)
|
|||
|
|
theta = torch.tensor(graph.theta, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
# Defocus (J4) + Astigmatism (J5)
|
|||
|
|
synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta)
|
|||
|
|
|
|||
|
|
# Fit coefficients
|
|||
|
|
coeffs = zernike_fit(synthetic_disp)
|
|||
|
|
print(f"Fitted coefficients shape: {coeffs.shape}")
|
|||
|
|
print(f"First 10 coefficients: {coeffs[:10].tolist()}")
|
|||
|
|
|
|||
|
|
# Reconstruct
|
|||
|
|
recon = zernike_fit.reconstruct(coeffs)
|
|||
|
|
error = (synthetic_disp - recon).abs()
|
|||
|
|
print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}")
|
|||
|
|
|
|||
|
|
# Test with multiple subcases
|
|||
|
|
print("\n--- Test Multi-Subcase ---")
|
|||
|
|
z_disp_multi = torch.stack([
|
|||
|
|
synthetic_disp,
|
|||
|
|
synthetic_disp * 0.5,
|
|||
|
|
synthetic_disp * 0.7,
|
|||
|
|
synthetic_disp * 0.9,
|
|||
|
|
], dim=1) # [n_nodes, 4]
|
|||
|
|
|
|||
|
|
coeffs_multi = zernike_fit(z_disp_multi)
|
|||
|
|
print(f"Multi-subcase coefficients: {coeffs_multi.shape}")
|
|||
|
|
|
|||
|
|
# Test objective layer
|
|||
|
|
print("\n--- Test Objective Layer ---")
|
|||
|
|
objective_layer = ZernikeObjectiveLayer(graph, n_modes=50)
|
|||
|
|
|
|||
|
|
objectives = objective_layer(z_disp_multi)
|
|||
|
|
print("Computed objectives:")
|
|||
|
|
for key, val in objectives.items():
|
|||
|
|
print(f" {key}: {val.item():.2f} nm")
|
|||
|
|
|
|||
|
|
# Test gradient flow
|
|||
|
|
print("\n--- Test Gradient Flow ---")
|
|||
|
|
z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True)
|
|||
|
|
objectives = objective_layer(z_disp_grad)
|
|||
|
|
loss = objectives['rel_filtered_rms_40_vs_20']
|
|||
|
|
loss.backward()
|
|||
|
|
print(f"Gradient shape: {z_disp_grad.grad.shape}")
|
|||
|
|
print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]")
|
|||
|
|
print("✓ Gradients flow through Zernike fitting!")
|
|||
|
|
|
|||
|
|
# Test loss function
|
|||
|
|
print("\n--- Test Loss Function ---")
|
|||
|
|
loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1)
|
|||
|
|
|
|||
|
|
z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True)
|
|||
|
|
|
|||
|
|
total_loss, components = loss_fn(z_pred, z_disp_multi.detach())
|
|||
|
|
print(f"Total loss: {total_loss.item():.6f}")
|
|||
|
|
for key, val in components.items():
|
|||
|
|
if isinstance(val, torch.Tensor):
|
|||
|
|
print(f" {key}: {val.item():.6f}")
|
|||
|
|
|
|||
|
|
print("\n" + "="*60)
|
|||
|
|
print("✓ All tests passed!")
|
|||
|
|
print("="*60)
|