Files
Atomizer/tools/validate_zernike_roundtrip.py
Antoine f9373bee99 feat(V&V): Zernike pipeline validation - synthetic WFE generator + round-trip validator
- generate_synthetic_wfe.py: Creates synthetic OPD surfaces from known Zernike coefficients
- validate_zernike_roundtrip.py: Round-trip validation (generate → fit → compare)
- validation_suite/: 18 test cases (single mode, multi-mode, noisy, edge cases)
- All 18 test cases pass with 0.000 nm error (clean) and <0.3 nm (10nm noise)
- M1 params: 1200mm dia, 135.75mm inner radius, 50 Noll modes

Project: P-Zernike-Validation (GigaBIT M1)
Requested by: Adyn Miles (StarSpec) for risk reduction before M2/M3 ordering
2026-03-09 15:49:06 +00:00

262 lines
8.0 KiB
Python

#!/usr/bin/env python3
"""
Zernike Round-Trip Validator
=============================
Reads a synthetic WFE CSV + its truth JSON, fits Zernike coefficients,
and compares recovered vs input coefficients.
This validates that the Zernike fitting math itself is correct by
doing a generate → fit → compare round-trip.
Usage:
# Single file
python validate_zernike_roundtrip.py validation_suite/Z05_astig_0deg.csv
# Full suite
python validate_zernike_roundtrip.py --suite validation_suite/
Author: Mario (Atomizer V&V)
Created: 2026-03-09
"""
import sys
import json
import argparse
from pathlib import Path
from math import factorial
from typing import Dict, Tuple
import numpy as np
from numpy.linalg import lstsq
# ============================================================================
# Zernike Math (same as generate_synthetic_wfe.py)
# ============================================================================
def noll_indices(j: int) -> Tuple[int, int]:
if j < 1:
raise ValueError("Noll index j must be >= 1")
count = 0
n = 0
while True:
if n == 0:
ms = [0]
elif n % 2 == 0:
ms = [0] + [m for k in range(1, n // 2 + 1) for m in (-2 * k, 2 * k)]
else:
ms = [m for k in range(0, (n + 1) // 2) for m in (-(2 * k + 1), (2 * k + 1))]
for m in ms:
count += 1
if count == j:
return n, m
n += 1
def zernike_radial(n, m, r):
R = np.zeros_like(r)
m_abs = abs(m)
for s in range((n - m_abs) // 2 + 1):
coef = ((-1) ** s * factorial(n - s) /
(factorial(s) * factorial((n + m_abs) // 2 - s) * factorial((n - m_abs) // 2 - s)))
R += coef * r ** (n - 2 * s)
return R
def zernike_noll(j, r, theta):
n, m = noll_indices(j)
R = zernike_radial(n, m, r)
if m == 0:
return R
elif m > 0:
return R * np.cos(m * theta)
else:
return R * np.sin(-m * theta)
def zernike_name(j):
n, m = noll_indices(j)
names = {
(0, 0): "Piston", (1, -1): "Tilt X", (1, 1): "Tilt Y",
(2, 0): "Defocus", (2, -2): "Astig 45°", (2, 2): "Astig 0°",
(3, -1): "Coma X", (3, 1): "Coma Y",
(3, -3): "Trefoil X", (3, 3): "Trefoil Y",
(4, 0): "Spherical", (4, -2): "2ndAstig X", (4, 2): "2ndAstig Y",
(6, 0): "2nd Spherical",
}
return names.get((n, m), f"Z({n},{m:+d})")
# ============================================================================
# Zernike Fitting (Least Squares)
# ============================================================================
def fit_zernike(x_mm, y_mm, opd_nm, diameter_mm=1200.0, n_modes=50):
"""
Fit Zernike coefficients to OPD data via least-squares.
Returns:
coefficients: array of shape (n_modes,), Noll-indexed from j=1
"""
outer_radius = diameter_mm / 2.0
r_norm = np.sqrt(x_mm**2 + y_mm**2) / outer_radius
theta = np.arctan2(y_mm, x_mm)
# Build Zernike basis matrix
Z = np.zeros((len(x_mm), n_modes))
for j in range(1, n_modes + 1):
Z[:, j - 1] = zernike_noll(j, r_norm, theta)
# Least-squares fit
coeffs, residuals, rank, sv = lstsq(Z, opd_nm, rcond=None)
return coeffs
# ============================================================================
# Validation
# ============================================================================
def validate_file(csv_path: str, n_modes: int = 50, diameter_mm: float = 1200.0,
tolerance_nm: float = 0.5, verbose: bool = True):
"""
Validate a single synthetic WFE file.
Returns:
passed: bool
results: dict with comparison details
"""
csv_path = Path(csv_path)
truth_path = csv_path.with_name(csv_path.stem + "_truth.json")
if not truth_path.exists():
# Try removing any suffix before _truth
base = csv_path.stem
truth_path = csv_path.with_name(base + "_truth.json")
if not truth_path.exists():
print(f" WARNING: No truth file found for {csv_path.name}")
return None, None
# Load CSV
data = np.loadtxt(csv_path, delimiter=",", skiprows=1)
x_mm = data[:, 0]
y_mm = data[:, 1]
opd_nm = data[:, 2]
# Load truth
with open(truth_path) as f:
truth = json.load(f)
input_coeffs = {int(k): v for k, v in truth["input_coefficients"].items()}
# Fit Zernike
recovered = fit_zernike(x_mm, y_mm, opd_nm, diameter_mm, n_modes)
# Compare
max_error = 0.0
results = {"modes": {}}
all_passed = True
if verbose:
print(f"\n {'Mode':>6} {'Name':>20} {'Input(nm)':>10} {'Recovered(nm)':>14} {'Error(nm)':>10} {'Status':>8}")
print(f" {'-'*6} {'-'*20} {'-'*10} {'-'*14} {'-'*10} {'-'*8}")
for j in range(1, n_modes + 1):
input_val = input_coeffs.get(j, 0.0)
recovered_val = recovered[j - 1]
error = abs(recovered_val - input_val)
max_error = max(max_error, error)
mode_passed = error < tolerance_nm
if not mode_passed:
all_passed = False
results["modes"][j] = {
"input": input_val,
"recovered": float(recovered_val),
"error": float(error),
"passed": mode_passed,
}
# Only print modes with non-zero input or significant recovery
if verbose and (abs(input_val) > 0.01 or abs(recovered_val) > tolerance_nm):
status = "" if mode_passed else ""
print(f" Z{j:>4d} {zernike_name(j):>20} {input_val:>10.3f} {recovered_val:>14.3f} {error:>10.6f} {status:>8}")
results["max_error_nm"] = float(max_error)
results["all_passed"] = all_passed
results["tolerance_nm"] = tolerance_nm
results["n_points"] = len(x_mm)
if verbose:
print(f"\n Max error: {max_error:.6f} nm")
print(f" Tolerance: {tolerance_nm:.3f} nm")
print(f" Result: {'✅ PASS' if all_passed else '❌ FAIL'}")
return all_passed, results
def validate_suite(suite_dir: str, n_modes: int = 50, tolerance_nm: float = 0.5):
"""Validate all test cases in a suite directory."""
suite_dir = Path(suite_dir)
csv_files = sorted(suite_dir.glob("*.csv"))
print(f"\nValidating {len(csv_files)} test cases in: {suite_dir}")
print("=" * 70)
summary = {}
n_pass = 0
n_fail = 0
n_skip = 0
for csv_file in csv_files:
print(f"\n{''*70}")
print(f"Test: {csv_file.name}")
passed, results = validate_file(csv_file, n_modes, tolerance_nm=tolerance_nm)
if passed is None:
n_skip += 1
summary[csv_file.stem] = "SKIP"
elif passed:
n_pass += 1
summary[csv_file.stem] = "PASS"
else:
n_fail += 1
summary[csv_file.stem] = "FAIL"
print(f"\n{'='*70}")
print(f"SUITE SUMMARY")
print(f"{'='*70}")
print(f" PASS: {n_pass}")
print(f" FAIL: {n_fail}")
print(f" SKIP: {n_skip}")
print(f" Total: {len(csv_files)}")
print(f"\n Overall: {'✅ ALL PASSED' if n_fail == 0 else '❌ FAILURES DETECTED'}")
return n_fail == 0, summary
def main():
parser = argparse.ArgumentParser(description="Validate Zernike round-trip accuracy")
parser.add_argument("input", nargs="?", help="CSV file or --suite directory")
parser.add_argument("--suite", type=str, help="Validate all CSVs in directory")
parser.add_argument("--n-modes", type=int, default=50)
parser.add_argument("--tolerance", type=float, default=0.5,
help="Max acceptable coefficient error in nm (default: 0.5)")
parser.add_argument("--diameter", type=float, default=1200.0)
args = parser.parse_args()
if args.suite:
passed, summary = validate_suite(args.suite, args.n_modes, args.tolerance)
sys.exit(0 if passed else 1)
elif args.input:
passed, results = validate_file(args.input, args.n_modes, args.diameter, args.tolerance)
sys.exit(0 if passed else 1)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()