618 lines
22 KiB
Python
618 lines
22 KiB
Python
|
|
"""
|
|||
|
|
Polar Mirror Graph for GNN Training
|
|||
|
|
====================================
|
|||
|
|
|
|||
|
|
This module creates a fixed polar grid graph structure for the mirror optical surface.
|
|||
|
|
The key insight is that the mirror has a fixed topology (circular annulus), so we can
|
|||
|
|
use a fixed graph structure regardless of FEA mesh variations.
|
|||
|
|
|
|||
|
|
Why Polar Grid?
|
|||
|
|
1. Matches mirror geometry (annulus)
|
|||
|
|
2. Same approach as extract_zernike_surface.py
|
|||
|
|
3. Enables mesh-independent training
|
|||
|
|
4. Edge structure respects radial/angular physics
|
|||
|
|
|
|||
|
|
Grid Structure:
|
|||
|
|
- n_radial points from r_inner to r_outer
|
|||
|
|
- n_angular points from 0 to 2π (not including 2π to avoid duplicate)
|
|||
|
|
- Total nodes = n_radial × n_angular
|
|||
|
|
- Edges connect radial neighbors and angular neighbors (wrap-around)
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
from optimization_engine.gnn.polar_graph import PolarMirrorGraph
|
|||
|
|
|
|||
|
|
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
|||
|
|
|
|||
|
|
# Interpolate FEA results to fixed grid
|
|||
|
|
z_disp_grid = graph.interpolate_from_mesh(fea_coords, fea_z_disp)
|
|||
|
|
|
|||
|
|
# Get PyTorch Geometric data
|
|||
|
|
data = graph.to_pyg_data(z_disp_grid, design_vars)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, Any, Optional, Tuple, List
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import torch
|
|||
|
|
HAS_TORCH = True
|
|||
|
|
except ImportError:
|
|||
|
|
HAS_TORCH = False
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator, CloughTocher2DInterpolator
|
|||
|
|
from scipy.spatial import Delaunay
|
|||
|
|
HAS_SCIPY = True
|
|||
|
|
except ImportError:
|
|||
|
|
HAS_SCIPY = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PolarMirrorGraph:
|
|||
|
|
"""
|
|||
|
|
Fixed polar grid graph for mirror optical surface.
|
|||
|
|
|
|||
|
|
This creates a mesh-independent graph structure that can be used for GNN training
|
|||
|
|
regardless of the underlying FEA mesh. FEA results are interpolated to this fixed grid.
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
n_nodes: Total number of nodes (n_radial × n_angular)
|
|||
|
|
r: Radial coordinates [n_nodes]
|
|||
|
|
theta: Angular coordinates [n_nodes]
|
|||
|
|
x: Cartesian X coordinates [n_nodes]
|
|||
|
|
y: Cartesian Y coordinates [n_nodes]
|
|||
|
|
edge_index: Graph edges [2, n_edges]
|
|||
|
|
edge_attr: Edge features [n_edges, 4] - (dr, dtheta, distance, angle)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
r_inner: float = 100.0,
|
|||
|
|
r_outer: float = 650.0,
|
|||
|
|
n_radial: int = 50,
|
|||
|
|
n_angular: int = 60
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Initialize polar grid graph.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
r_inner: Inner radius (central hole), mm
|
|||
|
|
r_outer: Outer radius, mm
|
|||
|
|
n_radial: Number of radial samples
|
|||
|
|
n_angular: Number of angular samples
|
|||
|
|
"""
|
|||
|
|
self.r_inner = r_inner
|
|||
|
|
self.r_outer = r_outer
|
|||
|
|
self.n_radial = n_radial
|
|||
|
|
self.n_angular = n_angular
|
|||
|
|
self.n_nodes = n_radial * n_angular
|
|||
|
|
|
|||
|
|
# Create polar grid coordinates
|
|||
|
|
r_1d = np.linspace(r_inner, r_outer, n_radial)
|
|||
|
|
theta_1d = np.linspace(0, 2 * np.pi, n_angular, endpoint=False)
|
|||
|
|
|
|||
|
|
# Meshgrid: theta varies fast (angular index), r varies slow (radial index)
|
|||
|
|
# Shape after flatten: [n_angular * n_radial] with angular varying fastest
|
|||
|
|
Theta, R = np.meshgrid(theta_1d, r_1d) # R shape: [n_radial, n_angular]
|
|||
|
|
|
|||
|
|
# Flatten: radial index varies slowest
|
|||
|
|
self.r = R.flatten().astype(np.float32)
|
|||
|
|
self.theta = Theta.flatten().astype(np.float32)
|
|||
|
|
self.x = (self.r * np.cos(self.theta)).astype(np.float32)
|
|||
|
|
self.y = (self.r * np.sin(self.theta)).astype(np.float32)
|
|||
|
|
|
|||
|
|
# Build graph edges
|
|||
|
|
self.edge_index, self.edge_attr = self._build_polar_edges()
|
|||
|
|
|
|||
|
|
# Precompute normalization factors
|
|||
|
|
self._r_mean = (r_inner + r_outer) / 2
|
|||
|
|
self._r_std = (r_outer - r_inner) / 2
|
|||
|
|
|
|||
|
|
def _node_index(self, i_r: int, i_theta: int) -> int:
|
|||
|
|
"""Convert (radial_index, angular_index) to flat node index."""
|
|||
|
|
# Angular wraps around
|
|||
|
|
i_theta = i_theta % self.n_angular
|
|||
|
|
return i_r * self.n_angular + i_theta
|
|||
|
|
|
|||
|
|
def _build_polar_edges(self) -> Tuple[np.ndarray, np.ndarray]:
|
|||
|
|
"""
|
|||
|
|
Create graph edges respecting polar topology.
|
|||
|
|
|
|||
|
|
Edge types:
|
|||
|
|
1. Radial edges: Connect adjacent radial rings
|
|||
|
|
2. Angular edges: Connect adjacent angular positions (with wrap-around)
|
|||
|
|
3. Diagonal edges: Connect diagonal neighbors for better message passing
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
edge_index: [2, n_edges] array of (source, target) pairs
|
|||
|
|
edge_attr: [n_edges, 4] array of (dr, dtheta, distance, angle)
|
|||
|
|
"""
|
|||
|
|
edges = []
|
|||
|
|
edge_features = []
|
|||
|
|
|
|||
|
|
for i_r in range(self.n_radial):
|
|||
|
|
for i_theta in range(self.n_angular):
|
|||
|
|
node = self._node_index(i_r, i_theta)
|
|||
|
|
|
|||
|
|
# Radial neighbor (outward)
|
|||
|
|
if i_r < self.n_radial - 1:
|
|||
|
|
neighbor = self._node_index(i_r + 1, i_theta)
|
|||
|
|
edges.append([node, neighbor])
|
|||
|
|
edges.append([neighbor, node]) # Bidirectional
|
|||
|
|
|
|||
|
|
# Edge features: (dr, dtheta, distance, relative_angle)
|
|||
|
|
dr = self.r[neighbor] - self.r[node]
|
|||
|
|
dtheta = 0.0
|
|||
|
|
dist = abs(dr)
|
|||
|
|
angle = 0.0 # Radial direction
|
|||
|
|
edge_features.append([dr, dtheta, dist, angle])
|
|||
|
|
edge_features.append([-dr, dtheta, dist, np.pi]) # Reverse
|
|||
|
|
|
|||
|
|
# Angular neighbor (counterclockwise, with wrap-around)
|
|||
|
|
neighbor = self._node_index(i_r, i_theta + 1)
|
|||
|
|
edges.append([node, neighbor])
|
|||
|
|
edges.append([neighbor, node]) # Bidirectional
|
|||
|
|
|
|||
|
|
# Edge features for angular edge
|
|||
|
|
dr = 0.0
|
|||
|
|
dtheta = 2 * np.pi / self.n_angular
|
|||
|
|
# Arc length at this radius
|
|||
|
|
dist = self.r[node] * dtheta
|
|||
|
|
angle = np.pi / 2 # Tangential direction
|
|||
|
|
edge_features.append([dr, dtheta, dist, angle])
|
|||
|
|
edge_features.append([dr, -dtheta, dist, -np.pi / 2]) # Reverse
|
|||
|
|
|
|||
|
|
# Diagonal neighbor (outward + counterclockwise) for better connectivity
|
|||
|
|
if i_r < self.n_radial - 1:
|
|||
|
|
neighbor = self._node_index(i_r + 1, i_theta + 1)
|
|||
|
|
edges.append([node, neighbor])
|
|||
|
|
edges.append([neighbor, node])
|
|||
|
|
|
|||
|
|
dr = self.r[neighbor] - self.r[node]
|
|||
|
|
dtheta = 2 * np.pi / self.n_angular
|
|||
|
|
dx = self.x[neighbor] - self.x[node]
|
|||
|
|
dy = self.y[neighbor] - self.y[node]
|
|||
|
|
dist = np.sqrt(dx**2 + dy**2)
|
|||
|
|
angle = np.arctan2(dy, dx)
|
|||
|
|
edge_features.append([dr, dtheta, dist, angle])
|
|||
|
|
edge_features.append([-dr, -dtheta, dist, angle + np.pi])
|
|||
|
|
|
|||
|
|
edge_index = np.array(edges, dtype=np.int64).T # [2, n_edges]
|
|||
|
|
edge_attr = np.array(edge_features, dtype=np.float32) # [n_edges, 4]
|
|||
|
|
|
|||
|
|
return edge_index, edge_attr
|
|||
|
|
|
|||
|
|
def get_node_features(self, normalized: bool = True) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
Get node features for GNN input.
|
|||
|
|
|
|||
|
|
Features: (r, theta, x, y) - polar and Cartesian coordinates
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
normalized: If True, normalize features to ~[-1, 1] range
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Node features [n_nodes, 4]
|
|||
|
|
"""
|
|||
|
|
if normalized:
|
|||
|
|
r_norm = (self.r - self._r_mean) / self._r_std
|
|||
|
|
theta_norm = self.theta / np.pi - 1 # [0, 2π] → [-1, 1]
|
|||
|
|
x_norm = self.x / self.r_outer
|
|||
|
|
y_norm = self.y / self.r_outer
|
|||
|
|
return np.column_stack([r_norm, theta_norm, x_norm, y_norm]).astype(np.float32)
|
|||
|
|
else:
|
|||
|
|
return np.column_stack([self.r, self.theta, self.x, self.y]).astype(np.float32)
|
|||
|
|
|
|||
|
|
def get_edge_features(self, normalized: bool = True) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
Get edge features for GNN input.
|
|||
|
|
|
|||
|
|
Features: (dr, dtheta, distance, angle)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
normalized: If True, normalize features
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Edge features [n_edges, 4]
|
|||
|
|
"""
|
|||
|
|
if normalized:
|
|||
|
|
edge_attr = self.edge_attr.copy()
|
|||
|
|
edge_attr[:, 0] /= self._r_std # dr
|
|||
|
|
edge_attr[:, 1] /= np.pi # dtheta
|
|||
|
|
edge_attr[:, 2] /= self.r_outer # distance
|
|||
|
|
edge_attr[:, 3] /= np.pi # angle
|
|||
|
|
return edge_attr
|
|||
|
|
else:
|
|||
|
|
return self.edge_attr
|
|||
|
|
|
|||
|
|
def interpolate_from_mesh(
|
|||
|
|
self,
|
|||
|
|
mesh_coords: np.ndarray,
|
|||
|
|
mesh_values: np.ndarray,
|
|||
|
|
method: str = 'rbf'
|
|||
|
|
) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
Interpolate FEA results from mesh nodes to fixed polar grid.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
mesh_coords: FEA node coordinates [n_fea_nodes, 2] or [n_fea_nodes, 3] (X, Y, [Z])
|
|||
|
|
mesh_values: Values to interpolate [n_fea_nodes] or [n_fea_nodes, n_features]
|
|||
|
|
method: Interpolation method ('rbf', 'linear', 'clough_tocher')
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Interpolated values on polar grid [n_nodes] or [n_nodes, n_features]
|
|||
|
|
"""
|
|||
|
|
if not HAS_SCIPY:
|
|||
|
|
raise ImportError("scipy required for interpolation: pip install scipy")
|
|||
|
|
|
|||
|
|
# Use only X, Y coordinates
|
|||
|
|
xy = mesh_coords[:, :2] if mesh_coords.shape[1] > 2 else mesh_coords
|
|||
|
|
|
|||
|
|
# Handle multi-dimensional values
|
|||
|
|
values_1d = mesh_values.ndim == 1
|
|||
|
|
if values_1d:
|
|||
|
|
mesh_values = mesh_values.reshape(-1, 1)
|
|||
|
|
|
|||
|
|
# Target coordinates
|
|||
|
|
target_xy = np.column_stack([self.x, self.y])
|
|||
|
|
|
|||
|
|
result = np.zeros((self.n_nodes, mesh_values.shape[1]), dtype=np.float32)
|
|||
|
|
|
|||
|
|
for i in range(mesh_values.shape[1]):
|
|||
|
|
vals = mesh_values[:, i]
|
|||
|
|
|
|||
|
|
if method == 'rbf':
|
|||
|
|
# RBF interpolation - smooth, handles scattered data well
|
|||
|
|
interp = RBFInterpolator(
|
|||
|
|
xy, vals,
|
|||
|
|
kernel='thin_plate_spline',
|
|||
|
|
smoothing=0.0
|
|||
|
|
)
|
|||
|
|
result[:, i] = interp(target_xy)
|
|||
|
|
|
|||
|
|
elif method == 'linear':
|
|||
|
|
# Linear interpolation via Delaunay triangulation
|
|||
|
|
interp = LinearNDInterpolator(xy, vals, fill_value=np.nan)
|
|||
|
|
result[:, i] = interp(target_xy)
|
|||
|
|
|
|||
|
|
# Handle NaN (points outside convex hull) with nearest neighbor
|
|||
|
|
nan_mask = np.isnan(result[:, i])
|
|||
|
|
if nan_mask.any():
|
|||
|
|
from scipy.spatial import cKDTree
|
|||
|
|
tree = cKDTree(xy)
|
|||
|
|
_, idx = tree.query(target_xy[nan_mask])
|
|||
|
|
result[nan_mask, i] = vals[idx]
|
|||
|
|
|
|||
|
|
elif method == 'clough_tocher':
|
|||
|
|
# Clough-Tocher (C1 smooth) interpolation
|
|||
|
|
interp = CloughTocher2DInterpolator(xy, vals, fill_value=np.nan)
|
|||
|
|
result[:, i] = interp(target_xy)
|
|||
|
|
|
|||
|
|
# Handle NaN
|
|||
|
|
nan_mask = np.isnan(result[:, i])
|
|||
|
|
if nan_mask.any():
|
|||
|
|
from scipy.spatial import cKDTree
|
|||
|
|
tree = cKDTree(xy)
|
|||
|
|
_, idx = tree.query(target_xy[nan_mask])
|
|||
|
|
result[nan_mask, i] = vals[idx]
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unknown interpolation method: {method}")
|
|||
|
|
|
|||
|
|
return result[:, 0] if values_1d else result
|
|||
|
|
|
|||
|
|
def interpolate_field_data(
|
|||
|
|
self,
|
|||
|
|
field_data: Dict[str, Any],
|
|||
|
|
subcases: List[int] = [1, 2, 3, 4],
|
|||
|
|
method: str = 'linear' # Changed from 'rbf' - much faster
|
|||
|
|
) -> Dict[str, np.ndarray]:
|
|||
|
|
"""
|
|||
|
|
Interpolate field data from extract_displacement_field() to polar grid.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
field_data: Output from extract_displacement_field()
|
|||
|
|
subcases: List of subcases to interpolate
|
|||
|
|
method: Interpolation method
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dictionary with:
|
|||
|
|
- z_displacement: [n_nodes, n_subcases] array
|
|||
|
|
- original_n_nodes: Number of FEA nodes
|
|||
|
|
"""
|
|||
|
|
mesh_coords = field_data['node_coords']
|
|||
|
|
z_disp_dict = field_data['z_displacement']
|
|||
|
|
|
|||
|
|
# Stack subcases
|
|||
|
|
z_disp_list = []
|
|||
|
|
for sc in subcases:
|
|||
|
|
if sc in z_disp_dict:
|
|||
|
|
z_disp_list.append(z_disp_dict[sc])
|
|||
|
|
else:
|
|||
|
|
raise KeyError(f"Subcase {sc} not found in field_data")
|
|||
|
|
|
|||
|
|
# [n_fea_nodes, n_subcases]
|
|||
|
|
z_disp_mesh = np.column_stack(z_disp_list)
|
|||
|
|
|
|||
|
|
# Interpolate to polar grid
|
|||
|
|
z_disp_grid = self.interpolate_from_mesh(mesh_coords, z_disp_mesh, method=method)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'z_displacement': z_disp_grid, # [n_nodes, n_subcases]
|
|||
|
|
'original_n_nodes': len(mesh_coords),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def to_pyg_data(
|
|||
|
|
self,
|
|||
|
|
z_displacement: np.ndarray,
|
|||
|
|
design_vars: np.ndarray,
|
|||
|
|
objectives: Optional[Dict[str, float]] = None
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Convert to PyTorch Geometric Data object.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
z_displacement: [n_nodes, n_subcases] displacement field
|
|||
|
|
design_vars: [n_design_vars] design parameters
|
|||
|
|
objectives: Optional dict of objective values (ground truth)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
torch_geometric.data.Data object
|
|||
|
|
"""
|
|||
|
|
if not HAS_TORCH:
|
|||
|
|
raise ImportError("PyTorch required: pip install torch")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from torch_geometric.data import Data
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("PyTorch Geometric required: pip install torch-geometric")
|
|||
|
|
|
|||
|
|
# Node features: (r, theta, x, y)
|
|||
|
|
node_features = torch.tensor(self.get_node_features(normalized=True), dtype=torch.float32)
|
|||
|
|
|
|||
|
|
# Edge index and features
|
|||
|
|
edge_index = torch.tensor(self.edge_index, dtype=torch.long)
|
|||
|
|
edge_attr = torch.tensor(self.get_edge_features(normalized=True), dtype=torch.float32)
|
|||
|
|
|
|||
|
|
# Target: Z-displacement field
|
|||
|
|
y = torch.tensor(z_displacement, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
# Design variables (global feature)
|
|||
|
|
design = torch.tensor(design_vars, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
data = Data(
|
|||
|
|
x=node_features,
|
|||
|
|
edge_index=edge_index,
|
|||
|
|
edge_attr=edge_attr,
|
|||
|
|
y=y,
|
|||
|
|
design=design,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Add objectives if provided
|
|||
|
|
if objectives:
|
|||
|
|
for key, value in objectives.items():
|
|||
|
|
setattr(data, key, torch.tensor([value], dtype=torch.float32))
|
|||
|
|
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
def save(self, path: Path) -> None:
|
|||
|
|
"""Save graph structure to JSON file."""
|
|||
|
|
path = Path(path)
|
|||
|
|
|
|||
|
|
data = {
|
|||
|
|
'r_inner': self.r_inner,
|
|||
|
|
'r_outer': self.r_outer,
|
|||
|
|
'n_radial': self.n_radial,
|
|||
|
|
'n_angular': self.n_angular,
|
|||
|
|
'n_nodes': self.n_nodes,
|
|||
|
|
'n_edges': self.edge_index.shape[1],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(path, 'w') as f:
|
|||
|
|
json.dump(data, f, indent=2)
|
|||
|
|
|
|||
|
|
# Save arrays separately for efficiency
|
|||
|
|
np.savez_compressed(
|
|||
|
|
path.with_suffix('.npz'),
|
|||
|
|
r=self.r,
|
|||
|
|
theta=self.theta,
|
|||
|
|
x=self.x,
|
|||
|
|
y=self.y,
|
|||
|
|
edge_index=self.edge_index,
|
|||
|
|
edge_attr=self.edge_attr,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def load(cls, path: Path) -> 'PolarMirrorGraph':
|
|||
|
|
"""Load graph structure from file."""
|
|||
|
|
path = Path(path)
|
|||
|
|
|
|||
|
|
with open(path, 'r') as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
|
|||
|
|
return cls(
|
|||
|
|
r_inner=data['r_inner'],
|
|||
|
|
r_outer=data['r_outer'],
|
|||
|
|
n_radial=data['n_radial'],
|
|||
|
|
n_angular=data['n_angular'],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def __repr__(self) -> str:
|
|||
|
|
return (
|
|||
|
|
f"PolarMirrorGraph("
|
|||
|
|
f"r=[{self.r_inner}, {self.r_outer}]mm, "
|
|||
|
|
f"grid={self.n_radial}×{self.n_angular}, "
|
|||
|
|
f"nodes={self.n_nodes}, "
|
|||
|
|
f"edges={self.edge_index.shape[1]})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Convenience functions
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
def create_mirror_dataset(
|
|||
|
|
study_dir: Path,
|
|||
|
|
polar_graph: Optional[PolarMirrorGraph] = None,
|
|||
|
|
subcases: List[int] = [1, 2, 3, 4],
|
|||
|
|
verbose: bool = True
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Create GNN dataset from a study's gnn_data folder.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
study_dir: Path to study directory
|
|||
|
|
polar_graph: PolarMirrorGraph instance (created if None)
|
|||
|
|
subcases: Subcases to include
|
|||
|
|
verbose: Print progress
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List of data dictionaries, each containing:
|
|||
|
|
- z_displacement: [n_nodes, n_subcases]
|
|||
|
|
- design_vars: [n_vars]
|
|||
|
|
- trial_number: int
|
|||
|
|
- original_n_nodes: int
|
|||
|
|
"""
|
|||
|
|
from optimization_engine.gnn.extract_displacement_field import load_field
|
|||
|
|
|
|||
|
|
study_dir = Path(study_dir)
|
|||
|
|
gnn_data_dir = study_dir / "gnn_data"
|
|||
|
|
|
|||
|
|
if not gnn_data_dir.exists():
|
|||
|
|
raise FileNotFoundError(f"No gnn_data folder in {study_dir}")
|
|||
|
|
|
|||
|
|
# Load index
|
|||
|
|
index_path = gnn_data_dir / "dataset_index.json"
|
|||
|
|
with open(index_path, 'r') as f:
|
|||
|
|
index = json.load(f)
|
|||
|
|
|
|||
|
|
if polar_graph is None:
|
|||
|
|
polar_graph = PolarMirrorGraph()
|
|||
|
|
|
|||
|
|
dataset = []
|
|||
|
|
|
|||
|
|
for trial_num, trial_info in index['trials'].items():
|
|||
|
|
if trial_info.get('status') != 'success':
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
trial_dir = study_dir / trial_info['trial_dir']
|
|||
|
|
|
|||
|
|
# Find field file
|
|||
|
|
field_path = None
|
|||
|
|
for ext in ['.h5', '.npz']:
|
|||
|
|
candidate = trial_dir / f"displacement_field{ext}"
|
|||
|
|
if candidate.exists():
|
|||
|
|
field_path = candidate
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if field_path is None:
|
|||
|
|
if verbose:
|
|||
|
|
print(f"[WARN] No field file for trial {trial_num}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Load field data
|
|||
|
|
field_data = load_field(field_path)
|
|||
|
|
|
|||
|
|
# Interpolate to polar grid
|
|||
|
|
interp_result = polar_graph.interpolate_field_data(field_data, subcases=subcases)
|
|||
|
|
|
|||
|
|
# Get design parameters
|
|||
|
|
params = trial_info.get('params', {})
|
|||
|
|
design_vars = np.array(list(params.values()), dtype=np.float32) if params else np.array([])
|
|||
|
|
|
|||
|
|
dataset.append({
|
|||
|
|
'z_displacement': interp_result['z_displacement'],
|
|||
|
|
'design_vars': design_vars,
|
|||
|
|
'design_names': list(params.keys()) if params else [],
|
|||
|
|
'trial_number': int(trial_num),
|
|||
|
|
'original_n_nodes': interp_result['original_n_nodes'],
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f"[OK] Trial {trial_num}: {interp_result['original_n_nodes']} → {polar_graph.n_nodes} nodes")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if verbose:
|
|||
|
|
print(f"[ERR] Trial {trial_num}: {e}")
|
|||
|
|
|
|||
|
|
return dataset
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# CLI
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description='Test PolarMirrorGraph')
|
|||
|
|
parser.add_argument('--test', action='store_true', help='Run basic tests')
|
|||
|
|
parser.add_argument('--study', type=Path, help='Create dataset from study')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
if args.test:
|
|||
|
|
print("="*60)
|
|||
|
|
print("TESTING PolarMirrorGraph")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# Create graph
|
|||
|
|
graph = PolarMirrorGraph(r_inner=100, r_outer=650, n_radial=50, n_angular=60)
|
|||
|
|
print(f"\n{graph}")
|
|||
|
|
|
|||
|
|
# Check node features
|
|||
|
|
node_feat = graph.get_node_features(normalized=True)
|
|||
|
|
print(f"\nNode features shape: {node_feat.shape}")
|
|||
|
|
print(f" r range: [{node_feat[:, 0].min():.2f}, {node_feat[:, 0].max():.2f}]")
|
|||
|
|
print(f" theta range: [{node_feat[:, 1].min():.2f}, {node_feat[:, 1].max():.2f}]")
|
|||
|
|
|
|||
|
|
# Check edge features
|
|||
|
|
edge_feat = graph.get_edge_features(normalized=True)
|
|||
|
|
print(f"\nEdge features shape: {edge_feat.shape}")
|
|||
|
|
print(f" dr range: [{edge_feat[:, 0].min():.2f}, {edge_feat[:, 0].max():.2f}]")
|
|||
|
|
print(f" distance range: [{edge_feat[:, 2].min():.2f}, {edge_feat[:, 2].max():.2f}]")
|
|||
|
|
|
|||
|
|
# Test interpolation with synthetic data
|
|||
|
|
print("\n--- Testing Interpolation ---")
|
|||
|
|
|
|||
|
|
# Create fake mesh data (random points in annulus)
|
|||
|
|
np.random.seed(42)
|
|||
|
|
n_mesh = 5000
|
|||
|
|
r_mesh = np.random.uniform(100, 650, n_mesh)
|
|||
|
|
theta_mesh = np.random.uniform(0, 2*np.pi, n_mesh)
|
|||
|
|
x_mesh = r_mesh * np.cos(theta_mesh)
|
|||
|
|
y_mesh = r_mesh * np.sin(theta_mesh)
|
|||
|
|
mesh_coords = np.column_stack([x_mesh, y_mesh])
|
|||
|
|
|
|||
|
|
# Synthetic displacement: smooth function
|
|||
|
|
mesh_values = 0.001 * (r_mesh / 650) ** 2 * np.cos(2 * theta_mesh)
|
|||
|
|
|
|||
|
|
# Interpolate
|
|||
|
|
grid_values = graph.interpolate_from_mesh(mesh_coords, mesh_values, method='rbf')
|
|||
|
|
print(f"Interpolated {n_mesh} mesh nodes → {len(grid_values)} grid nodes")
|
|||
|
|
print(f" Input range: [{mesh_values.min():.6f}, {mesh_values.max():.6f}]")
|
|||
|
|
print(f" Output range: [{grid_values.min():.6f}, {grid_values.max():.6f}]")
|
|||
|
|
|
|||
|
|
print("\n✓ All tests passed!")
|
|||
|
|
|
|||
|
|
elif args.study:
|
|||
|
|
# Create dataset from study
|
|||
|
|
print(f"Creating dataset from: {args.study}")
|
|||
|
|
|
|||
|
|
graph = PolarMirrorGraph()
|
|||
|
|
dataset = create_mirror_dataset(args.study, polar_graph=graph)
|
|||
|
|
|
|||
|
|
print(f"\nDataset: {len(dataset)} samples")
|
|||
|
|
if dataset:
|
|||
|
|
print(f" Z-displacement shape: {dataset[0]['z_displacement'].shape}")
|
|||
|
|
print(f" Design vars: {len(dataset[0]['design_vars'])} variables")
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# Default: just show info
|
|||
|
|
graph = PolarMirrorGraph()
|
|||
|
|
print(graph)
|
|||
|
|
print(f"\nNode features: {graph.get_node_features().shape}")
|
|||
|
|
print(f"Edge index: {graph.edge_index.shape}")
|
|||
|
|
print(f"Edge features: {graph.edge_attr.shape}")
|