""" Polar Mirror Graph for GNN Training ==================================== This module creates a fixed polar grid graph structure for the mirror optical surface. The key insight is that the mirror has a fixed topology (circular annulus), so we can use a fixed graph structure regardless of FEA mesh variations. Why Polar Grid? 1. Matches mirror geometry (annulus) 2. Same approach as extract_zernike_surface.py 3. Enables mesh-independent training 4. Edge structure respects radial/angular physics Grid Structure: - n_radial points from r_inner to r_outer - n_angular points from 0 to 2π (not including 2π to avoid duplicate) - Total nodes = n_radial × n_angular - Edges connect radial neighbors and angular neighbors (wrap-around) Usage: from optimization_engine.gnn.polar_graph import PolarMirrorGraph graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60) # Interpolate FEA results to fixed grid z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp) # Get PyTorch Geometric data data = graph.to_pyg_data(z_disp_grid, design_vars) """ import numpy as np from pathlib import Path from typing import Dict, Any, Optional, Tuple, List import json try: import torch HAS_TORCH = True except ImportError: HAS_TORCH = False try: from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator from scipy.spatial import Delaunay HAS_SCIPY = True except ImportError: HAS_SCIPY = False class PolarMirrorGraph: """ Fixed polar grid graph for mirror optical surface. This creates a mesh-independent graph structure that can be used for GNN training regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid. Attributes: n_nodes: Total number of nodes (n_radial × n_angular) r: Radial coordinates [n_nodes] theta: Angular coordinates [n_nodes] x: Cartesian X coordinates [n_nodes] y: Cartesian Y coordinates [n_nodes] edge_index: Graph edges [2, n_edges] edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle) """ def __init__( self, r_inner: float = 100.0, r_outer: float = 650.0, n_radial: int = 50, n_angular: int = 60 ): """ Initialize polar grid graph. Args: r_inner: Inner radius (central hole), mm r_outer: Outer radius, mm n_radial: Number of radial samples n_angular: Number of angular samples """ self.r_inner = r_inner self.r_outer = r_outer self.n_radial = n_radial self.n_angular = n_angular self.n_nodes = n_radial * n_angular # Create polar grid coordinates r_1d = np.linspace(r_inner, r_outer, n_radial) theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False) # Meshgrid: theta varies fast (angular index), r varies slow (radial index) # Shape after flatten: [n_angular * n_radial] with angular varying fastest Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular] # Flatten: radial index varies slowest self.r = R.flatten().astype(np.float32) self.theta = Theta.flatten().astype(np.float32) self.x = (self.r * np.cos(self.theta)).astype(np.float32) self.y = (self.r * np.sin(self.theta)).astype(np.float32) # Build graph edges self.edge_index, self.edge_attr = self._build_polar_edges() # Precompute normalization factors self._r_mean = (r_inner + r_outer) / 2 self._r_std = (r_outer - r_inner) / 2 def _node_index(self, i_r: int, i_theta: int) -> int: """Convert (radial_index, angular_index) to flat node index.""" # Angular wraps around i_theta = i_theta % self.n_angular return i_r * self.n_angular + i_theta def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]: """ Create graph edges respecting polar topology. Edge types: 1. Radial edges: Connect adjacent radial rings 2. Angular edges: Connect adjacent angular positions (with wrap-around) 3. Diagonal edges: Connect diagonal neighbors for better message passing Returns: edge_index: [2, n_edges] array of (source, target) pairs edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle) """ edges = [] edge_features = [] for i_r in range(self.n_radial): for i_theta in range(self.n_angular): node = self._node_index(i_r, i_theta) # Radial neighbor (outward) if i_r < self.n_radial - 1: neighbor = self._node_index(i_r + 1, i_theta) edges.append([node, neighbor]) edges.append([neighbor, node]) # Bidirectional # Edge features: (dr, dtheta, distance, relative_angle) dr = self.r[neighbor] - self.r[node] dtheta = 0.0 dist = abs(dr) angle = 0.0 # Radial direction edge_features.append([dr, dtheta, dist, angle]) edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse # Angular neighbor (counterclockwise, with wrap-around) neighbor = self._node_index(i_r, i_theta + 1) edges.append([node, neighbor]) edges.append([neighbor, node]) # Bidirectional # Edge features for angular edge dr = 0.0 dtheta = 2 * np.pi / self.n_angular # Arc length at this radius dist = self.r[node] * dtheta angle = np.pi / 2 # Tangential direction edge_features.append([dr, dtheta, dist, angle]) edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse # Diagonal neighbor (outward + counterclockwise) for better connectivity if i_r < self.n_radial - 1: neighbor = self._node_index(i_r + 1, i_theta + 1) edges.append([node, neighbor]) edges.append([neighbor, node]) dr = self.r[neighbor] - self.r[node] dtheta = 2 * np.pi / self.n_angular dx = self.x[neighbor] - self.x[node] dy = self.y[neighbor] - self.y[node] dist = np.sqrt(dx**2 + dy**2) angle = np.arctan2(dy, dx) edge_features.append([dr, dtheta, dist, angle]) edge_features.append([-dr, -dtheta, dist, angle + np.pi]) edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges] edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4] return edge_index, edge_attr def get_node_features(self, normalized: bool = True) -> np.ndarray: """ Get node features for GNN input. Features: (r, theta, x, y) - polar and Cartesian coordinates Args: normalized: If True, normalize features to ~[-1, 1] range Returns: Node features [n_nodes, 4] """ if normalized: r_norm = (self.r - self._r_mean) / self._r_std theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1] x_norm = self.x / self.r_outer y_norm = self.y / self.r_outer return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32) else: return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32) def get_edge_features(self, normalized: bool = True) -> np.ndarray: """ Get edge features for GNN input. Features: (dr, dtheta, distance, angle) Args: normalized: If True, normalize features Returns: Edge features [n_edges, 4] """ if normalized: edge_attr = self.edge_attr.copy() edge_attr[:, 0] /= self._r_std # dr edge_attr[:, 1] /= np.pi # dtheta edge_attr[:, 2] /= self.r_outer # distance edge_attr[:, 3] /= np.pi # angle return edge_attr else: return self.edge_attr def interpolate_from_mesh( self, mesh_coords: np.ndarray, mesh_values: np.ndarray, method: str = 'rbf' ) -> np.ndarray: """ Interpolate FEA results from mesh nodes to fixed polar grid. Args: mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z]) mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features] method: Interpolation method ('rbf', 'linear', 'clough_tocher') Returns: Interpolated values on polar grid [n_nodes] or [n_nodes, n_features] """ if not HAS_SCIPY: raise ImportError("scipy required for interpolation: pip install scipy") # Use only X, Y coordinates xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords # Handle multi-dimensional values values_1d = mesh_values.ndim == 1 if values_1d: mesh_values = mesh_values.reshape(-1, 1) # Target coordinates target_xy = np.column_stack([self.x, self.y]) result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32) for i in range(mesh_values.shape[1]): vals = mesh_values[:, i] if method == 'rbf': # RBF interpolation - smooth, handles scattered data well interp = RBFInterpolator( xy, vals, kernel='thin_plate_spline', smoothing=0.0 ) result[:, i] = interp(target_xy) elif method == 'linear': # Linear interpolation via Delaunay triangulation interp = LinearNDInterpolator(xy, vals, fill_value=np.nan) result[:, i] = interp(target_xy) # Handle NaN (points outside convex hull) with nearest neighbor nan_mask = np.isnan(result[:, i]) if nan_mask.any(): from scipy.spatial import cKDTree tree = cKDTree(xy) _, idx = tree.query(target_xy[nan_mask]) result[nan_mask, i] = vals[idx] elif method == 'clough_tocher': # Clough-Tocher (C1 smooth) interpolation interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan) result[:, i] = interp(target_xy) # Handle NaN nan_mask = np.isnan(result[:, i]) if nan_mask.any(): from scipy.spatial import cKDTree tree = cKDTree(xy) _, idx = tree.query(target_xy[nan_mask]) result[nan_mask, i] = vals[idx] else: raise ValueError(f"Unknown interpolation method: {method}") return result[:, 0] if values_1d else result def interpolate_field_data( self, field_data: Dict[str, Any], subcases: List[int] = [1, 2, 3, 4], method: str = 'linear' # Changed from 'rbf' - much faster ) -> Dict[str, np.ndarray]: """ Interpolate field data from extract_displacement_field() to polar grid. Args: field_data: Output from extract_displacement_field() subcases: List of subcases to interpolate method: Interpolation method Returns: Dictionary with: - z_displacement: [n_nodes, n_subcases] array - original_n_nodes: Number of FEA nodes """ mesh_coords = field_data['node_coords'] z_disp_dict = field_data['z_displacement'] # Stack subcases z_disp_list = [] for sc in subcases: if sc in z_disp_dict: z_disp_list.append(z_disp_dict[sc]) else: raise KeyError(f"Subcase {sc} not found in field_data") # [n_fea_nodes, n_subcases] z_disp_mesh = np.column_stack(z_disp_list) # Interpolate to polar grid z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method) return { 'z_displacement': z_disp_grid, # [n_nodes, n_subcases] 'original_n_nodes': len(mesh_coords), } def to_pyg_data( self, z_displacement: np.ndarray, design_vars: np.ndarray, objectives: Optional[Dict[str, float]] = None ): """ Convert to PyTorch Geometric Data object. Args: z_displacement: [n_nodes, n_subcases] displacement field design_vars: [n_design_vars] design parameters objectives: Optional dict of objective values (ground truth) Returns: torch_geometric.data.Data object """ if not HAS_TORCH: raise ImportError("PyTorch required: pip install torch") try: from torch_geometric.data import Data except ImportError: raise ImportError("PyTorch Geometric required: pip install torch-geometric") # Node features: (r, theta, x, y) node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32) # Edge index and features edge_index = torch.tensor(self.edge_index, dtype=torch.long) edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32) # Target: Z-displacement field y = torch.tensor(z_displacement, dtype=torch.float32) # Design variables (global feature) design = torch.tensor(design_vars, dtype=torch.float32) data = Data( x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y, design=design, ) # Add objectives if provided if objectives: for key, value in objectives.items(): setattr(data, key, torch.tensor([value], dtype=torch.float32)) return data def save(self, path: Path) -> None: """Save graph structure to JSON file.""" path = Path(path) data = { 'r_inner': self.r_inner, 'r_outer': self.r_outer, 'n_radial': self.n_radial, 'n_angular': self.n_angular, 'n_nodes': self.n_nodes, 'n_edges': self.edge_index.shape[1], } with open(path, 'w') as f: json.dump(data, f, indent=2) # Save arrays separately for efficiency np.savez_compressed( path.with_suffix('.npz'), r=self.r, theta=self.theta, x=self.x, y=self.y, edge_index=self.edge_index, edge_attr=self.edge_attr, ) @classmethod def load(cls, path: Path) -> 'PolarMirrorGraph': """Load graph structure from file.""" path = Path(path) with open(path, 'r') as f: data = json.load(f) return cls( r_inner=data['r_inner'], r_outer=data['r_outer'], n_radial=data['n_radial'], n_angular=data['n_angular'], ) def __repr__(self) -> str: return ( f"PolarMirrorGraph(" f"r=[{self.r_inner}, {self.r_outer}]mm, " f"grid={self.n_radial}×{self.n_angular}, " f"nodes={self.n_nodes}, " f"edges={self.edge_index.shape[1]})" ) # ============================================================================= # Convenience functions # ============================================================================= def create_mirror_dataset( study_dir: Path, polar_graph: Optional[PolarMirrorGraph] = None, subcases: List[int] = [1, 2, 3, 4], verbose: bool = True ) -> List[Dict[str, Any]]: """ Create GNN dataset from a study's gnn_data folder. Args: study_dir: Path to study directory polar_graph: PolarMirrorGraph instance (created if None) subcases: Subcases to include verbose: Print progress Returns: List of data dictionaries, each containing: - z_displacement: [n_nodes, n_subcases] - design_vars: [n_vars] - trial_number: int - original_n_nodes: int """ from optimization_engine.gnn.extract_displacement_field import load_field study_dir = Path(study_dir) gnn_data_dir = study_dir / "gnn_data" if not gnn_data_dir.exists(): raise FileNotFoundError(f"No gnn_data folder in {study_dir}") # Load index index_path = gnn_data_dir / "dataset_index.json" with open(index_path, 'r') as f: index = json.load(f) if polar_graph is None: polar_graph = PolarMirrorGraph() dataset = [] for trial_num, trial_info in index['trials'].items(): if trial_info.get('status') != 'success': continue trial_dir = study_dir / trial_info['trial_dir'] # Find field file field_path = None for ext in ['.h5', '.npz']: candidate = trial_dir / f"displacement_field{ext}" if candidate.exists(): field_path = candidate break if field_path is None: if verbose: print(f"[WARN] No field file for trial {trial_num}") continue try: # Load field data field_data = load_field(field_path) # Interpolate to polar grid interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases) # Get design parameters params = trial_info.get('params', {}) design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([]) dataset.append({ 'z_displacement': interp_result['z_displacement'], 'design_vars': design_vars, 'design_names': list(params.keys()) if params else [], 'trial_number': int(trial_num), 'original_n_nodes': interp_result['original_n_nodes'], }) if verbose: print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']} → {polar_graph.n_nodes} nodes") except Exception as e: if verbose: print(f"[ERR] Trial {trial_num}: {e}") return dataset # ============================================================================= # CLI # ============================================================================= if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Test PolarMirrorGraph') parser.add_argument('--test', action='store_true', help='Run basic tests') parser.add_argument('--study', type=Path, help='Create dataset from study') args = parser.parse_args() if args.test: print("="*60) print("TESTING PolarMirrorGraph") print("="*60) # Create graph graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60) print(f"\n{graph}") # Check node features node_feat = graph.get_node_features(normalized=True) print(f"\nNode features shape: {node_feat.shape}") print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]") print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]") # Check edge features edge_feat = graph.get_edge_features(normalized=True) print(f"\nEdge features shape: {edge_feat.shape}") print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]") print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]") # Test interpolation with synthetic data print("\n--- Testing Interpolation ---") # Create fake mesh data (random points in annulus) np.random.seed(42) n_mesh = 5000 r_mesh = np.random.uniform(100, 650, n_mesh) theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh) x_mesh = r_mesh * np.cos(theta_mesh) y_mesh = r_mesh * np.sin(theta_mesh) mesh_coords = np.column_stack([x_mesh, y_mesh]) # Synthetic displacement: smooth function mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh) # Interpolate grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf') print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes") print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]") print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]") print("\n✓ All tests passed!") elif args.study: # Create dataset from study print(f"Creating dataset from: {args.study}") graph = PolarMirrorGraph() dataset = create_mirror_dataset(args.study, polar_graph=graph) print(f"\nDataset: {len(dataset)} samples") if dataset: print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}") print(f" Design vars: {len(dataset[0]['design_vars'])} variables") else: # Default: just show info graph = PolarMirrorGraph() print(graph) print(f"\nNode features: {graph.get_node_features().shape}") print(f"Edge index: {graph.edge_index.shape}") print(f"Edge features: {graph.edge_attr.shape}")