""" Surface-based Zernike Extraction ================================ This module implements a surface-fitting approach for Zernike extraction that is robust to mesh regeneration. Instead of relying on specific mesh node IDs (which change when the mesh regenerates), it: 1. Reads ALL nodes from the OP2/BDF 2. Identifies optical surface nodes by spatial filtering (Z coordinate and radial position) 3. Interpolates displacements onto a fixed regular polar grid 4. Fits Zernike polynomials to the interpolated surface This approach is stable across mesh changes because: - The optical surface geometry is defined by physical location, not node IDs - The interpolation handles different mesh densities consistently - The fixed evaluation grid ensures comparable results across iterations """ import numpy as np from pathlib import Path from typing import Dict, Any, Optional, Tuple, List from math import factorial from scipy.interpolate import RBFInterpolator from pyNastran.op2.op2 import OP2 from pyNastran.bdf.bdf import BDF # ============================================================================= # Zernike Polynomial Functions (Noll ordering) # ============================================================================= def noll_indices(j: int) -> Tuple[int, int]: """Convert Noll index j to radial order n and azimuthal order m.""" if j < 1: raise ValueError("Noll index j must be >= 1") count = 0 n = 0 while True: if n == 0: ms = [0] elif n % 2 == 0: ms = [0] + [m for k in range(1, n // 2 + 1) for m in (-2 * k, 2 * k)] else: ms = [m for k in range(0, (n + 1) // 2) for m in (-(2 * k + 1), (2 * k + 1))] for m in ms: count += 1 if count == j: return n, m n += 1 def zernike_noll(j: int, r: np.ndarray, th: np.ndarray) -> np.ndarray: """ Compute Zernike polynomial for Noll index j. Args: j: Noll index (1-based) r: Normalized radial coordinates (0 to 1) th: Angular coordinates in radians Returns: Zernike polynomial values at each (r, theta) point """ n, m = noll_indices(j) R = np.zeros_like(r) for s in range((n - abs(m)) // 2 + 1): c = ((-1) ** s * factorial(n - s) / (factorial(s) * factorial((n + abs(m)) // 2 - s) * factorial((n - abs(m)) // 2 - s))) R += c * r ** (n - 2 * s) if m == 0: return R return R * (np.cos(m * th) if m > 0 else np.sin(-m * th)) def zernike_mode_name(j: int) -> str: """Return descriptive name for Zernike mode.""" names = { 1: "Piston", 2: "Tilt X", 3: "Tilt Y", 4: "Defocus", 5: "Astigmatism 45°", 6: "Astigmatism 0°", 7: "Coma X", 8: "Coma Y", 9: "Trefoil X", 10: "Trefoil Y", 11: "Spherical", } return names.get(j, f"Z{j}") # ============================================================================= # Surface Identification # ============================================================================= def identify_optical_surface_nodes( node_coords: Dict[int, np.ndarray], r_inner: float = 100.0, r_outer: float = 650.0, z_tolerance: float = 100.0 ) -> List[int]: """ Identify nodes on the optical surface by spatial filtering. The optical surface is identified by: 1. Radial position (between inner and outer radius) 2. Consistent Z range (nodes on the curved mirror surface) Args: node_coords: Dictionary mapping node ID to (X, Y, Z) coordinates r_inner: Inner radius cutoff (central hole) r_outer: Outer radius limit z_tolerance: Maximum Z deviation from mean to include Returns: List of node IDs on the optical surface """ # Get all coordinates as arrays nids = list(node_coords.keys()) coords = np.array([node_coords[nid] for nid in nids]) # Calculate radial position r = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2) # Initial radial filter radial_mask = (r >= r_inner) & (r <= r_outer) # Find nodes in radial range radial_nids = np.array(nids)[radial_mask] radial_coords = coords[radial_mask] if len(radial_coords) == 0: raise ValueError(f"No nodes found in radial range [{r_inner}, {r_outer}]") # The optical surface should have a relatively small Z range # Find the mode of Z values (most common Z region) z_vals = radial_coords[:, 2] z_mean = np.mean(z_vals) z_std = np.std(z_vals) # Filter to nodes within z_tolerance of the mean Z z_mask = np.abs(radial_coords[:, 2] - z_mean) < z_tolerance surface_nids = radial_nids[z_mask] print(f"[SURFACE] Identified {len(surface_nids)} optical surface nodes") print(f"[SURFACE] Radial range: [{r_inner:.1f}, {r_outer:.1f}] mm") print(f"[SURFACE] Z range: [{z_vals[z_mask].min():.1f}, {z_vals[z_mask].max():.1f}] mm") return surface_nids.tolist() # ============================================================================= # Interpolation to Regular Grid # ============================================================================= def create_polar_grid( r_max: float, r_min: float = 0.0, n_radial: int = 50, n_angular: int = 60 ) -> Tuple[np.ndarray, np.ndarray]: """ Create a regular polar grid for evaluation. Args: r_max: Maximum radius r_min: Minimum radius (central hole) n_radial: Number of radial points n_angular: Number of angular points Returns: X, Y coordinates of grid points """ r = np.linspace(r_min, r_max, n_radial) theta = np.linspace(0, 2 * np.pi, n_angular, endpoint=False) R, Theta = np.meshgrid(r, theta) X = R * np.cos(Theta) Y = R * np.sin(Theta) return X.flatten(), Y.flatten() def interpolate_to_grid( X_mesh: np.ndarray, Y_mesh: np.ndarray, values: np.ndarray, X_grid: np.ndarray, Y_grid: np.ndarray, method: str = 'rbf' ) -> np.ndarray: """ Interpolate scattered data to regular grid. Args: X_mesh, Y_mesh: Scattered mesh coordinates values: Values at mesh points X_grid, Y_grid: Target grid coordinates method: Interpolation method ('rbf' for radial basis function) Returns: Interpolated values at grid points """ # Remove NaN values valid = ~np.isnan(values) & ~np.isnan(X_mesh) & ~np.isnan(Y_mesh) X_valid = X_mesh[valid] Y_valid = Y_mesh[valid] V_valid = values[valid] if len(V_valid) < 10: raise ValueError(f"Not enough valid points for interpolation: {len(V_valid)}") # RBF interpolation (works well for scattered data) points = np.column_stack([X_valid, Y_valid]) interp = RBFInterpolator(points, V_valid, kernel='thin_plate_spline', smoothing=1e-6) grid_points = np.column_stack([X_grid, Y_grid]) return interp(grid_points) # ============================================================================= # Zernike Fitting # ============================================================================= def fit_zernike_coefficients( X: np.ndarray, Y: np.ndarray, W: np.ndarray, n_modes: int = 50, R_max: Optional[float] = None ) -> Tuple[np.ndarray, float]: """ Fit Zernike polynomials to surface data. Args: X, Y: Coordinates W: Surface values (WFE in nm) n_modes: Number of Zernike modes to fit R_max: Maximum radius for normalization (auto-computed if None) Returns: Tuple of (coefficients array, R_max used) """ # Center the data X_c = X - np.mean(X) Y_c = Y - np.mean(Y) # Normalize to unit disk if R_max is None: R_max = np.max(np.hypot(X_c, Y_c)) r = np.hypot(X_c / R_max, Y_c / R_max) theta = np.arctan2(Y_c, X_c) # Mask points inside unit disk with valid values mask = (r <= 1.0) & ~np.isnan(W) if not np.any(mask): raise RuntimeError("No valid points inside unit disk") r_valid = r[mask] theta_valid = theta[mask] W_valid = W[mask] # Build Zernike basis matrix Z = np.column_stack([ zernike_noll(j, r_valid, theta_valid).astype(np.float64) for j in range(1, n_modes + 1) ]) # Least squares fit try: coeffs = np.linalg.lstsq(Z, W_valid, rcond=None)[0] except np.linalg.LinAlgError: # Fallback to pseudo-inverse coeffs = np.linalg.pinv(Z) @ W_valid return coeffs, R_max def compute_rms_from_surface( X: np.ndarray, Y: np.ndarray, W: np.ndarray, coeffs: np.ndarray, R_max: float, filter_low_orders: int = 4 ) -> Tuple[float, float]: """ Compute global and filtered RMS from surface data. Args: X, Y: Coordinates W: Surface values (WFE in nm) coeffs: Zernike coefficients R_max: Normalization radius filter_low_orders: Number of low-order modes to filter (typically 4 for piston/tip/tilt/defocus) Returns: Tuple of (global_rms, filtered_rms) """ # Center and normalize X_c = X - np.mean(X) Y_c = Y - np.mean(Y) r = np.hypot(X_c / R_max, Y_c / R_max) theta = np.arctan2(Y_c, X_c) # Mask for valid points mask = (r <= 1.0) & ~np.isnan(W) W_valid = W[mask] # Global RMS global_rms = np.sqrt(np.mean(W_valid ** 2)) # Build low-order Zernike matrix Z_low = np.column_stack([ zernike_noll(j, r[mask], theta[mask]) for j in range(1, filter_low_orders + 1) ]) # Reconstruct low-order surface W_low = Z_low @ coeffs[:filter_low_orders] # Filtered residual W_filtered = W_valid - W_low filtered_rms = np.sqrt(np.mean(W_filtered ** 2)) return global_rms, filtered_rms # ============================================================================= # Main Extraction Function # ============================================================================= class SurfaceZernikeExtractor: """ Surface-based Zernike extractor that is robust to mesh changes. """ def __init__( self, r_inner: float = 100.0, r_outer: float = 600.0, z_tolerance: float = 100.0, n_modes: int = 50, grid_n_radial: int = 50, grid_n_angular: int = 60 ): """ Initialize the extractor. Args: r_inner: Inner radius cutoff for surface identification r_outer: Outer radius cutoff z_tolerance: Z tolerance for surface identification n_modes: Number of Zernike modes to fit grid_n_radial: Number of radial points in interpolation grid grid_n_angular: Number of angular points in interpolation grid """ self.r_inner = r_inner self.r_outer = r_outer self.z_tolerance = z_tolerance self.n_modes = n_modes self.grid_n_radial = grid_n_radial self.grid_n_angular = grid_n_angular # Create the evaluation grid once self.X_grid, self.Y_grid = create_polar_grid( r_max=r_outer, r_min=r_inner, n_radial=grid_n_radial, n_angular=grid_n_angular ) def extract_from_op2( self, op2_path: Path, bdf_path: Optional[Path] = None, reference_subcase: str = "2", subcases: Optional[List[str]] = None ) -> Dict[str, Dict[str, Any]]: """ Extract Zernike coefficients from OP2 file using surface fitting. Args: op2_path: Path to OP2 file bdf_path: Path to BDF/DAT file (auto-detected if None) reference_subcase: Reference subcase ID for relative calculations subcases: List of subcase IDs to process (default: ["1", "2", "3", "4"]) Returns: Dictionary with results for each subcase """ op2_path = Path(op2_path) # Find BDF file if bdf_path is None: for ext in ['.dat', '.bdf']: candidate = op2_path.with_suffix(ext) if candidate.exists(): bdf_path = candidate break if bdf_path is None: raise FileNotFoundError(f"No .dat or .bdf found for {op2_path}") if subcases is None: subcases = ["1", "2", "3", "4"] print(f"[SURFACE ZERNIKE] Reading geometry from: {bdf_path.name}") # Read geometry bdf = BDF() bdf.read_bdf(str(bdf_path)) node_geo = {int(nid): node.get_position() for nid, node in bdf.nodes.items()} print(f"[SURFACE ZERNIKE] Total nodes in BDF: {len(node_geo)}") # Read OP2 print(f"[SURFACE ZERNIKE] Reading displacements from: {op2_path.name}") op2 = OP2() op2.read_op2(str(op2_path)) if not op2.displacements: raise RuntimeError("No displacement data in OP2") # Extract data for each subcase subcase_data = {} NM_PER_MM = 1e6 # mm to nm conversion for key, darr in op2.displacements.items(): isub = str(getattr(darr, 'isubcase', key)) if isub not in subcases: continue data = darr.data dmat = data[0] if data.ndim == 3 else data ngt = darr.node_gridtype node_ids = ngt[:, 0] if ngt.ndim == 2 else ngt # Get coordinates and Z displacement for each node X = [] Y = [] disp_z = [] for i, nid in enumerate(node_ids): if int(nid) in node_geo: pos = node_geo[int(nid)] X.append(pos[0]) Y.append(pos[1]) disp_z.append(float(dmat[i, 2])) X = np.array(X) Y = np.array(Y) disp_z = np.array(disp_z) # Filter to optical surface by radial position r = np.sqrt(X**2 + Y**2) surface_mask = (r >= self.r_inner) & (r <= self.r_outer) subcase_data[isub] = { 'X': X[surface_mask], 'Y': Y[surface_mask], 'disp_z': disp_z[surface_mask], 'node_count': int(np.sum(surface_mask)) } print(f"[SURFACE ZERNIKE] Subcase {isub}: {subcase_data[isub]['node_count']} surface nodes") # Process each subcase results = {} for isub in subcases: if isub not in subcase_data: print(f"[SURFACE ZERNIKE] WARNING: Subcase {isub} not found") results[isub] = None continue data = subcase_data[isub] X = data['X'] Y = data['Y'] disp_z = data['disp_z'] # Convert to WFE (nm) wfe_nm = 2.0 * disp_z * NM_PER_MM # Interpolate to regular grid try: wfe_grid = interpolate_to_grid(X, Y, wfe_nm, self.X_grid, self.Y_grid) except Exception as e: print(f"[SURFACE ZERNIKE] WARNING: Interpolation failed for subcase {isub}: {e}") results[isub] = None continue # Fit Zernike to interpolated surface coeffs, R_max = fit_zernike_coefficients( self.X_grid, self.Y_grid, wfe_grid, self.n_modes ) # Compute RMS global_rms, filtered_rms = compute_rms_from_surface( self.X_grid, self.Y_grid, wfe_grid, coeffs, R_max ) result = { 'coefficients': coeffs.tolist(), 'global_rms_nm': global_rms, 'filtered_rms_nm': filtered_rms, 'R_max': R_max, 'node_count': data['node_count'], } # Compute relative (vs reference) if not the reference if isub != reference_subcase and reference_subcase in subcase_data: ref_data = subcase_data[reference_subcase] # We need to interpolate both to the same grid and subtract ref_wfe = 2.0 * ref_data['disp_z'] * NM_PER_MM ref_grid = interpolate_to_grid( ref_data['X'], ref_data['Y'], ref_wfe, self.X_grid, self.Y_grid ) # Relative WFE rel_wfe = wfe_grid - ref_grid # Fit and compute RMS for relative surface rel_coeffs, _ = fit_zernike_coefficients( self.X_grid, self.Y_grid, rel_wfe, self.n_modes, R_max ) rel_global_rms, rel_filtered_rms = compute_rms_from_surface( self.X_grid, self.Y_grid, rel_wfe, rel_coeffs, R_max ) result['relative_coefficients'] = rel_coeffs.tolist() result['relative_global_rms_nm'] = rel_global_rms result['relative_filtered_rms_nm'] = rel_filtered_rms results[isub] = result return results # ============================================================================= # Convenience Function # ============================================================================= def extract_surface_zernike( op2_path: Path, bdf_path: Optional[Path] = None, reference_subcase: str = "2", r_inner: float = 100.0, r_outer: float = 600.0 ) -> Dict[str, Dict[str, Any]]: """ Convenience function to extract Zernike coefficients using surface fitting. Args: op2_path: Path to OP2 file bdf_path: Path to BDF/DAT file (auto-detected if None) reference_subcase: Reference subcase ID r_inner: Inner radius for surface identification r_outer: Outer radius for surface identification Returns: Dictionary with results for each subcase """ extractor = SurfaceZernikeExtractor( r_inner=r_inner, r_outer=r_outer ) return extractor.extract_from_op2(op2_path, bdf_path, reference_subcase) if __name__ == '__main__': # Test the extractor import sys if len(sys.argv) < 2: print("Usage: python extract_zernike_surface.py ") sys.exit(1) op2_path = Path(sys.argv[1]) results = extract_surface_zernike(op2_path) print("\n" + "="*60) print("RESULTS") print("="*60) for isub, data in results.items(): if data is None: print(f"\nSubcase {isub}: No data") continue print(f"\nSubcase {isub}:") print(f" Global RMS: {data['global_rms_nm']:.2f} nm") print(f" Filtered RMS: {data['filtered_rms_nm']:.2f} nm") if 'relative_global_rms_nm' in data: print(f" Relative Global RMS: {data['relative_global_rms_nm']:.2f} nm") print(f" Relative Filtered RMS: {data['relative_filtered_rms_nm']:.2f} nm")