Files
Atomizer/tools/validate_zernike_roundtrip.py
Antoine 4146e9d8f1 feat(V&V): Updated to FEA CSV format + real M2 mesh injection
- Output now matches WFE_from_CSV_OPD format: ,X,Y,Z,DX,DY,DZ (meters)
- Suite regenerated using real M2 mesh (357 nodes, 308mm diameter)
- All 14 clean test cases: PASS (0.000 nm error)
- 3 noisy cases: expected FAIL due to low node count amplifying noise
- Added --inject mode to use real FEA mesh geometry
- Added lateral displacement test case
2026-03-09 15:56:23 +00:00

299 lines
9.1 KiB
Python

#!/usr/bin/env python3
"""
Zernike Round-Trip Validator
=============================
Reads a synthetic WFE CSV (FEA format: X,Y,Z,DX,DY,DZ) + its truth JSON,
fits Zernike coefficients from DZ, and compares recovered vs input.
This validates that the Zernike fitting math is correct by doing a
generate → fit → compare round-trip.
Usage:
# Single file
python validate_zernike_roundtrip.py validation_suite/Z05_astig_0deg.csv
# Full suite
python validate_zernike_roundtrip.py --suite validation_suite/
Author: Mario (Atomizer V&V)
Created: 2026-03-09
"""
import sys
import csv
import json
import argparse
from pathlib import Path
from math import factorial
from typing import Dict, Tuple
import numpy as np
from numpy.linalg import lstsq
# ============================================================================
# Zernike Math (same as generate_synthetic_wfe.py)
# ============================================================================
def noll_indices(j: int) -> Tuple[int, int]:
if j < 1:
raise ValueError("Noll index j must be >= 1")
count = 0
n = 0
while True:
if n == 0:
ms = [0]
elif n % 2 == 0:
ms = [0] + [m for k in range(1, n // 2 + 1) for m in (-2 * k, 2 * k)]
else:
ms = [m for k in range(0, (n + 1) // 2) for m in (-(2 * k + 1), (2 * k + 1))]
for m in ms:
count += 1
if count == j:
return n, m
n += 1
def zernike_radial(n, m, r):
R = np.zeros_like(r)
m_abs = abs(m)
for s in range((n - m_abs) // 2 + 1):
coef = ((-1) ** s * factorial(n - s) /
(factorial(s) * factorial((n + m_abs) // 2 - s) * factorial((n - m_abs) // 2 - s)))
R += coef * r ** (n - 2 * s)
return R
def zernike_noll(j, r, theta):
n, m = noll_indices(j)
R = zernike_radial(n, m, r)
if m == 0:
return R
elif m > 0:
return R * np.cos(m * theta)
else:
return R * np.sin(-m * theta)
def zernike_name(j):
n, m = noll_indices(j)
names = {
(0, 0): "Piston", (1, -1): "Tilt X", (1, 1): "Tilt Y",
(2, 0): "Defocus", (2, -2): "Astig 45°", (2, 2): "Astig 0°",
(3, -1): "Coma X", (3, 1): "Coma Y",
(3, -3): "Trefoil X", (3, 3): "Trefoil Y",
(4, 0): "Spherical", (4, -2): "2ndAstig X", (4, 2): "2ndAstig Y",
(6, 0): "2nd Spherical",
}
return names.get((n, m), f"Z({n},{m:+d})")
# ============================================================================
# CSV Reading (FEA format)
# ============================================================================
def read_fea_csv(csv_path: str):
"""
Read FEA-format CSV with columns: (index), X, Y, Z, DX, DY, DZ.
All values in meters.
Returns:
x_m, y_m, z_m, dx_m, dy_m, dz_m: arrays
"""
x, y, z, dx, dy, dz = [], [], [], [], [], []
with open(csv_path) as f:
reader = csv.DictReader(f)
for row in reader:
x.append(float(row['X']))
y.append(float(row['Y']))
z.append(float(row['Z']))
dx.append(float(row['DX']))
dy.append(float(row['DY']))
dz.append(float(row['DZ']))
return (np.array(x), np.array(y), np.array(z),
np.array(dx), np.array(dy), np.array(dz))
# ============================================================================
# Zernike Fitting (Least Squares)
# ============================================================================
def fit_zernike(x_m, y_m, dz_m, diameter_mm=None, n_modes=50):
"""
Fit Zernike coefficients to DZ displacement data.
Args:
x_m, y_m: Node positions in meters
dz_m: Z-displacement in meters
diameter_mm: Mirror diameter (auto-detected if None)
n_modes: Number of Zernike modes to fit
Returns:
coefficients_nm: array of shape (n_modes,), amplitudes in nm
"""
if diameter_mm is None:
r_mm = np.sqrt(x_m**2 + y_m**2) * 1000.0
diameter_mm = 2.0 * np.max(r_mm)
outer_r_m = diameter_mm / 2000.0
r_norm = np.sqrt(x_m**2 + y_m**2) / outer_r_m
theta = np.arctan2(y_m, x_m)
# Convert DZ to nm
dz_nm = dz_m * 1e9
# Build Zernike basis matrix
Z = np.zeros((len(x_m), n_modes))
for j in range(1, n_modes + 1):
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
# Least-squares fit
coeffs, residuals, rank, sv = lstsq(Z, dz_nm, rcond=None)
return coeffs
# ============================================================================
# Validation
# ============================================================================
def validate_file(csv_path: str, n_modes: int = 50, diameter_mm: float = None,
tolerance_nm: float = 0.5, verbose: bool = True):
"""
Validate a single synthetic WFE file.
Returns:
passed: bool
results: dict
"""
csv_path = Path(csv_path)
truth_path = csv_path.with_name(csv_path.stem + "_truth.json")
if not truth_path.exists():
print(f" WARNING: No truth file found for {csv_path.name}")
return None, None
# Load CSV (FEA format)
x_m, y_m, z_m, dx_m, dy_m, dz_m = read_fea_csv(csv_path)
# Load truth
with open(truth_path) as f:
truth = json.load(f)
input_coeffs = {int(k): v for k, v in truth["input_coefficients"].items()}
# Use diameter from truth if available
if diameter_mm is None:
diameter_mm = truth.get("diameter_mm")
# Fit Zernike from DZ
recovered = fit_zernike(x_m, y_m, dz_m, diameter_mm, n_modes)
# Compare
max_error = 0.0
results = {"modes": {}}
all_passed = True
if verbose:
print(f"\n {'Mode':>6} {'Name':>20} {'Input(nm)':>10} {'Recovered(nm)':>14} {'Error(nm)':>10} {'Status':>8}")
print(f" {'-'*6} {'-'*20} {'-'*10} {'-'*14} {'-'*10} {'-'*8}")
for j in range(1, n_modes + 1):
input_val = input_coeffs.get(j, 0.0)
recovered_val = recovered[j - 1]
error = abs(recovered_val - input_val)
max_error = max(max_error, error)
mode_passed = error < tolerance_nm
if not mode_passed:
all_passed = False
results["modes"][j] = {
"input": input_val,
"recovered": float(recovered_val),
"error": float(error),
"passed": mode_passed,
}
if verbose and (abs(input_val) > 0.01 or abs(recovered_val) > tolerance_nm):
status = "" if mode_passed else ""
print(f" Z{j:>4d} {zernike_name(j):>20} {input_val:>10.3f} {recovered_val:>14.3f} {error:>10.6f} {status:>8}")
results["max_error_nm"] = float(max_error)
results["all_passed"] = all_passed
results["tolerance_nm"] = tolerance_nm
results["n_points"] = len(x_m)
if verbose:
print(f"\n Max error: {max_error:.6f} nm")
print(f" Tolerance: {tolerance_nm:.3f} nm")
print(f" Result: {'✅ PASS' if all_passed else '❌ FAIL'}")
return all_passed, results
def validate_suite(suite_dir: str, n_modes: int = 50, tolerance_nm: float = 0.5):
"""Validate all test cases in a suite directory."""
suite_dir = Path(suite_dir)
csv_files = sorted(suite_dir.glob("*.csv"))
print(f"\nValidating {len(csv_files)} test cases in: {suite_dir}")
print("=" * 70)
summary = {}
n_pass = 0
n_fail = 0
n_skip = 0
for csv_file in csv_files:
print(f"\n{''*70}")
print(f"Test: {csv_file.name}")
passed, results = validate_file(csv_file, n_modes, tolerance_nm=tolerance_nm)
if passed is None:
n_skip += 1
summary[csv_file.stem] = "SKIP"
elif passed:
n_pass += 1
summary[csv_file.stem] = "PASS"
else:
n_fail += 1
summary[csv_file.stem] = "FAIL"
print(f"\n{'='*70}")
print(f"SUITE SUMMARY")
print(f"{'='*70}")
print(f" PASS: {n_pass}")
print(f" FAIL: {n_fail}")
print(f" SKIP: {n_skip}")
print(f" Total: {len(csv_files)}")
print(f"\n Overall: {'✅ ALL PASSED' if n_fail == 0 else '❌ FAILURES DETECTED'}")
return n_fail == 0, summary
def main():
parser = argparse.ArgumentParser(description="Validate Zernike round-trip accuracy")
parser.add_argument("input", nargs="?", help="CSV file or use --suite")
parser.add_argument("--suite", type=str, help="Validate all CSVs in directory")
parser.add_argument("--n-modes", type=int, default=50)
parser.add_argument("--tolerance", type=float, default=0.5,
help="Max acceptable coefficient error in nm (default: 0.5)")
parser.add_argument("--diameter", type=float, default=None,
help="Mirror diameter in mm (auto-detected if not set)")
args = parser.parse_args()
if args.suite:
passed, summary = validate_suite(args.suite, args.n_modes, args.tolerance)
sys.exit(0 if passed else 1)
elif args.input:
passed, results = validate_file(args.input, args.n_modes, args.diameter, args.tolerance)
sys.exit(0 if passed else 1)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()