70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
|
|
"""
|
||
|
|
GNN (Graph Neural Network) Surrogate Module for Atomizer
|
||
|
|
=========================================================
|
||
|
|
|
||
|
|
This module provides Graph Neural Network-based surrogates for FEA optimization,
|
||
|
|
particularly designed for Zernike-based mirror optimization where spatial structure
|
||
|
|
matters.
|
||
|
|
|
||
|
|
Key Components:
|
||
|
|
- PolarMirrorGraph: Fixed polar grid graph structure for mirror surface
|
||
|
|
- ZernikeGNN: GNN model for predicting displacement fields
|
||
|
|
- DifferentiableZernikeFit: GPU-accelerated Zernike fitting
|
||
|
|
- ZernikeObjectiveLayer: Compute objectives from displacement fields
|
||
|
|
- ZernikeGNNTrainer: Complete training pipeline
|
||
|
|
|
||
|
|
Why GNN over MLP for Zernike?
|
||
|
|
1. Spatial awareness: GNN learns smooth deformation fields via message passing
|
||
|
|
2. Correct relative computation: Predicts fields, then subtracts (like FEA)
|
||
|
|
3. Multi-task learning: Field + objective supervision
|
||
|
|
4. Physics-informed: Edge structure respects mirror geometry
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
# Training
|
||
|
|
python -m optimization_engine.gnn.train_zernike_gnn V11 V12 --epochs 200
|
||
|
|
|
||
|
|
# API
|
||
|
|
from optimization_engine.gnn import PolarMirrorGraph, ZernikeGNN, ZernikeGNNTrainer
|
||
|
|
"""
|
||
|
|
|
||
|
|
__version__ = "1.0.0"
|
||
|
|
|
||
|
|
# Core components
|
||
|
|
from .polar_graph import PolarMirrorGraph, create_mirror_dataset
|
||
|
|
from .zernike_gnn import ZernikeGNN, ZernikeGNNLite, create_model, load_model
|
||
|
|
from .differentiable_zernike import (
|
||
|
|
DifferentiableZernikeFit,
|
||
|
|
ZernikeObjectiveLayer,
|
||
|
|
ZernikeRMSLoss,
|
||
|
|
build_zernike_matrix,
|
||
|
|
)
|
||
|
|
from .extract_displacement_field import (
|
||
|
|
extract_displacement_field,
|
||
|
|
save_field,
|
||
|
|
load_field,
|
||
|
|
)
|
||
|
|
from .train_zernike_gnn import ZernikeGNNTrainer, MirrorDataset
|
||
|
|
|
||
|
|
__all__ = [
|
||
|
|
# Polar Graph
|
||
|
|
'PolarMirrorGraph',
|
||
|
|
'create_mirror_dataset',
|
||
|
|
# GNN Model
|
||
|
|
'ZernikeGNN',
|
||
|
|
'ZernikeGNNLite',
|
||
|
|
'create_model',
|
||
|
|
'load_model',
|
||
|
|
# Zernike Layers
|
||
|
|
'DifferentiableZernikeFit',
|
||
|
|
'ZernikeObjectiveLayer',
|
||
|
|
'ZernikeRMSLoss',
|
||
|
|
'build_zernike_matrix',
|
||
|
|
# Field Extraction
|
||
|
|
'extract_displacement_field',
|
||
|
|
'save_field',
|
||
|
|
'load_field',
|
||
|
|
# Training
|
||
|
|
'ZernikeGNNTrainer',
|
||
|
|
'MirrorDataset',
|
||
|
|
]
|