feat: Add Zernike GNN surrogate module and M1 mirror V12/V13 studies
This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
535
studies/m1_mirror_adaptive_V12/run_gnn_turbo.py
Normal file
535
studies/m1_mirror_adaptive_V12/run_gnn_turbo.py
Normal file
@@ -0,0 +1,535 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
M1 Mirror - GNN Turbo Optimization with FEA Validation
|
||||
=======================================================
|
||||
|
||||
Runs fast GNN-based turbo optimization (5000 trials in ~2 min) then
|
||||
validates top candidates with actual FEA (~5 min each).
|
||||
|
||||
Usage:
|
||||
python run_gnn_turbo.py # Full workflow: 5000 GNN + 5 FEA validations
|
||||
python run_gnn_turbo.py --gnn-only # Just GNN turbo, no FEA
|
||||
python run_gnn_turbo.py --validate-top 10 # Validate top 10 instead of 5
|
||||
python run_gnn_turbo.py --trials 10000 # More GNN trials
|
||||
|
||||
Estimated time: ~2 min GNN + ~25 min FEA validation = ~27 min total
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add parent directories to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
|
||||
from optimization_engine.nx_solver import NXSolver
|
||||
from optimization_engine.utils import ensure_nx_running
|
||||
from optimization_engine.gnn.extract_displacement_field import (
|
||||
extract_displacement_field, save_field_to_hdf5
|
||||
)
|
||||
from optimization_engine.extractors.extract_zernike_surface import extract_surface_zernike
|
||||
|
||||
# ============================================================================
|
||||
# Paths
|
||||
# ============================================================================
|
||||
|
||||
STUDY_DIR = Path(__file__).parent
|
||||
SETUP_DIR = STUDY_DIR / "1_setup"
|
||||
MODEL_DIR = SETUP_DIR / "model"
|
||||
CONFIG_PATH = SETUP_DIR / "optimization_config.json"
|
||||
CHECKPOINT_PATH = Path("C:/Users/Antoine/Atomizer/zernike_gnn_checkpoint.pt")
|
||||
RESULTS_DIR = STUDY_DIR / "gnn_turbo_results"
|
||||
LOG_FILE = STUDY_DIR / "gnn_turbo.log"
|
||||
|
||||
# Ensure directories exist
|
||||
RESULTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# ============================================================================
|
||||
# Logging
|
||||
# ============================================================================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s | %(levelname)-8s | %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler(LOG_FILE, mode='a')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GNN Turbo Runner
|
||||
# ============================================================================
|
||||
|
||||
class GNNTurboRunner:
|
||||
"""
|
||||
Run GNN turbo optimization with optional FEA validation.
|
||||
|
||||
Workflow:
|
||||
1. Load trained GNN model
|
||||
2. Run fast turbo optimization (5000 trials in ~2 min)
|
||||
3. Extract Pareto front and top candidates per objective
|
||||
4. Validate selected candidates with actual FEA
|
||||
5. Report GNN vs FEA accuracy
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Path, checkpoint_path: Path):
|
||||
logger.info("=" * 60)
|
||||
logger.info("GNN TURBO OPTIMIZER")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Load config
|
||||
with open(config_path) as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# Load GNN optimizer
|
||||
logger.info(f"Loading GNN from {checkpoint_path}")
|
||||
self.gnn = ZernikeGNNOptimizer.from_checkpoint(checkpoint_path, config_path)
|
||||
logger.info(f"GNN loaded. Design variables: {self.gnn.design_names}")
|
||||
logger.info(f"disp_scale: {self.gnn.disp_scale}")
|
||||
|
||||
# Design variable info
|
||||
self.design_vars = [v for v in self.config['design_variables'] if v.get('enabled', True)]
|
||||
self.objectives = self.config['objectives']
|
||||
self.objective_names = [obj['name'] for obj in self.objectives]
|
||||
|
||||
# NX Solver for FEA validation
|
||||
self.nx_solver = None # Lazy init
|
||||
|
||||
def _init_nx_solver(self):
|
||||
"""Initialize NX solver only when needed (for FEA validation)."""
|
||||
if self.nx_solver is not None:
|
||||
return
|
||||
|
||||
nx_settings = self.config.get('nx_settings', {})
|
||||
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
|
||||
version_match = re.search(r'NX(\d+)', nx_install_dir)
|
||||
nastran_version = version_match.group(1) if version_match else "2506"
|
||||
|
||||
self.nx_solver = NXSolver(
|
||||
master_model_dir=str(MODEL_DIR),
|
||||
nx_install_dir=nx_install_dir,
|
||||
nastran_version=nastran_version,
|
||||
timeout=nx_settings.get('simulation_timeout_s', 600),
|
||||
use_iteration_folders=True,
|
||||
study_name="m1_mirror_adaptive_V12_gnn_validation"
|
||||
)
|
||||
|
||||
# Ensure NX is running
|
||||
ensure_nx_running(nx_install_dir)
|
||||
|
||||
def run_turbo(self, n_trials: int = 5000) -> dict:
|
||||
"""
|
||||
Run GNN turbo optimization.
|
||||
|
||||
Returns dict with:
|
||||
- all_predictions: List of all predictions
|
||||
- pareto_front: Pareto-optimal designs
|
||||
- best_per_objective: Best design for each objective
|
||||
"""
|
||||
logger.info(f"\nRunning turbo optimization ({n_trials} trials)...")
|
||||
start_time = time.time()
|
||||
|
||||
results = self.gnn.turbo_optimize(n_trials=n_trials, verbose=True)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Turbo completed in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
|
||||
|
||||
# Get Pareto front
|
||||
pareto = results.get_pareto_front()
|
||||
logger.info(f"Found {len(pareto)} Pareto-optimal designs")
|
||||
|
||||
# Get best per objective
|
||||
best_per_obj = {}
|
||||
for obj_name in self.objective_names:
|
||||
best = results.get_best(n=1, objective=obj_name)[0]
|
||||
best_per_obj[obj_name] = best
|
||||
logger.info(f"Best {obj_name}: {best.objectives[obj_name]:.2f} nm")
|
||||
|
||||
return {
|
||||
'results': results,
|
||||
'pareto': pareto,
|
||||
'best_per_objective': best_per_obj,
|
||||
'elapsed_time': elapsed
|
||||
}
|
||||
|
||||
def select_validation_candidates(self, turbo_results: dict, n_validate: int = 5) -> list:
|
||||
"""
|
||||
Select diverse candidates for FEA validation.
|
||||
|
||||
Strategy: Select from Pareto front with diversity preference.
|
||||
If Pareto front has fewer than n_validate, add best per objective.
|
||||
"""
|
||||
candidates = []
|
||||
seen_designs = set()
|
||||
|
||||
pareto = turbo_results['pareto']
|
||||
best_per_obj = turbo_results['best_per_objective']
|
||||
|
||||
# First, add best per objective (most important to validate)
|
||||
for obj_name, pred in best_per_obj.items():
|
||||
design_key = tuple(round(v, 4) for v in pred.design.values())
|
||||
if design_key not in seen_designs:
|
||||
candidates.append({
|
||||
'design': pred.design,
|
||||
'gnn_objectives': pred.objectives,
|
||||
'source': f'best_{obj_name}'
|
||||
})
|
||||
seen_designs.add(design_key)
|
||||
|
||||
if len(candidates) >= n_validate:
|
||||
break
|
||||
|
||||
# Fill remaining slots from Pareto front
|
||||
if len(candidates) < n_validate and len(pareto) > 0:
|
||||
# Sort Pareto by sum of objectives (balanced designs)
|
||||
pareto_sorted = sorted(pareto,
|
||||
key=lambda p: sum(p.objectives.values()))
|
||||
|
||||
for pred in pareto_sorted:
|
||||
design_key = tuple(round(v, 4) for v in pred.design.values())
|
||||
if design_key not in seen_designs:
|
||||
candidates.append({
|
||||
'design': pred.design,
|
||||
'gnn_objectives': pred.objectives,
|
||||
'source': 'pareto_front'
|
||||
})
|
||||
seen_designs.add(design_key)
|
||||
|
||||
if len(candidates) >= n_validate:
|
||||
break
|
||||
|
||||
logger.info(f"Selected {len(candidates)} candidates for FEA validation:")
|
||||
for i, c in enumerate(candidates):
|
||||
logger.info(f" {i+1}. {c['source']}: 40vs20={c['gnn_objectives']['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||||
|
||||
return candidates
|
||||
|
||||
def run_fea_validation(self, design: dict, trial_num: int) -> dict:
|
||||
"""
|
||||
Run FEA for a single design and extract Zernike objectives.
|
||||
|
||||
Returns dict with success status and FEA objectives.
|
||||
"""
|
||||
self._init_nx_solver()
|
||||
|
||||
trial_dir = RESULTS_DIR / f"validation_{trial_num:04d}"
|
||||
trial_dir.mkdir(exist_ok=True)
|
||||
|
||||
logger.info(f"Validation {trial_num}: Running FEA...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Build expression updates
|
||||
expressions = {var['expression_name']: design[var['name']]
|
||||
for var in self.design_vars}
|
||||
|
||||
# Create iteration folder
|
||||
iter_folder = self.nx_solver.create_iteration_folder(
|
||||
iterations_base_dir=RESULTS_DIR / "iterations",
|
||||
iteration_number=trial_num,
|
||||
expression_updates=expressions
|
||||
)
|
||||
|
||||
# Run solve
|
||||
op2_path = self.nx_solver.run_solve(
|
||||
sim_file=iter_folder / self.config['nx_settings']['sim_file'],
|
||||
solution_name=self.config['nx_settings']['solution_name']
|
||||
)
|
||||
|
||||
if op2_path is None or not Path(op2_path).exists():
|
||||
logger.error(f"Validation {trial_num}: FEA solve failed - no OP2")
|
||||
return {'success': False, 'error': 'No OP2 file'}
|
||||
|
||||
# Extract Zernike objectives using the same extractor as training
|
||||
bdf_path = iter_folder / "model.bdf"
|
||||
if not bdf_path.exists():
|
||||
bdf_files = list(iter_folder.glob("*.bdf"))
|
||||
bdf_path = bdf_files[0] if bdf_files else None
|
||||
|
||||
# Use extract_surface_zernike to get objectives
|
||||
zernike_result = extract_surface_zernike(
|
||||
op2_path=str(op2_path),
|
||||
bdf_path=str(bdf_path),
|
||||
n_modes=50,
|
||||
r_inner=100.0,
|
||||
r_outer=650.0,
|
||||
n_radial=50,
|
||||
n_angular=60
|
||||
)
|
||||
|
||||
if not zernike_result.get('success', False):
|
||||
logger.error(f"Validation {trial_num}: Zernike extraction failed")
|
||||
return {'success': False, 'error': zernike_result.get('error', 'Unknown')}
|
||||
|
||||
# Compute relative objectives (same as GNN training data)
|
||||
objectives = self._compute_relative_objectives(zernike_result)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Validation {trial_num}: Completed in {elapsed:.1f}s")
|
||||
logger.info(f" 40vs20: {objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
||||
logger.info(f" 60vs20: {objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
|
||||
logger.info(f" mfg90: {objectives['mfg_90_optician_workload']:.2f} nm")
|
||||
|
||||
# Save results
|
||||
results = {
|
||||
'success': True,
|
||||
'design': design,
|
||||
'objectives': objectives,
|
||||
'op2_path': str(op2_path),
|
||||
'elapsed_time': elapsed
|
||||
}
|
||||
|
||||
with open(trial_dir / 'fea_result.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation {trial_num}: Error - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _compute_relative_objectives(self, zernike_result: dict) -> dict:
|
||||
"""
|
||||
Compute relative Zernike objectives from extraction result.
|
||||
|
||||
Matches the exact computation used in GNN training data preparation.
|
||||
"""
|
||||
coeffs = zernike_result['data']['coefficients'] # Dict by subcase
|
||||
|
||||
# Subcase mapping: 1=90deg, 2=20deg(ref), 3=40deg, 4=60deg
|
||||
subcases = ['1', '2', '3', '4']
|
||||
|
||||
# Convert coefficients to arrays
|
||||
coeff_arrays = {}
|
||||
for sc in subcases:
|
||||
if sc in coeffs:
|
||||
coeff_arrays[sc] = np.array(coeffs[sc])
|
||||
|
||||
# Objective 1: rel_filtered_rms_40_vs_20
|
||||
# Relative = subcase 3 (40deg) - subcase 2 (20deg ref)
|
||||
# Filter: remove J1-J4 (first 4 modes)
|
||||
rel_40_vs_20 = coeff_arrays['3'] - coeff_arrays['2']
|
||||
rel_40_vs_20_filtered = rel_40_vs_20[4:] # Skip J1-J4
|
||||
rms_40_vs_20 = np.sqrt(np.sum(rel_40_vs_20_filtered ** 2))
|
||||
|
||||
# Objective 2: rel_filtered_rms_60_vs_20
|
||||
rel_60_vs_20 = coeff_arrays['4'] - coeff_arrays['2']
|
||||
rel_60_vs_20_filtered = rel_60_vs_20[4:] # Skip J1-J4
|
||||
rms_60_vs_20 = np.sqrt(np.sum(rel_60_vs_20_filtered ** 2))
|
||||
|
||||
# Objective 3: mfg_90_optician_workload (J1-J3 filtered, keep J4 defocus)
|
||||
rel_90_vs_20 = coeff_arrays['1'] - coeff_arrays['2']
|
||||
rel_90_vs_20_filtered = rel_90_vs_20[3:] # Skip only J1-J3 (keep J4 defocus)
|
||||
rms_mfg_90 = np.sqrt(np.sum(rel_90_vs_20_filtered ** 2))
|
||||
|
||||
return {
|
||||
'rel_filtered_rms_40_vs_20': float(rms_40_vs_20),
|
||||
'rel_filtered_rms_60_vs_20': float(rms_60_vs_20),
|
||||
'mfg_90_optician_workload': float(rms_mfg_90)
|
||||
}
|
||||
|
||||
def compare_results(self, candidates: list) -> dict:
|
||||
"""
|
||||
Compare GNN predictions vs FEA results.
|
||||
|
||||
Returns accuracy statistics.
|
||||
"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("GNN vs FEA COMPARISON")
|
||||
logger.info("=" * 60)
|
||||
|
||||
errors = {obj: [] for obj in self.objective_names}
|
||||
|
||||
for c in candidates:
|
||||
if 'fea_objectives' not in c or not c.get('fea_success', False):
|
||||
continue
|
||||
|
||||
gnn = c['gnn_objectives']
|
||||
fea = c['fea_objectives']
|
||||
|
||||
logger.info(f"\n{c['source']}:")
|
||||
logger.info(f" {'Objective':<30} {'GNN':<10} {'FEA':<10} {'Error':<10}")
|
||||
logger.info(f" {'-'*60}")
|
||||
|
||||
for obj in self.objective_names:
|
||||
gnn_val = gnn[obj]
|
||||
fea_val = fea[obj]
|
||||
error_pct = abs(gnn_val - fea_val) / fea_val * 100 if fea_val > 0 else 0
|
||||
|
||||
logger.info(f" {obj:<30} {gnn_val:<10.2f} {fea_val:<10.2f} {error_pct:<10.1f}%")
|
||||
errors[obj].append(error_pct)
|
||||
|
||||
# Summary statistics
|
||||
logger.info("\n" + "-" * 60)
|
||||
logger.info("SUMMARY STATISTICS")
|
||||
logger.info("-" * 60)
|
||||
|
||||
summary = {}
|
||||
for obj in self.objective_names:
|
||||
if errors[obj]:
|
||||
mean_err = np.mean(errors[obj])
|
||||
max_err = np.max(errors[obj])
|
||||
summary[obj] = {'mean_error_pct': mean_err, 'max_error_pct': max_err}
|
||||
logger.info(f"{obj}: Mean error = {mean_err:.1f}%, Max error = {max_err:.1f}%")
|
||||
|
||||
return summary
|
||||
|
||||
def run_full_workflow(self, n_trials: int = 5000, n_validate: int = 5, gnn_only: bool = False):
|
||||
"""
|
||||
Run complete workflow: GNN turbo → select candidates → FEA validation → comparison.
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Phase 1: GNN Turbo
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 1: GNN TURBO OPTIMIZATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
turbo_results = self.run_turbo(n_trials=n_trials)
|
||||
|
||||
# Save turbo results
|
||||
turbo_summary = {
|
||||
'timestamp': timestamp,
|
||||
'n_trials': n_trials,
|
||||
'n_pareto': len(turbo_results['pareto']),
|
||||
'elapsed_time': turbo_results['elapsed_time'],
|
||||
'best_per_objective': {
|
||||
obj: {
|
||||
'design': pred.design,
|
||||
'objectives': pred.objectives
|
||||
}
|
||||
for obj, pred in turbo_results['best_per_objective'].items()
|
||||
},
|
||||
'pareto_front': [
|
||||
{'design': p.design, 'objectives': p.objectives}
|
||||
for p in turbo_results['pareto'][:20] # Top 20 from Pareto
|
||||
]
|
||||
}
|
||||
|
||||
turbo_file = RESULTS_DIR / f'turbo_results_{timestamp}.json'
|
||||
with open(turbo_file, 'w') as f:
|
||||
json.dump(turbo_summary, f, indent=2)
|
||||
logger.info(f"Turbo results saved to {turbo_file}")
|
||||
|
||||
if gnn_only:
|
||||
logger.info("\n--gnn-only flag set, skipping FEA validation")
|
||||
return {'turbo': turbo_summary}
|
||||
|
||||
# Phase 2: FEA Validation
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 2: FEA VALIDATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
candidates = self.select_validation_candidates(turbo_results, n_validate=n_validate)
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
logger.info(f"\n--- Validating candidate {i+1}/{len(candidates)} ---")
|
||||
fea_result = self.run_fea_validation(candidate['design'], trial_num=i+1)
|
||||
candidate['fea_success'] = fea_result.get('success', False)
|
||||
if fea_result.get('success'):
|
||||
candidate['fea_objectives'] = fea_result['objectives']
|
||||
candidate['fea_time'] = fea_result.get('elapsed_time', 0)
|
||||
|
||||
# Phase 3: Comparison
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 3: RESULTS COMPARISON")
|
||||
logger.info("=" * 60)
|
||||
|
||||
comparison = self.compare_results(candidates)
|
||||
|
||||
# Save final report
|
||||
final_report = {
|
||||
'timestamp': timestamp,
|
||||
'turbo_summary': turbo_summary,
|
||||
'validation_candidates': [
|
||||
{
|
||||
'source': c['source'],
|
||||
'design': c['design'],
|
||||
'gnn_objectives': c['gnn_objectives'],
|
||||
'fea_objectives': c.get('fea_objectives'),
|
||||
'fea_success': c.get('fea_success', False),
|
||||
'fea_time': c.get('fea_time')
|
||||
}
|
||||
for c in candidates
|
||||
],
|
||||
'accuracy_summary': comparison
|
||||
}
|
||||
|
||||
report_file = RESULTS_DIR / f'gnn_turbo_report_{timestamp}.json'
|
||||
with open(report_file, 'w') as f:
|
||||
json.dump(final_report, f, indent=2)
|
||||
logger.info(f"\nFinal report saved to {report_file}")
|
||||
|
||||
# Print final summary
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("WORKFLOW COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"GNN Turbo: {n_trials} trials in {turbo_results['elapsed_time']:.1f}s")
|
||||
logger.info(f"Pareto front: {len(turbo_results['pareto'])} designs")
|
||||
|
||||
successful_validations = sum(1 for c in candidates if c.get('fea_success', False))
|
||||
logger.info(f"FEA Validations: {successful_validations}/{len(candidates)} successful")
|
||||
|
||||
if comparison:
|
||||
avg_errors = [np.mean([comparison[obj]['mean_error_pct'] for obj in comparison])]
|
||||
logger.info(f"Overall GNN accuracy: {100 - np.mean(avg_errors):.1f}%")
|
||||
|
||||
return final_report
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GNN Turbo Optimization with FEA Validation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
)
|
||||
parser.add_argument('--trials', type=int, default=5000,
|
||||
help='Number of GNN turbo trials (default: 5000)')
|
||||
parser.add_argument('--validate-top', type=int, default=5,
|
||||
help='Number of top candidates to validate with FEA (default: 5)')
|
||||
parser.add_argument('--gnn-only', action='store_true',
|
||||
help='Run only GNN turbo, skip FEA validation')
|
||||
parser.add_argument('--checkpoint', type=str, default=str(CHECKPOINT_PATH),
|
||||
help='Path to GNN checkpoint')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting GNN Turbo Optimization")
|
||||
logger.info(f" Checkpoint: {args.checkpoint}")
|
||||
logger.info(f" GNN trials: {args.trials}")
|
||||
logger.info(f" FEA validations: {args.validate_top if not args.gnn_only else 'SKIP'}")
|
||||
|
||||
runner = GNNTurboRunner(
|
||||
config_path=CONFIG_PATH,
|
||||
checkpoint_path=Path(args.checkpoint)
|
||||
)
|
||||
|
||||
report = runner.run_full_workflow(
|
||||
n_trials=args.trials,
|
||||
n_validate=args.validate_top,
|
||||
gnn_only=args.gnn_only
|
||||
)
|
||||
|
||||
logger.info("\nDone!")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user