This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
618 lines
22 KiB
Python
618 lines
22 KiB
Python
"""
|
||
Polar Mirror Graph for GNN Training
|
||
====================================
|
||
|
||
This module creates a fixed polar grid graph structure for the mirror optical surface.
|
||
The key insight is that the mirror has a fixed topology (circular annulus), so we can
|
||
use a fixed graph structure regardless of FEA mesh variations.
|
||
|
||
Why Polar Grid?
|
||
1. Matches mirror geometry (annulus)
|
||
2. Same approach as extract_zernike_surface.py
|
||
3. Enables mesh-independent training
|
||
4. Edge structure respects radial/angular physics
|
||
|
||
Grid Structure:
|
||
- n_radial points from r_inner to r_outer
|
||
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
|
||
- Total nodes = n_radial × n_angular
|
||
- Edges connect radial neighbors and angular neighbors (wrap-around)
|
||
|
||
Usage:
|
||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||
|
||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||
|
||
# Interpolate FEA results to fixed grid
|
||
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
|
||
|
||
# Get PyTorch Geometric data
|
||
data = graph.to_pyg_data(z_disp_grid, design_vars)
|
||
"""
|
||
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional, Tuple, List
|
||
import json
|
||
|
||
try:
|
||
import torch
|
||
HAS_TORCH = True
|
||
except ImportError:
|
||
HAS_TORCH = False
|
||
|
||
try:
|
||
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
|
||
from scipy.spatial import Delaunay
|
||
HAS_SCIPY = True
|
||
except ImportError:
|
||
HAS_SCIPY = False
|
||
|
||
|
||
class PolarMirrorGraph:
|
||
"""
|
||
Fixed polar grid graph for mirror optical surface.
|
||
|
||
This creates a mesh-independent graph structure that can be used for GNN training
|
||
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
|
||
|
||
Attributes:
|
||
n_nodes: Total number of nodes (n_radial × n_angular)
|
||
r: Radial coordinates [n_nodes]
|
||
theta: Angular coordinates [n_nodes]
|
||
x: Cartesian X coordinates [n_nodes]
|
||
y: Cartesian Y coordinates [n_nodes]
|
||
edge_index: Graph edges [2, n_edges]
|
||
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
r_inner: float = 100.0,
|
||
r_outer: float = 650.0,
|
||
n_radial: int = 50,
|
||
n_angular: int = 60
|
||
):
|
||
"""
|
||
Initialize polar grid graph.
|
||
|
||
Args:
|
||
r_inner: Inner radius (central hole), mm
|
||
r_outer: Outer radius, mm
|
||
n_radial: Number of radial samples
|
||
n_angular: Number of angular samples
|
||
"""
|
||
self.r_inner = r_inner
|
||
self.r_outer = r_outer
|
||
self.n_radial = n_radial
|
||
self.n_angular = n_angular
|
||
self.n_nodes = n_radial * n_angular
|
||
|
||
# Create polar grid coordinates
|
||
r_1d = np.linspace(r_inner, r_outer, n_radial)
|
||
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
|
||
|
||
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
|
||
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
|
||
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
|
||
|
||
# Flatten: radial index varies slowest
|
||
self.r = R.flatten().astype(np.float32)
|
||
self.theta = Theta.flatten().astype(np.float32)
|
||
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
|
||
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
|
||
|
||
# Build graph edges
|
||
self.edge_index, self.edge_attr = self._build_polar_edges()
|
||
|
||
# Precompute normalization factors
|
||
self._r_mean = (r_inner + r_outer) / 2
|
||
self._r_std = (r_outer - r_inner) / 2
|
||
|
||
def _node_index(self, i_r: int, i_theta: int) -> int:
|
||
"""Convert (radial_index, angular_index) to flat node index."""
|
||
# Angular wraps around
|
||
i_theta = i_theta % self.n_angular
|
||
return i_r * self.n_angular + i_theta
|
||
|
||
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
Create graph edges respecting polar topology.
|
||
|
||
Edge types:
|
||
1. Radial edges: Connect adjacent radial rings
|
||
2. Angular edges: Connect adjacent angular positions (with wrap-around)
|
||
3. Diagonal edges: Connect diagonal neighbors for better message passing
|
||
|
||
Returns:
|
||
edge_index: [2, n_edges] array of (source, target) pairs
|
||
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
|
||
"""
|
||
edges = []
|
||
edge_features = []
|
||
|
||
for i_r in range(self.n_radial):
|
||
for i_theta in range(self.n_angular):
|
||
node = self._node_index(i_r, i_theta)
|
||
|
||
# Radial neighbor (outward)
|
||
if i_r < self.n_radial - 1:
|
||
neighbor = self._node_index(i_r + 1, i_theta)
|
||
edges.append([node, neighbor])
|
||
edges.append([neighbor, node]) # Bidirectional
|
||
|
||
# Edge features: (dr, dtheta, distance, relative_angle)
|
||
dr = self.r[neighbor] - self.r[node]
|
||
dtheta = 0.0
|
||
dist = abs(dr)
|
||
angle = 0.0 # Radial direction
|
||
edge_features.append([dr, dtheta, dist, angle])
|
||
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
|
||
|
||
# Angular neighbor (counterclockwise, with wrap-around)
|
||
neighbor = self._node_index(i_r, i_theta + 1)
|
||
edges.append([node, neighbor])
|
||
edges.append([neighbor, node]) # Bidirectional
|
||
|
||
# Edge features for angular edge
|
||
dr = 0.0
|
||
dtheta = 2 * np.pi / self.n_angular
|
||
# Arc length at this radius
|
||
dist = self.r[node] * dtheta
|
||
angle = np.pi / 2 # Tangential direction
|
||
edge_features.append([dr, dtheta, dist, angle])
|
||
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
|
||
|
||
# Diagonal neighbor (outward + counterclockwise) for better connectivity
|
||
if i_r < self.n_radial - 1:
|
||
neighbor = self._node_index(i_r + 1, i_theta + 1)
|
||
edges.append([node, neighbor])
|
||
edges.append([neighbor, node])
|
||
|
||
dr = self.r[neighbor] - self.r[node]
|
||
dtheta = 2 * np.pi / self.n_angular
|
||
dx = self.x[neighbor] - self.x[node]
|
||
dy = self.y[neighbor] - self.y[node]
|
||
dist = np.sqrt(dx**2 + dy**2)
|
||
angle = np.arctan2(dy, dx)
|
||
edge_features.append([dr, dtheta, dist, angle])
|
||
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
|
||
|
||
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
|
||
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
|
||
|
||
return edge_index, edge_attr
|
||
|
||
def get_node_features(self, normalized: bool = True) -> np.ndarray:
|
||
"""
|
||
Get node features for GNN input.
|
||
|
||
Features: (r, theta, x, y) - polar and Cartesian coordinates
|
||
|
||
Args:
|
||
normalized: If True, normalize features to ~[-1, 1] range
|
||
|
||
Returns:
|
||
Node features [n_nodes, 4]
|
||
"""
|
||
if normalized:
|
||
r_norm = (self.r - self._r_mean) / self._r_std
|
||
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
|
||
x_norm = self.x / self.r_outer
|
||
y_norm = self.y / self.r_outer
|
||
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
|
||
else:
|
||
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
|
||
|
||
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
|
||
"""
|
||
Get edge features for GNN input.
|
||
|
||
Features: (dr, dtheta, distance, angle)
|
||
|
||
Args:
|
||
normalized: If True, normalize features
|
||
|
||
Returns:
|
||
Edge features [n_edges, 4]
|
||
"""
|
||
if normalized:
|
||
edge_attr = self.edge_attr.copy()
|
||
edge_attr[:, 0] /= self._r_std # dr
|
||
edge_attr[:, 1] /= np.pi # dtheta
|
||
edge_attr[:, 2] /= self.r_outer # distance
|
||
edge_attr[:, 3] /= np.pi # angle
|
||
return edge_attr
|
||
else:
|
||
return self.edge_attr
|
||
|
||
def interpolate_from_mesh(
|
||
self,
|
||
mesh_coords: np.ndarray,
|
||
mesh_values: np.ndarray,
|
||
method: str = 'rbf'
|
||
) -> np.ndarray:
|
||
"""
|
||
Interpolate FEA results from mesh nodes to fixed polar grid.
|
||
|
||
Args:
|
||
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
|
||
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
|
||
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
|
||
|
||
Returns:
|
||
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
|
||
"""
|
||
if not HAS_SCIPY:
|
||
raise ImportError("scipy required for interpolation: pip install scipy")
|
||
|
||
# Use only X, Y coordinates
|
||
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
|
||
|
||
# Handle multi-dimensional values
|
||
values_1d = mesh_values.ndim == 1
|
||
if values_1d:
|
||
mesh_values = mesh_values.reshape(-1, 1)
|
||
|
||
# Target coordinates
|
||
target_xy = np.column_stack([self.x, self.y])
|
||
|
||
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
|
||
|
||
for i in range(mesh_values.shape[1]):
|
||
vals = mesh_values[:, i]
|
||
|
||
if method == 'rbf':
|
||
# RBF interpolation - smooth, handles scattered data well
|
||
interp = RBFInterpolator(
|
||
xy, vals,
|
||
kernel='thin_plate_spline',
|
||
smoothing=0.0
|
||
)
|
||
result[:, i] = interp(target_xy)
|
||
|
||
elif method == 'linear':
|
||
# Linear interpolation via Delaunay triangulation
|
||
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
|
||
result[:, i] = interp(target_xy)
|
||
|
||
# Handle NaN (points outside convex hull) with nearest neighbor
|
||
nan_mask = np.isnan(result[:, i])
|
||
if nan_mask.any():
|
||
from scipy.spatial import cKDTree
|
||
tree = cKDTree(xy)
|
||
_, idx = tree.query(target_xy[nan_mask])
|
||
result[nan_mask, i] = vals[idx]
|
||
|
||
elif method == 'clough_tocher':
|
||
# Clough-Tocher (C1 smooth) interpolation
|
||
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
|
||
result[:, i] = interp(target_xy)
|
||
|
||
# Handle NaN
|
||
nan_mask = np.isnan(result[:, i])
|
||
if nan_mask.any():
|
||
from scipy.spatial import cKDTree
|
||
tree = cKDTree(xy)
|
||
_, idx = tree.query(target_xy[nan_mask])
|
||
result[nan_mask, i] = vals[idx]
|
||
else:
|
||
raise ValueError(f"Unknown interpolation method: {method}")
|
||
|
||
return result[:, 0] if values_1d else result
|
||
|
||
def interpolate_field_data(
|
||
self,
|
||
field_data: Dict[str, Any],
|
||
subcases: List[int] = [1, 2, 3, 4],
|
||
method: str = 'linear' # Changed from 'rbf' - much faster
|
||
) -> Dict[str, np.ndarray]:
|
||
"""
|
||
Interpolate field data from extract_displacement_field() to polar grid.
|
||
|
||
Args:
|
||
field_data: Output from extract_displacement_field()
|
||
subcases: List of subcases to interpolate
|
||
method: Interpolation method
|
||
|
||
Returns:
|
||
Dictionary with:
|
||
- z_displacement: [n_nodes, n_subcases] array
|
||
- original_n_nodes: Number of FEA nodes
|
||
"""
|
||
mesh_coords = field_data['node_coords']
|
||
z_disp_dict = field_data['z_displacement']
|
||
|
||
# Stack subcases
|
||
z_disp_list = []
|
||
for sc in subcases:
|
||
if sc in z_disp_dict:
|
||
z_disp_list.append(z_disp_dict[sc])
|
||
else:
|
||
raise KeyError(f"Subcase {sc} not found in field_data")
|
||
|
||
# [n_fea_nodes, n_subcases]
|
||
z_disp_mesh = np.column_stack(z_disp_list)
|
||
|
||
# Interpolate to polar grid
|
||
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
|
||
|
||
return {
|
||
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
|
||
'original_n_nodes': len(mesh_coords),
|
||
}
|
||
|
||
def to_pyg_data(
|
||
self,
|
||
z_displacement: np.ndarray,
|
||
design_vars: np.ndarray,
|
||
objectives: Optional[Dict[str, float]] = None
|
||
):
|
||
"""
|
||
Convert to PyTorch Geometric Data object.
|
||
|
||
Args:
|
||
z_displacement: [n_nodes, n_subcases] displacement field
|
||
design_vars: [n_design_vars] design parameters
|
||
objectives: Optional dict of objective values (ground truth)
|
||
|
||
Returns:
|
||
torch_geometric.data.Data object
|
||
"""
|
||
if not HAS_TORCH:
|
||
raise ImportError("PyTorch required: pip install torch")
|
||
|
||
try:
|
||
from torch_geometric.data import Data
|
||
except ImportError:
|
||
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
|
||
|
||
# Node features: (r, theta, x, y)
|
||
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
|
||
|
||
# Edge index and features
|
||
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
|
||
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
|
||
|
||
# Target: Z-displacement field
|
||
y = torch.tensor(z_displacement, dtype=torch.float32)
|
||
|
||
# Design variables (global feature)
|
||
design = torch.tensor(design_vars, dtype=torch.float32)
|
||
|
||
data = Data(
|
||
x=node_features,
|
||
edge_index=edge_index,
|
||
edge_attr=edge_attr,
|
||
y=y,
|
||
design=design,
|
||
)
|
||
|
||
# Add objectives if provided
|
||
if objectives:
|
||
for key, value in objectives.items():
|
||
setattr(data, key, torch.tensor([value], dtype=torch.float32))
|
||
|
||
return data
|
||
|
||
def save(self, path: Path) -> None:
|
||
"""Save graph structure to JSON file."""
|
||
path = Path(path)
|
||
|
||
data = {
|
||
'r_inner': self.r_inner,
|
||
'r_outer': self.r_outer,
|
||
'n_radial': self.n_radial,
|
||
'n_angular': self.n_angular,
|
||
'n_nodes': self.n_nodes,
|
||
'n_edges': self.edge_index.shape[1],
|
||
}
|
||
|
||
with open(path, 'w') as f:
|
||
json.dump(data, f, indent=2)
|
||
|
||
# Save arrays separately for efficiency
|
||
np.savez_compressed(
|
||
path.with_suffix('.npz'),
|
||
r=self.r,
|
||
theta=self.theta,
|
||
x=self.x,
|
||
y=self.y,
|
||
edge_index=self.edge_index,
|
||
edge_attr=self.edge_attr,
|
||
)
|
||
|
||
@classmethod
|
||
def load(cls, path: Path) -> 'PolarMirrorGraph':
|
||
"""Load graph structure from file."""
|
||
path = Path(path)
|
||
|
||
with open(path, 'r') as f:
|
||
data = json.load(f)
|
||
|
||
return cls(
|
||
r_inner=data['r_inner'],
|
||
r_outer=data['r_outer'],
|
||
n_radial=data['n_radial'],
|
||
n_angular=data['n_angular'],
|
||
)
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"PolarMirrorGraph("
|
||
f"r=[{self.r_inner}, {self.r_outer}]mm, "
|
||
f"grid={self.n_radial}×{self.n_angular}, "
|
||
f"nodes={self.n_nodes}, "
|
||
f"edges={self.edge_index.shape[1]})"
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Convenience functions
|
||
# =============================================================================
|
||
|
||
def create_mirror_dataset(
|
||
study_dir: Path,
|
||
polar_graph: Optional[PolarMirrorGraph] = None,
|
||
subcases: List[int] = [1, 2, 3, 4],
|
||
verbose: bool = True
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Create GNN dataset from a study's gnn_data folder.
|
||
|
||
Args:
|
||
study_dir: Path to study directory
|
||
polar_graph: PolarMirrorGraph instance (created if None)
|
||
subcases: Subcases to include
|
||
verbose: Print progress
|
||
|
||
Returns:
|
||
List of data dictionaries, each containing:
|
||
- z_displacement: [n_nodes, n_subcases]
|
||
- design_vars: [n_vars]
|
||
- trial_number: int
|
||
- original_n_nodes: int
|
||
"""
|
||
from optimization_engine.gnn.extract_displacement_field import load_field
|
||
|
||
study_dir = Path(study_dir)
|
||
gnn_data_dir = study_dir / "gnn_data"
|
||
|
||
if not gnn_data_dir.exists():
|
||
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
|
||
|
||
# Load index
|
||
index_path = gnn_data_dir / "dataset_index.json"
|
||
with open(index_path, 'r') as f:
|
||
index = json.load(f)
|
||
|
||
if polar_graph is None:
|
||
polar_graph = PolarMirrorGraph()
|
||
|
||
dataset = []
|
||
|
||
for trial_num, trial_info in index['trials'].items():
|
||
if trial_info.get('status') != 'success':
|
||
continue
|
||
|
||
trial_dir = study_dir / trial_info['trial_dir']
|
||
|
||
# Find field file
|
||
field_path = None
|
||
for ext in ['.h5', '.npz']:
|
||
candidate = trial_dir / f"displacement_field{ext}"
|
||
if candidate.exists():
|
||
field_path = candidate
|
||
break
|
||
|
||
if field_path is None:
|
||
if verbose:
|
||
print(f"[WARN] No field file for trial {trial_num}")
|
||
continue
|
||
|
||
try:
|
||
# Load field data
|
||
field_data = load_field(field_path)
|
||
|
||
# Interpolate to polar grid
|
||
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
|
||
|
||
# Get design parameters
|
||
params = trial_info.get('params', {})
|
||
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
|
||
|
||
dataset.append({
|
||
'z_displacement': interp_result['z_displacement'],
|
||
'design_vars': design_vars,
|
||
'design_names': list(params.keys()) if params else [],
|
||
'trial_number': int(trial_num),
|
||
'original_n_nodes': interp_result['original_n_nodes'],
|
||
})
|
||
|
||
if verbose:
|
||
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']} → {polar_graph.n_nodes} nodes")
|
||
|
||
except Exception as e:
|
||
if verbose:
|
||
print(f"[ERR] Trial {trial_num}: {e}")
|
||
|
||
return dataset
|
||
|
||
|
||
# =============================================================================
|
||
# CLI
|
||
# =============================================================================
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
|
||
parser.add_argument('--test', action='store_true', help='Run basic tests')
|
||
parser.add_argument('--study', type=Path, help='Create dataset from study')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.test:
|
||
print("="*60)
|
||
print("TESTING PolarMirrorGraph")
|
||
print("="*60)
|
||
|
||
# Create graph
|
||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||
print(f"\n{graph}")
|
||
|
||
# Check node features
|
||
node_feat = graph.get_node_features(normalized=True)
|
||
print(f"\nNode features shape: {node_feat.shape}")
|
||
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
|
||
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
|
||
|
||
# Check edge features
|
||
edge_feat = graph.get_edge_features(normalized=True)
|
||
print(f"\nEdge features shape: {edge_feat.shape}")
|
||
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
|
||
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
|
||
|
||
# Test interpolation with synthetic data
|
||
print("\n--- Testing Interpolation ---")
|
||
|
||
# Create fake mesh data (random points in annulus)
|
||
np.random.seed(42)
|
||
n_mesh = 5000
|
||
r_mesh = np.random.uniform(100, 650, n_mesh)
|
||
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
|
||
x_mesh = r_mesh * np.cos(theta_mesh)
|
||
y_mesh = r_mesh * np.sin(theta_mesh)
|
||
mesh_coords = np.column_stack([x_mesh, y_mesh])
|
||
|
||
# Synthetic displacement: smooth function
|
||
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
|
||
|
||
# Interpolate
|
||
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
|
||
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
|
||
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
|
||
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
|
||
|
||
print("\n✓ All tests passed!")
|
||
|
||
elif args.study:
|
||
# Create dataset from study
|
||
print(f"Creating dataset from: {args.study}")
|
||
|
||
graph = PolarMirrorGraph()
|
||
dataset = create_mirror_dataset(args.study, polar_graph=graph)
|
||
|
||
print(f"\nDataset: {len(dataset)} samples")
|
||
if dataset:
|
||
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
|
||
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
|
||
|
||
else:
|
||
# Default: just show info
|
||
graph = PolarMirrorGraph()
|
||
print(graph)
|
||
print(f"\nNode features: {graph.get_node_features().shape}")
|
||
print(f"Edge index: {graph.edge_index.shape}")
|
||
print(f"Edge features: {graph.edge_attr.shape}")
|