feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
617
optimization_engine/gnn/polar_graph.py
Normal file
617
optimization_engine/gnn/polar_graph.py
Normal file
@@ -0,0 +1,617 @@
|
||||
"""
|
||||
Polar Mirror Graph for GNN Training
|
||||
====================================
|
||||
|
||||
This module creates a fixed polar grid graph structure for the mirror optical surface.
|
||||
The key insight is that the mirror has a fixed topology (circular annulus), so we can
|
||||
use a fixed graph structure regardless of FEA mesh variations.
|
||||
|
||||
Why Polar Grid?
|
||||
1. Matches mirror geometry (annulus)
|
||||
2. Same approach as extract_zernike_surface.py
|
||||
3. Enables mesh-independent training
|
||||
4. Edge structure respects radial/angular physics
|
||||
|
||||
Grid Structure:
|
||||
- n_radial points from r_inner to r_outer
|
||||
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
|
||||
- Total nodes = n_radial × n_angular
|
||||
- Edges connect radial neighbors and angular neighbors (wrap-around)
|
||||
|
||||
Usage:
|
||||
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
||||
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
|
||||
# Interpolate FEA results to fixed grid
|
||||
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
|
||||
|
||||
# Get PyTorch Geometric data
|
||||
data = graph.to_pyg_data(z_disp_grid, design_vars)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import json
|
||||
|
||||
try:
|
||||
import torch
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
|
||||
from scipy.spatial import Delaunay
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
HAS_SCIPY = False
|
||||
|
||||
|
||||
class PolarMirrorGraph:
|
||||
"""
|
||||
Fixed polar grid graph for mirror optical surface.
|
||||
|
||||
This creates a mesh-independent graph structure that can be used for GNN training
|
||||
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
|
||||
|
||||
Attributes:
|
||||
n_nodes: Total number of nodes (n_radial × n_angular)
|
||||
r: Radial coordinates [n_nodes]
|
||||
theta: Angular coordinates [n_nodes]
|
||||
x: Cartesian X coordinates [n_nodes]
|
||||
y: Cartesian Y coordinates [n_nodes]
|
||||
edge_index: Graph edges [2, n_edges]
|
||||
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
r_inner: float = 100.0,
|
||||
r_outer: float = 650.0,
|
||||
n_radial: int = 50,
|
||||
n_angular: int = 60
|
||||
):
|
||||
"""
|
||||
Initialize polar grid graph.
|
||||
|
||||
Args:
|
||||
r_inner: Inner radius (central hole), mm
|
||||
r_outer: Outer radius, mm
|
||||
n_radial: Number of radial samples
|
||||
n_angular: Number of angular samples
|
||||
"""
|
||||
self.r_inner = r_inner
|
||||
self.r_outer = r_outer
|
||||
self.n_radial = n_radial
|
||||
self.n_angular = n_angular
|
||||
self.n_nodes = n_radial * n_angular
|
||||
|
||||
# Create polar grid coordinates
|
||||
r_1d = np.linspace(r_inner, r_outer, n_radial)
|
||||
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
|
||||
|
||||
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
|
||||
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
|
||||
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
|
||||
|
||||
# Flatten: radial index varies slowest
|
||||
self.r = R.flatten().astype(np.float32)
|
||||
self.theta = Theta.flatten().astype(np.float32)
|
||||
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
|
||||
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
|
||||
|
||||
# Build graph edges
|
||||
self.edge_index, self.edge_attr = self._build_polar_edges()
|
||||
|
||||
# Precompute normalization factors
|
||||
self._r_mean = (r_inner + r_outer) / 2
|
||||
self._r_std = (r_outer - r_inner) / 2
|
||||
|
||||
def _node_index(self, i_r: int, i_theta: int) -> int:
|
||||
"""Convert (radial_index, angular_index) to flat node index."""
|
||||
# Angular wraps around
|
||||
i_theta = i_theta % self.n_angular
|
||||
return i_r * self.n_angular + i_theta
|
||||
|
||||
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Create graph edges respecting polar topology.
|
||||
|
||||
Edge types:
|
||||
1. Radial edges: Connect adjacent radial rings
|
||||
2. Angular edges: Connect adjacent angular positions (with wrap-around)
|
||||
3. Diagonal edges: Connect diagonal neighbors for better message passing
|
||||
|
||||
Returns:
|
||||
edge_index: [2, n_edges] array of (source, target) pairs
|
||||
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
|
||||
"""
|
||||
edges = []
|
||||
edge_features = []
|
||||
|
||||
for i_r in range(self.n_radial):
|
||||
for i_theta in range(self.n_angular):
|
||||
node = self._node_index(i_r, i_theta)
|
||||
|
||||
# Radial neighbor (outward)
|
||||
if i_r < self.n_radial - 1:
|
||||
neighbor = self._node_index(i_r + 1, i_theta)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node]) # Bidirectional
|
||||
|
||||
# Edge features: (dr, dtheta, distance, relative_angle)
|
||||
dr = self.r[neighbor] - self.r[node]
|
||||
dtheta = 0.0
|
||||
dist = abs(dr)
|
||||
angle = 0.0 # Radial direction
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
|
||||
|
||||
# Angular neighbor (counterclockwise, with wrap-around)
|
||||
neighbor = self._node_index(i_r, i_theta + 1)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node]) # Bidirectional
|
||||
|
||||
# Edge features for angular edge
|
||||
dr = 0.0
|
||||
dtheta = 2 * np.pi / self.n_angular
|
||||
# Arc length at this radius
|
||||
dist = self.r[node] * dtheta
|
||||
angle = np.pi / 2 # Tangential direction
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
|
||||
|
||||
# Diagonal neighbor (outward + counterclockwise) for better connectivity
|
||||
if i_r < self.n_radial - 1:
|
||||
neighbor = self._node_index(i_r + 1, i_theta + 1)
|
||||
edges.append([node, neighbor])
|
||||
edges.append([neighbor, node])
|
||||
|
||||
dr = self.r[neighbor] - self.r[node]
|
||||
dtheta = 2 * np.pi / self.n_angular
|
||||
dx = self.x[neighbor] - self.x[node]
|
||||
dy = self.y[neighbor] - self.y[node]
|
||||
dist = np.sqrt(dx**2 + dy**2)
|
||||
angle = np.arctan2(dy, dx)
|
||||
edge_features.append([dr, dtheta, dist, angle])
|
||||
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
|
||||
|
||||
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
|
||||
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
|
||||
|
||||
return edge_index, edge_attr
|
||||
|
||||
def get_node_features(self, normalized: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Get node features for GNN input.
|
||||
|
||||
Features: (r, theta, x, y) - polar and Cartesian coordinates
|
||||
|
||||
Args:
|
||||
normalized: If True, normalize features to ~[-1, 1] range
|
||||
|
||||
Returns:
|
||||
Node features [n_nodes, 4]
|
||||
"""
|
||||
if normalized:
|
||||
r_norm = (self.r - self._r_mean) / self._r_std
|
||||
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
|
||||
x_norm = self.x / self.r_outer
|
||||
y_norm = self.y / self.r_outer
|
||||
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
|
||||
else:
|
||||
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
|
||||
|
||||
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Get edge features for GNN input.
|
||||
|
||||
Features: (dr, dtheta, distance, angle)
|
||||
|
||||
Args:
|
||||
normalized: If True, normalize features
|
||||
|
||||
Returns:
|
||||
Edge features [n_edges, 4]
|
||||
"""
|
||||
if normalized:
|
||||
edge_attr = self.edge_attr.copy()
|
||||
edge_attr[:, 0] /= self._r_std # dr
|
||||
edge_attr[:, 1] /= np.pi # dtheta
|
||||
edge_attr[:, 2] /= self.r_outer # distance
|
||||
edge_attr[:, 3] /= np.pi # angle
|
||||
return edge_attr
|
||||
else:
|
||||
return self.edge_attr
|
||||
|
||||
def interpolate_from_mesh(
|
||||
self,
|
||||
mesh_coords: np.ndarray,
|
||||
mesh_values: np.ndarray,
|
||||
method: str = 'rbf'
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Interpolate FEA results from mesh nodes to fixed polar grid.
|
||||
|
||||
Args:
|
||||
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
|
||||
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
|
||||
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
|
||||
|
||||
Returns:
|
||||
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
|
||||
"""
|
||||
if not HAS_SCIPY:
|
||||
raise ImportError("scipy required for interpolation: pip install scipy")
|
||||
|
||||
# Use only X, Y coordinates
|
||||
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
|
||||
|
||||
# Handle multi-dimensional values
|
||||
values_1d = mesh_values.ndim == 1
|
||||
if values_1d:
|
||||
mesh_values = mesh_values.reshape(-1, 1)
|
||||
|
||||
# Target coordinates
|
||||
target_xy = np.column_stack([self.x, self.y])
|
||||
|
||||
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
|
||||
|
||||
for i in range(mesh_values.shape[1]):
|
||||
vals = mesh_values[:, i]
|
||||
|
||||
if method == 'rbf':
|
||||
# RBF interpolation - smooth, handles scattered data well
|
||||
interp = RBFInterpolator(
|
||||
xy, vals,
|
||||
kernel='thin_plate_spline',
|
||||
smoothing=0.0
|
||||
)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
elif method == 'linear':
|
||||
# Linear interpolation via Delaunay triangulation
|
||||
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
# Handle NaN (points outside convex hull) with nearest neighbor
|
||||
nan_mask = np.isnan(result[:, i])
|
||||
if nan_mask.any():
|
||||
from scipy.spatial import cKDTree
|
||||
tree = cKDTree(xy)
|
||||
_, idx = tree.query(target_xy[nan_mask])
|
||||
result[nan_mask, i] = vals[idx]
|
||||
|
||||
elif method == 'clough_tocher':
|
||||
# Clough-Tocher (C1 smooth) interpolation
|
||||
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
|
||||
result[:, i] = interp(target_xy)
|
||||
|
||||
# Handle NaN
|
||||
nan_mask = np.isnan(result[:, i])
|
||||
if nan_mask.any():
|
||||
from scipy.spatial import cKDTree
|
||||
tree = cKDTree(xy)
|
||||
_, idx = tree.query(target_xy[nan_mask])
|
||||
result[nan_mask, i] = vals[idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown interpolation method: {method}")
|
||||
|
||||
return result[:, 0] if values_1d else result
|
||||
|
||||
def interpolate_field_data(
|
||||
self,
|
||||
field_data: Dict[str, Any],
|
||||
subcases: List[int] = [1, 2, 3, 4],
|
||||
method: str = 'linear' # Changed from 'rbf' - much faster
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Interpolate field data from extract_displacement_field() to polar grid.
|
||||
|
||||
Args:
|
||||
field_data: Output from extract_displacement_field()
|
||||
subcases: List of subcases to interpolate
|
||||
method: Interpolation method
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- z_displacement: [n_nodes, n_subcases] array
|
||||
- original_n_nodes: Number of FEA nodes
|
||||
"""
|
||||
mesh_coords = field_data['node_coords']
|
||||
z_disp_dict = field_data['z_displacement']
|
||||
|
||||
# Stack subcases
|
||||
z_disp_list = []
|
||||
for sc in subcases:
|
||||
if sc in z_disp_dict:
|
||||
z_disp_list.append(z_disp_dict[sc])
|
||||
else:
|
||||
raise KeyError(f"Subcase {sc} not found in field_data")
|
||||
|
||||
# [n_fea_nodes, n_subcases]
|
||||
z_disp_mesh = np.column_stack(z_disp_list)
|
||||
|
||||
# Interpolate to polar grid
|
||||
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
|
||||
|
||||
return {
|
||||
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
|
||||
'original_n_nodes': len(mesh_coords),
|
||||
}
|
||||
|
||||
def to_pyg_data(
|
||||
self,
|
||||
z_displacement: np.ndarray,
|
||||
design_vars: np.ndarray,
|
||||
objectives: Optional[Dict[str, float]] = None
|
||||
):
|
||||
"""
|
||||
Convert to PyTorch Geometric Data object.
|
||||
|
||||
Args:
|
||||
z_displacement: [n_nodes, n_subcases] displacement field
|
||||
design_vars: [n_design_vars] design parameters
|
||||
objectives: Optional dict of objective values (ground truth)
|
||||
|
||||
Returns:
|
||||
torch_geometric.data.Data object
|
||||
"""
|
||||
if not HAS_TORCH:
|
||||
raise ImportError("PyTorch required: pip install torch")
|
||||
|
||||
try:
|
||||
from torch_geometric.data import Data
|
||||
except ImportError:
|
||||
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
|
||||
|
||||
# Node features: (r, theta, x, y)
|
||||
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
|
||||
|
||||
# Edge index and features
|
||||
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
|
||||
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
|
||||
|
||||
# Target: Z-displacement field
|
||||
y = torch.tensor(z_displacement, dtype=torch.float32)
|
||||
|
||||
# Design variables (global feature)
|
||||
design = torch.tensor(design_vars, dtype=torch.float32)
|
||||
|
||||
data = Data(
|
||||
x=node_features,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
y=y,
|
||||
design=design,
|
||||
)
|
||||
|
||||
# Add objectives if provided
|
||||
if objectives:
|
||||
for key, value in objectives.items():
|
||||
setattr(data, key, torch.tensor([value], dtype=torch.float32))
|
||||
|
||||
return data
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""Save graph structure to JSON file."""
|
||||
path = Path(path)
|
||||
|
||||
data = {
|
||||
'r_inner': self.r_inner,
|
||||
'r_outer': self.r_outer,
|
||||
'n_radial': self.n_radial,
|
||||
'n_angular': self.n_angular,
|
||||
'n_nodes': self.n_nodes,
|
||||
'n_edges': self.edge_index.shape[1],
|
||||
}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
# Save arrays separately for efficiency
|
||||
np.savez_compressed(
|
||||
path.with_suffix('.npz'),
|
||||
r=self.r,
|
||||
theta=self.theta,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
edge_index=self.edge_index,
|
||||
edge_attr=self.edge_attr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> 'PolarMirrorGraph':
|
||||
"""Load graph structure from file."""
|
||||
path = Path(path)
|
||||
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
return cls(
|
||||
r_inner=data['r_inner'],
|
||||
r_outer=data['r_outer'],
|
||||
n_radial=data['n_radial'],
|
||||
n_angular=data['n_angular'],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PolarMirrorGraph("
|
||||
f"r=[{self.r_inner}, {self.r_outer}]mm, "
|
||||
f"grid={self.n_radial}×{self.n_angular}, "
|
||||
f"nodes={self.n_nodes}, "
|
||||
f"edges={self.edge_index.shape[1]})"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Convenience functions
|
||||
# =============================================================================
|
||||
|
||||
def create_mirror_dataset(
|
||||
study_dir: Path,
|
||||
polar_graph: Optional[PolarMirrorGraph] = None,
|
||||
subcases: List[int] = [1, 2, 3, 4],
|
||||
verbose: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Create GNN dataset from a study's gnn_data folder.
|
||||
|
||||
Args:
|
||||
study_dir: Path to study directory
|
||||
polar_graph: PolarMirrorGraph instance (created if None)
|
||||
subcases: Subcases to include
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
List of data dictionaries, each containing:
|
||||
- z_displacement: [n_nodes, n_subcases]
|
||||
- design_vars: [n_vars]
|
||||
- trial_number: int
|
||||
- original_n_nodes: int
|
||||
"""
|
||||
from optimization_engine.gnn.extract_displacement_field import load_field
|
||||
|
||||
study_dir = Path(study_dir)
|
||||
gnn_data_dir = study_dir / "gnn_data"
|
||||
|
||||
if not gnn_data_dir.exists():
|
||||
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
|
||||
|
||||
# Load index
|
||||
index_path = gnn_data_dir / "dataset_index.json"
|
||||
with open(index_path, 'r') as f:
|
||||
index = json.load(f)
|
||||
|
||||
if polar_graph is None:
|
||||
polar_graph = PolarMirrorGraph()
|
||||
|
||||
dataset = []
|
||||
|
||||
for trial_num, trial_info in index['trials'].items():
|
||||
if trial_info.get('status') != 'success':
|
||||
continue
|
||||
|
||||
trial_dir = study_dir / trial_info['trial_dir']
|
||||
|
||||
# Find field file
|
||||
field_path = None
|
||||
for ext in ['.h5', '.npz']:
|
||||
candidate = trial_dir / f"displacement_field{ext}"
|
||||
if candidate.exists():
|
||||
field_path = candidate
|
||||
break
|
||||
|
||||
if field_path is None:
|
||||
if verbose:
|
||||
print(f"[WARN] No field file for trial {trial_num}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Load field data
|
||||
field_data = load_field(field_path)
|
||||
|
||||
# Interpolate to polar grid
|
||||
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
|
||||
|
||||
# Get design parameters
|
||||
params = trial_info.get('params', {})
|
||||
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
|
||||
|
||||
dataset.append({
|
||||
'z_displacement': interp_result['z_displacement'],
|
||||
'design_vars': design_vars,
|
||||
'design_names': list(params.keys()) if params else [],
|
||||
'trial_number': int(trial_num),
|
||||
'original_n_nodes': interp_result['original_n_nodes'],
|
||||
})
|
||||
|
||||
if verbose:
|
||||
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']} → {polar_graph.n_nodes} nodes")
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f"[ERR] Trial {trial_num}: {e}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
|
||||
parser.add_argument('--test', action='store_true', help='Run basic tests')
|
||||
parser.add_argument('--study', type=Path, help='Create dataset from study')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.test:
|
||||
print("="*60)
|
||||
print("TESTING PolarMirrorGraph")
|
||||
print("="*60)
|
||||
|
||||
# Create graph
|
||||
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
||||
print(f"\n{graph}")
|
||||
|
||||
# Check node features
|
||||
node_feat = graph.get_node_features(normalized=True)
|
||||
print(f"\nNode features shape: {node_feat.shape}")
|
||||
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
|
||||
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
|
||||
|
||||
# Check edge features
|
||||
edge_feat = graph.get_edge_features(normalized=True)
|
||||
print(f"\nEdge features shape: {edge_feat.shape}")
|
||||
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
|
||||
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
|
||||
|
||||
# Test interpolation with synthetic data
|
||||
print("\n--- Testing Interpolation ---")
|
||||
|
||||
# Create fake mesh data (random points in annulus)
|
||||
np.random.seed(42)
|
||||
n_mesh = 5000
|
||||
r_mesh = np.random.uniform(100, 650, n_mesh)
|
||||
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
|
||||
x_mesh = r_mesh * np.cos(theta_mesh)
|
||||
y_mesh = r_mesh * np.sin(theta_mesh)
|
||||
mesh_coords = np.column_stack([x_mesh, y_mesh])
|
||||
|
||||
# Synthetic displacement: smooth function
|
||||
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
|
||||
|
||||
# Interpolate
|
||||
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
|
||||
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
|
||||
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
|
||||
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
|
||||
elif args.study:
|
||||
# Create dataset from study
|
||||
print(f"Creating dataset from: {args.study}")
|
||||
|
||||
graph = PolarMirrorGraph()
|
||||
dataset = create_mirror_dataset(args.study, polar_graph=graph)
|
||||
|
||||
print(f"\nDataset: {len(dataset)} samples")
|
||||
if dataset:
|
||||
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
|
||||
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
|
||||
|
||||
else:
|
||||
# Default: just show info
|
||||
graph = PolarMirrorGraph()
|
||||
print(graph)
|
||||
print(f"\nNode features: {graph.get_node_features().shape}")
|
||||
print(f"Edge index: {graph.edge_index.shape}")
|
||||
print(f"Edge features: {graph.edge_attr.shape}")
|
||||
Reference in New Issue
Block a user