""" Differentiable Zernike Fitting Layer ===================================== This module provides GPU-accelerated, differentiable Zernike polynomial fitting. The key innovation is putting Zernike fitting INSIDE the neural network for end-to-end training. Why Differentiable Zernike? Current MLP approach: design → MLP → coefficients (learn 200 outputs independently) GNN + Differentiable Zernike: design → GNN → displacement field → Zernike fit → coefficients ↑ Differentiable! Gradients flow back This allows the network to learn: 1. Spatially coherent displacement fields 2. Fields that produce accurate Zernike coefficients 3. Correct relative deformation computation Components: 1. DifferentiableZernikeFit - Fits coefficients from displacement field 2. ZernikeObjectiveLayer - Computes RMS objectives like FEA post-processing Usage: from optimization_engine.gnn.differentiable_zernike import ( DifferentiableZernikeFit, ZernikeObjectiveLayer ) from optimization_engine.gnn.polar_graph import PolarMirrorGraph graph = PolarMirrorGraph() objective_layer = ZernikeObjectiveLayer(graph, n_modes=50) # In forward pass: z_disp = gnn_model(...) # [n_nodes, 4] objectives = objective_layer(z_disp) # Dict with RMS values """ import math import numpy as np import torch import torch.nn as nn from typing import Dict, Optional, Tuple from pathlib import Path def zernike_noll(j: int, r: np.ndarray, theta: np.ndarray) -> np.ndarray: """ Compute Zernike polynomial for Noll index j. Uses the standard Noll indexing convention: j=1: piston, j=2: tilt-y, j=3: tilt-x, j=4: defocus, etc. Args: j: Noll index (1-based) r: Radial coordinates (normalized to [0, 1]) theta: Angular coordinates (radians) Returns: Zernike polynomial values at (r, theta) """ # Convert Noll index to (n, m) n = int(np.ceil((-3 + np.sqrt(9 + 8 * (j - 1))) / 2)) m_sum = j - n * (n + 1) // 2 - 1 if n % 2 == 0: m = 2 * (m_sum // 2) if j % 2 == 1 else 2 * (m_sum // 2) + 1 else: m = 2 * (m_sum // 2) + 1 if j % 2 == 1 else 2 * (m_sum // 2) if (n - m) % 2 == 1: m = -m # Compute radial polynomial R_n^|m|(r) R = np.zeros_like(r) m_abs = abs(m) for k in range((n - m_abs) // 2 + 1): coef = ((-1) ** k * math.factorial(n - k) / (math.factorial(k) * math.factorial((n + m_abs) // 2 - k) * math.factorial((n - m_abs) // 2 - k))) R += coef * r ** (n - 2 * k) # Combine with angular part if m >= 0: Z = R * np.cos(m_abs * theta) else: Z = R * np.sin(m_abs * theta) # Normalization factor if m == 0: norm = np.sqrt(n + 1) else: norm = np.sqrt(2 * (n + 1)) return norm * Z def build_zernike_matrix( r: np.ndarray, theta: np.ndarray, n_modes: int = 50, r_max: float = None ) -> np.ndarray: """ Build Zernike basis matrix for a set of points. Args: r: Radial coordinates theta: Angular coordinates n_modes: Number of Zernike modes (Noll indices 1 to n_modes) r_max: Maximum radius for normalization (if None, use max(r)) Returns: Z: [n_points, n_modes] Zernike basis matrix """ if r_max is None: r_max = r.max() r_norm = r / r_max n_points = len(r) Z = np.zeros((n_points, n_modes), dtype=np.float64) for j in range(1, n_modes + 1): Z[:, j - 1] = zernike_noll(j, r_norm, theta) return Z class DifferentiableZernikeFit(nn.Module): """ GPU-accelerated, differentiable Zernike polynomial fitting. This layer fits Zernike coefficients to a displacement field using least squares. The key insight is that least squares has a closed-form solution: c = (Z^T Z)^{-1} Z^T @ values By precomputing (Z^T Z)^{-1} Z^T, we can fit coefficients with a single matrix multiply, which is fully differentiable. """ def __init__( self, polar_graph, n_modes: int = 50, regularization: float = 1e-6 ): """ Args: polar_graph: PolarMirrorGraph instance n_modes: Number of Zernike modes to fit regularization: Tikhonov regularization for stability """ super().__init__() self.n_modes = n_modes # Get coordinates from polar graph r = polar_graph.r theta = polar_graph.theta r_max = polar_graph.r_outer # Build Zernike basis matrix [n_nodes, n_modes] Z = build_zernike_matrix(r, theta, n_modes, r_max) # Convert to tensor and register as buffer Z_tensor = torch.tensor(Z, dtype=torch.float32) self.register_buffer('Z', Z_tensor) # Precompute pseudo-inverse with regularization # c = (Z^T Z + λI)^{-1} Z^T @ values ZtZ = Z_tensor.T @ Z_tensor ZtZ_reg = ZtZ + regularization * torch.eye(n_modes) ZtZ_inv = torch.inverse(ZtZ_reg) pseudo_inv = ZtZ_inv @ Z_tensor.T # [n_modes, n_nodes] self.register_buffer('pseudo_inverse', pseudo_inv) def forward(self, z_displacement: torch.Tensor) -> torch.Tensor: """ Fit Zernike coefficients to displacement field. Args: z_displacement: [n_nodes] or [n_nodes, n_subcases] displacement Returns: coefficients: [n_modes] or [n_subcases, n_modes] """ if z_displacement.dim() == 1: # Single field: [n_nodes] → [n_modes] return self.pseudo_inverse @ z_displacement else: # Multiple subcases: [n_nodes, n_subcases] → [n_subcases, n_modes] # Transpose, multiply, transpose back return (self.pseudo_inverse @ z_displacement).T def reconstruct(self, coefficients: torch.Tensor) -> torch.Tensor: """ Reconstruct displacement field from coefficients. Args: coefficients: [n_modes] or [n_subcases, n_modes] Returns: z_displacement: [n_nodes] or [n_nodes, n_subcases] """ if coefficients.dim() == 1: return self.Z @ coefficients else: return self.Z @ coefficients.T def fit_and_residual( self, z_displacement: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Fit coefficients and return residual. Args: z_displacement: [n_nodes] or [n_nodes, n_subcases] Returns: coefficients, residual """ coeffs = self.forward(z_displacement) reconstruction = self.reconstruct(coeffs) residual = z_displacement - reconstruction return coeffs, residual class ZernikeObjectiveLayer(nn.Module): """ Compute Zernike-based optimization objectives from displacement field. This layer replicates the exact computation done in FEA post-processing: 1. Compute relative displacement (e.g., 40° - 20°) 2. Convert to wavefront error (× 2 for reflection, mm → nm) 3. Fit Zernike and remove low-order terms 4. Compute filtered RMS The computation is fully differentiable, allowing end-to-end training with objective-based loss. """ def __init__( self, polar_graph, n_modes: int = 50, regularization: float = 1e-6 ): """ Args: polar_graph: PolarMirrorGraph instance n_modes: Number of Zernike modes regularization: Regularization for Zernike fitting """ super().__init__() self.n_modes = n_modes self.zernike_fit = DifferentiableZernikeFit(polar_graph, n_modes, regularization) # Precompute Zernike basis subsets for filtering Z = self.zernike_fit.Z # Low-order modes (J1-J4: piston, tip, tilt, defocus) self.register_buffer('Z_j1_to_j4', Z[:, :4]) # Only J1-J3 for manufacturing objective self.register_buffer('Z_j1_to_j3', Z[:, :3]) # Store node count self.n_nodes = Z.shape[0] def forward( self, z_disp_all_subcases: torch.Tensor, return_all: bool = False ) -> Dict[str, torch.Tensor]: """ Compute Zernike objectives from displacement field. Args: z_disp_all_subcases: [n_nodes, 4] Z-displacement for 4 subcases Subcase order: 1=90°, 2=20°(ref), 3=40°, 4=60° return_all: If True, return additional diagnostics Returns: Dictionary with objective values: - rel_filtered_rms_40_vs_20: RMS after J1-J4 removal (nm) - rel_filtered_rms_60_vs_20: RMS after J1-J4 removal (nm) - mfg_90_optician_workload: RMS after J1-J3 removal (nm) """ # Unpack subcases disp_90 = z_disp_all_subcases[:, 0] # Subcase 1: 90° disp_20 = z_disp_all_subcases[:, 1] # Subcase 2: 20° (reference) disp_40 = z_disp_all_subcases[:, 2] # Subcase 3: 40° disp_60 = z_disp_all_subcases[:, 3] # Subcase 4: 60° # === Objective 1: Relative filtered RMS 40° vs 20° === disp_rel_40 = disp_40 - disp_20 wfe_rel_40 = 2.0 * disp_rel_40 * 1e6 # mm → nm, ×2 for reflection rms_40_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_40) # === Objective 2: Relative filtered RMS 60° vs 20° === disp_rel_60 = disp_60 - disp_20 wfe_rel_60 = 2.0 * disp_rel_60 * 1e6 rms_60_vs_20 = self._compute_filtered_rms_j1_to_j4(wfe_rel_60) # === Objective 3: Manufacturing 90° (J1-J3 filtered) === disp_rel_90 = disp_90 - disp_20 wfe_rel_90 = 2.0 * disp_rel_90 * 1e6 mfg_90 = self._compute_filtered_rms_j1_to_j3(wfe_rel_90) result = { 'rel_filtered_rms_40_vs_20': rms_40_vs_20, 'rel_filtered_rms_60_vs_20': rms_60_vs_20, 'mfg_90_optician_workload': mfg_90, } if return_all: # Include intermediate values for debugging result['wfe_rel_40'] = wfe_rel_40 result['wfe_rel_60'] = wfe_rel_60 result['wfe_rel_90'] = wfe_rel_90 return result def _compute_filtered_rms_j1_to_j4(self, wfe: torch.Tensor) -> torch.Tensor: """ Compute RMS after removing J1-J4 (piston, tip, tilt, defocus). This is the standard filtered RMS for optical performance. """ # Fit low-order coefficients using precomputed pseudo-inverse # c = (Z^T Z)^{-1} Z^T @ wfe Z_low = self.Z_j1_to_j4 ZtZ_low = Z_low.T @ Z_low coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe) # Reconstruct low-order surface wfe_low = Z_low @ coeffs_low # Residual (high-order content) wfe_filtered = wfe - wfe_low # RMS return torch.sqrt(torch.mean(wfe_filtered ** 2)) def _compute_filtered_rms_j1_to_j3(self, wfe: torch.Tensor) -> torch.Tensor: """ Compute RMS after removing only J1-J3 (piston, tip, tilt). This keeps defocus (J4), which is harder to polish out - represents actual manufacturing workload. """ Z_low = self.Z_j1_to_j3 ZtZ_low = Z_low.T @ Z_low coeffs_low = torch.linalg.solve(ZtZ_low, Z_low.T @ wfe) wfe_low = Z_low @ coeffs_low wfe_filtered = wfe - wfe_low return torch.sqrt(torch.mean(wfe_filtered ** 2)) class ZernikeRMSLoss(nn.Module): """ Combined loss function for GNN training. This loss combines: 1. Displacement field reconstruction loss (MSE) 2. Objective prediction loss (relative Zernike RMS) The multi-task loss helps the network learn both accurate displacement fields AND accurate objective predictions. """ def __init__( self, polar_graph, field_weight: float = 1.0, objective_weight: float = 0.1, n_modes: int = 50 ): """ Args: polar_graph: PolarMirrorGraph instance field_weight: Weight for displacement field loss objective_weight: Weight for objective loss n_modes: Number of Zernike modes """ super().__init__() self.field_weight = field_weight self.objective_weight = objective_weight self.objective_layer = ZernikeObjectiveLayer(polar_graph, n_modes) def forward( self, z_disp_pred: torch.Tensor, z_disp_true: torch.Tensor, objectives_true: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Compute combined loss. Args: z_disp_pred: Predicted displacement [n_nodes, 4] z_disp_true: Ground truth displacement [n_nodes, 4] objectives_true: Optional dict of true objective values Returns: total_loss, loss_components dict """ # Field reconstruction loss loss_field = nn.functional.mse_loss(z_disp_pred, z_disp_true) # Scale field loss to account for small displacement values # Displacements are ~1e-4 mm, so MSE is ~1e-8 loss_field_scaled = loss_field * 1e8 components = { 'loss_field': loss_field_scaled, } total_loss = self.field_weight * loss_field_scaled # Objective loss (if ground truth provided) if objectives_true is not None and self.objective_weight > 0: objectives_pred = self.objective_layer(z_disp_pred) loss_obj = 0.0 for key in ['rel_filtered_rms_40_vs_20', 'rel_filtered_rms_60_vs_20', 'mfg_90_optician_workload']: if key in objectives_true: pred = objectives_pred[key] true = objectives_true[key] # Relative error squared rel_err = ((pred - true) / (true + 1e-6)) ** 2 loss_obj = loss_obj + rel_err components[f'loss_{key}'] = rel_err components['loss_objectives'] = loss_obj total_loss = total_loss + self.objective_weight * loss_obj components['total_loss'] = total_loss return total_loss, components # ============================================================================= # Testing # ============================================================================= if __name__ == '__main__': import sys sys.path.insert(0, "C:/Users/Antoine/Atomizer") from optimization_engine.gnn.polar_graph import PolarMirrorGraph print("="*60) print("Testing Differentiable Zernike Layer") print("="*60) # Create polar graph graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60) print(f"\nPolar Graph: {graph.n_nodes} nodes") # Create Zernike fitting layer zernike_fit = DifferentiableZernikeFit(graph, n_modes=50) print(f"Zernike Fit: {zernike_fit.n_modes} modes") print(f" Z matrix: {zernike_fit.Z.shape}") print(f" Pseudo-inverse: {zernike_fit.pseudo_inverse.shape}") # Test with synthetic displacement print("\n--- Test Zernike Fitting ---") # Create synthetic displacement (defocus + astigmatism pattern) r_norm = torch.tensor(graph.r / graph.r_outer, dtype=torch.float32) theta = torch.tensor(graph.theta, dtype=torch.float32) # Defocus (J4) + Astigmatism (J5) synthetic_disp = 0.001 * (2 * r_norm**2 - 1) + 0.0005 * r_norm**2 * torch.cos(2 * theta) # Fit coefficients coeffs = zernike_fit(synthetic_disp) print(f"Fitted coefficients shape: {coeffs.shape}") print(f"First 10 coefficients: {coeffs[:10].tolist()}") # Reconstruct recon = zernike_fit.reconstruct(coeffs) error = (synthetic_disp - recon).abs() print(f"Reconstruction error: max={error.max():.6f}, mean={error.mean():.6f}") # Test with multiple subcases print("\n--- Test Multi-Subcase ---") z_disp_multi = torch.stack([ synthetic_disp, synthetic_disp * 0.5, synthetic_disp * 0.7, synthetic_disp * 0.9, ], dim=1) # [n_nodes, 4] coeffs_multi = zernike_fit(z_disp_multi) print(f"Multi-subcase coefficients: {coeffs_multi.shape}") # Test objective layer print("\n--- Test Objective Layer ---") objective_layer = ZernikeObjectiveLayer(graph, n_modes=50) objectives = objective_layer(z_disp_multi) print("Computed objectives:") for key, val in objectives.items(): print(f" {key}: {val.item():.2f} nm") # Test gradient flow print("\n--- Test Gradient Flow ---") z_disp_grad = z_disp_multi.clone().detach().requires_grad_(True) objectives = objective_layer(z_disp_grad) loss = objectives['rel_filtered_rms_40_vs_20'] loss.backward() print(f"Gradient shape: {z_disp_grad.grad.shape}") print(f"Gradient range: [{z_disp_grad.grad.min():.6f}, {z_disp_grad.grad.max():.6f}]") print("✓ Gradients flow through Zernike fitting!") # Test loss function print("\n--- Test Loss Function ---") loss_fn = ZernikeRMSLoss(graph, field_weight=1.0, objective_weight=0.1) z_pred = (z_disp_multi.detach() + 0.0001 * torch.randn_like(z_disp_multi)).requires_grad_(True) total_loss, components = loss_fn(z_pred, z_disp_multi.detach()) print(f"Total loss: {total_loss.item():.6f}") for key, val in components.items(): if isinstance(val, torch.Tensor): print(f" {key}: {val.item():.6f}") print("\n" + "="*60) print("✓ All tests passed!") print("="*60)