This commit introduces the GNN-based surrogate for Zernike mirror optimization and the M1 mirror study progression from V12 (GNN validation) to V13 (pure NSGA-II). ## GNN Surrogate Module (optimization_engine/gnn/) New module for Graph Neural Network surrogate prediction of mirror deformations: - `polar_graph.py`: PolarMirrorGraph - fixed 3000-node polar grid structure - `zernike_gnn.py`: ZernikeGNN with design-conditioned message passing - `differentiable_zernike.py`: GPU-accelerated Zernike fitting and objectives - `train_zernike_gnn.py`: ZernikeGNNTrainer with multi-task loss - `gnn_optimizer.py`: ZernikeGNNOptimizer for turbo mode (~900k trials/hour) - `extract_displacement_field.py`: OP2 to HDF5 field extraction - `backfill_field_data.py`: Extract fields from existing FEA trials Key innovation: Design-conditioned convolutions that modulate message passing based on structural design parameters, enabling accurate field prediction. ## M1 Mirror Studies ### V12: GNN Field Prediction + FEA Validation - Zernike GNN trained on V10/V11 FEA data (238 samples) - Turbo mode: 5000 GNN predictions → top candidates → FEA validation - Calibration workflow for GNN-to-FEA error correction - Scripts: run_gnn_turbo.py, validate_gnn_best.py, compute_full_calibration.py ### V13: Pure NSGA-II FEA (Ground Truth) - Seeds 217 FEA trials from V11+V12 - Pure multi-objective NSGA-II without any surrogate - Establishes ground-truth Pareto front for GNN accuracy evaluation - Narrowed blank_backface_angle range to [4.0, 5.0] ## Documentation Updates - SYS_14: Added Zernike GNN section with architecture diagrams - CLAUDE.md: Added GNN module reference and quick start - V13 README: Study documentation with seeding strategy 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
536 lines
20 KiB
Python
536 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
M1 Mirror - GNN Turbo Optimization with FEA Validation
|
|
=======================================================
|
|
|
|
Runs fast GNN-based turbo optimization (5000 trials in ~2 min) then
|
|
validates top candidates with actual FEA (~5 min each).
|
|
|
|
Usage:
|
|
python run_gnn_turbo.py # Full workflow: 5000 GNN + 5 FEA validations
|
|
python run_gnn_turbo.py --gnn-only # Just GNN turbo, no FEA
|
|
python run_gnn_turbo.py --validate-top 10 # Validate top 10 instead of 5
|
|
python run_gnn_turbo.py --trials 10000 # More GNN trials
|
|
|
|
Estimated time: ~2 min GNN + ~25 min FEA validation = ~27 min total
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import time
|
|
import argparse
|
|
import logging
|
|
import re
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
# Add parent directories to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from optimization_engine.gnn.gnn_optimizer import ZernikeGNNOptimizer
|
|
from optimization_engine.nx_solver import NXSolver
|
|
from optimization_engine.utils import ensure_nx_running
|
|
from optimization_engine.gnn.extract_displacement_field import (
|
|
extract_displacement_field, save_field_to_hdf5
|
|
)
|
|
from optimization_engine.extractors.extract_zernike_surface import extract_surface_zernike
|
|
|
|
# ============================================================================
|
|
# Paths
|
|
# ============================================================================
|
|
|
|
STUDY_DIR = Path(__file__).parent
|
|
SETUP_DIR = STUDY_DIR / "1_setup"
|
|
MODEL_DIR = SETUP_DIR / "model"
|
|
CONFIG_PATH = SETUP_DIR / "optimization_config.json"
|
|
CHECKPOINT_PATH = Path("C:/Users/Antoine/Atomizer/zernike_gnn_checkpoint.pt")
|
|
RESULTS_DIR = STUDY_DIR / "gnn_turbo_results"
|
|
LOG_FILE = STUDY_DIR / "gnn_turbo.log"
|
|
|
|
# Ensure directories exist
|
|
RESULTS_DIR.mkdir(exist_ok=True)
|
|
|
|
# ============================================================================
|
|
# Logging
|
|
# ============================================================================
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s | %(levelname)-8s | %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(sys.stdout),
|
|
logging.FileHandler(LOG_FILE, mode='a')
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# GNN Turbo Runner
|
|
# ============================================================================
|
|
|
|
class GNNTurboRunner:
|
|
"""
|
|
Run GNN turbo optimization with optional FEA validation.
|
|
|
|
Workflow:
|
|
1. Load trained GNN model
|
|
2. Run fast turbo optimization (5000 trials in ~2 min)
|
|
3. Extract Pareto front and top candidates per objective
|
|
4. Validate selected candidates with actual FEA
|
|
5. Report GNN vs FEA accuracy
|
|
"""
|
|
|
|
def __init__(self, config_path: Path, checkpoint_path: Path):
|
|
logger.info("=" * 60)
|
|
logger.info("GNN TURBO OPTIMIZER")
|
|
logger.info("=" * 60)
|
|
|
|
# Load config
|
|
with open(config_path) as f:
|
|
self.config = json.load(f)
|
|
|
|
# Load GNN optimizer
|
|
logger.info(f"Loading GNN from {checkpoint_path}")
|
|
self.gnn = ZernikeGNNOptimizer.from_checkpoint(checkpoint_path, config_path)
|
|
logger.info(f"GNN loaded. Design variables: {self.gnn.design_names}")
|
|
logger.info(f"disp_scale: {self.gnn.disp_scale}")
|
|
|
|
# Design variable info
|
|
self.design_vars = [v for v in self.config['design_variables'] if v.get('enabled', True)]
|
|
self.objectives = self.config['objectives']
|
|
self.objective_names = [obj['name'] for obj in self.objectives]
|
|
|
|
# NX Solver for FEA validation
|
|
self.nx_solver = None # Lazy init
|
|
|
|
def _init_nx_solver(self):
|
|
"""Initialize NX solver only when needed (for FEA validation)."""
|
|
if self.nx_solver is not None:
|
|
return
|
|
|
|
nx_settings = self.config.get('nx_settings', {})
|
|
nx_install_dir = nx_settings.get('nx_install_path', 'C:\\Program Files\\Siemens\\NX2506')
|
|
version_match = re.search(r'NX(\d+)', nx_install_dir)
|
|
nastran_version = version_match.group(1) if version_match else "2506"
|
|
|
|
self.nx_solver = NXSolver(
|
|
master_model_dir=str(MODEL_DIR),
|
|
nx_install_dir=nx_install_dir,
|
|
nastran_version=nastran_version,
|
|
timeout=nx_settings.get('simulation_timeout_s', 600),
|
|
use_iteration_folders=True,
|
|
study_name="m1_mirror_adaptive_V12_gnn_validation"
|
|
)
|
|
|
|
# Ensure NX is running
|
|
ensure_nx_running(nx_install_dir)
|
|
|
|
def run_turbo(self, n_trials: int = 5000) -> dict:
|
|
"""
|
|
Run GNN turbo optimization.
|
|
|
|
Returns dict with:
|
|
- all_predictions: List of all predictions
|
|
- pareto_front: Pareto-optimal designs
|
|
- best_per_objective: Best design for each objective
|
|
"""
|
|
logger.info(f"\nRunning turbo optimization ({n_trials} trials)...")
|
|
start_time = time.time()
|
|
|
|
results = self.gnn.turbo_optimize(n_trials=n_trials, verbose=True)
|
|
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"Turbo completed in {elapsed:.1f}s ({n_trials/elapsed:.0f} trials/sec)")
|
|
|
|
# Get Pareto front
|
|
pareto = results.get_pareto_front()
|
|
logger.info(f"Found {len(pareto)} Pareto-optimal designs")
|
|
|
|
# Get best per objective
|
|
best_per_obj = {}
|
|
for obj_name in self.objective_names:
|
|
best = results.get_best(n=1, objective=obj_name)[0]
|
|
best_per_obj[obj_name] = best
|
|
logger.info(f"Best {obj_name}: {best.objectives[obj_name]:.2f} nm")
|
|
|
|
return {
|
|
'results': results,
|
|
'pareto': pareto,
|
|
'best_per_objective': best_per_obj,
|
|
'elapsed_time': elapsed
|
|
}
|
|
|
|
def select_validation_candidates(self, turbo_results: dict, n_validate: int = 5) -> list:
|
|
"""
|
|
Select diverse candidates for FEA validation.
|
|
|
|
Strategy: Select from Pareto front with diversity preference.
|
|
If Pareto front has fewer than n_validate, add best per objective.
|
|
"""
|
|
candidates = []
|
|
seen_designs = set()
|
|
|
|
pareto = turbo_results['pareto']
|
|
best_per_obj = turbo_results['best_per_objective']
|
|
|
|
# First, add best per objective (most important to validate)
|
|
for obj_name, pred in best_per_obj.items():
|
|
design_key = tuple(round(v, 4) for v in pred.design.values())
|
|
if design_key not in seen_designs:
|
|
candidates.append({
|
|
'design': pred.design,
|
|
'gnn_objectives': pred.objectives,
|
|
'source': f'best_{obj_name}'
|
|
})
|
|
seen_designs.add(design_key)
|
|
|
|
if len(candidates) >= n_validate:
|
|
break
|
|
|
|
# Fill remaining slots from Pareto front
|
|
if len(candidates) < n_validate and len(pareto) > 0:
|
|
# Sort Pareto by sum of objectives (balanced designs)
|
|
pareto_sorted = sorted(pareto,
|
|
key=lambda p: sum(p.objectives.values()))
|
|
|
|
for pred in pareto_sorted:
|
|
design_key = tuple(round(v, 4) for v in pred.design.values())
|
|
if design_key not in seen_designs:
|
|
candidates.append({
|
|
'design': pred.design,
|
|
'gnn_objectives': pred.objectives,
|
|
'source': 'pareto_front'
|
|
})
|
|
seen_designs.add(design_key)
|
|
|
|
if len(candidates) >= n_validate:
|
|
break
|
|
|
|
logger.info(f"Selected {len(candidates)} candidates for FEA validation:")
|
|
for i, c in enumerate(candidates):
|
|
logger.info(f" {i+1}. {c['source']}: 40vs20={c['gnn_objectives']['rel_filtered_rms_40_vs_20']:.2f} nm")
|
|
|
|
return candidates
|
|
|
|
def run_fea_validation(self, design: dict, trial_num: int) -> dict:
|
|
"""
|
|
Run FEA for a single design and extract Zernike objectives.
|
|
|
|
Returns dict with success status and FEA objectives.
|
|
"""
|
|
self._init_nx_solver()
|
|
|
|
trial_dir = RESULTS_DIR / f"validation_{trial_num:04d}"
|
|
trial_dir.mkdir(exist_ok=True)
|
|
|
|
logger.info(f"Validation {trial_num}: Running FEA...")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Build expression updates
|
|
expressions = {var['expression_name']: design[var['name']]
|
|
for var in self.design_vars}
|
|
|
|
# Create iteration folder
|
|
iter_folder = self.nx_solver.create_iteration_folder(
|
|
iterations_base_dir=RESULTS_DIR / "iterations",
|
|
iteration_number=trial_num,
|
|
expression_updates=expressions
|
|
)
|
|
|
|
# Run solve
|
|
op2_path = self.nx_solver.run_solve(
|
|
sim_file=iter_folder / self.config['nx_settings']['sim_file'],
|
|
solution_name=self.config['nx_settings']['solution_name']
|
|
)
|
|
|
|
if op2_path is None or not Path(op2_path).exists():
|
|
logger.error(f"Validation {trial_num}: FEA solve failed - no OP2")
|
|
return {'success': False, 'error': 'No OP2 file'}
|
|
|
|
# Extract Zernike objectives using the same extractor as training
|
|
bdf_path = iter_folder / "model.bdf"
|
|
if not bdf_path.exists():
|
|
bdf_files = list(iter_folder.glob("*.bdf"))
|
|
bdf_path = bdf_files[0] if bdf_files else None
|
|
|
|
# Use extract_surface_zernike to get objectives
|
|
zernike_result = extract_surface_zernike(
|
|
op2_path=str(op2_path),
|
|
bdf_path=str(bdf_path),
|
|
n_modes=50,
|
|
r_inner=100.0,
|
|
r_outer=650.0,
|
|
n_radial=50,
|
|
n_angular=60
|
|
)
|
|
|
|
if not zernike_result.get('success', False):
|
|
logger.error(f"Validation {trial_num}: Zernike extraction failed")
|
|
return {'success': False, 'error': zernike_result.get('error', 'Unknown')}
|
|
|
|
# Compute relative objectives (same as GNN training data)
|
|
objectives = self._compute_relative_objectives(zernike_result)
|
|
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"Validation {trial_num}: Completed in {elapsed:.1f}s")
|
|
logger.info(f" 40vs20: {objectives['rel_filtered_rms_40_vs_20']:.2f} nm")
|
|
logger.info(f" 60vs20: {objectives['rel_filtered_rms_60_vs_20']:.2f} nm")
|
|
logger.info(f" mfg90: {objectives['mfg_90_optician_workload']:.2f} nm")
|
|
|
|
# Save results
|
|
results = {
|
|
'success': True,
|
|
'design': design,
|
|
'objectives': objectives,
|
|
'op2_path': str(op2_path),
|
|
'elapsed_time': elapsed
|
|
}
|
|
|
|
with open(trial_dir / 'fea_result.json', 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Validation {trial_num}: Error - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
def _compute_relative_objectives(self, zernike_result: dict) -> dict:
|
|
"""
|
|
Compute relative Zernike objectives from extraction result.
|
|
|
|
Matches the exact computation used in GNN training data preparation.
|
|
"""
|
|
coeffs = zernike_result['data']['coefficients'] # Dict by subcase
|
|
|
|
# Subcase mapping: 1=90deg, 2=20deg(ref), 3=40deg, 4=60deg
|
|
subcases = ['1', '2', '3', '4']
|
|
|
|
# Convert coefficients to arrays
|
|
coeff_arrays = {}
|
|
for sc in subcases:
|
|
if sc in coeffs:
|
|
coeff_arrays[sc] = np.array(coeffs[sc])
|
|
|
|
# Objective 1: rel_filtered_rms_40_vs_20
|
|
# Relative = subcase 3 (40deg) - subcase 2 (20deg ref)
|
|
# Filter: remove J1-J4 (first 4 modes)
|
|
rel_40_vs_20 = coeff_arrays['3'] - coeff_arrays['2']
|
|
rel_40_vs_20_filtered = rel_40_vs_20[4:] # Skip J1-J4
|
|
rms_40_vs_20 = np.sqrt(np.sum(rel_40_vs_20_filtered ** 2))
|
|
|
|
# Objective 2: rel_filtered_rms_60_vs_20
|
|
rel_60_vs_20 = coeff_arrays['4'] - coeff_arrays['2']
|
|
rel_60_vs_20_filtered = rel_60_vs_20[4:] # Skip J1-J4
|
|
rms_60_vs_20 = np.sqrt(np.sum(rel_60_vs_20_filtered ** 2))
|
|
|
|
# Objective 3: mfg_90_optician_workload (J1-J3 filtered, keep J4 defocus)
|
|
rel_90_vs_20 = coeff_arrays['1'] - coeff_arrays['2']
|
|
rel_90_vs_20_filtered = rel_90_vs_20[3:] # Skip only J1-J3 (keep J4 defocus)
|
|
rms_mfg_90 = np.sqrt(np.sum(rel_90_vs_20_filtered ** 2))
|
|
|
|
return {
|
|
'rel_filtered_rms_40_vs_20': float(rms_40_vs_20),
|
|
'rel_filtered_rms_60_vs_20': float(rms_60_vs_20),
|
|
'mfg_90_optician_workload': float(rms_mfg_90)
|
|
}
|
|
|
|
def compare_results(self, candidates: list) -> dict:
|
|
"""
|
|
Compare GNN predictions vs FEA results.
|
|
|
|
Returns accuracy statistics.
|
|
"""
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("GNN vs FEA COMPARISON")
|
|
logger.info("=" * 60)
|
|
|
|
errors = {obj: [] for obj in self.objective_names}
|
|
|
|
for c in candidates:
|
|
if 'fea_objectives' not in c or not c.get('fea_success', False):
|
|
continue
|
|
|
|
gnn = c['gnn_objectives']
|
|
fea = c['fea_objectives']
|
|
|
|
logger.info(f"\n{c['source']}:")
|
|
logger.info(f" {'Objective':<30} {'GNN':<10} {'FEA':<10} {'Error':<10}")
|
|
logger.info(f" {'-'*60}")
|
|
|
|
for obj in self.objective_names:
|
|
gnn_val = gnn[obj]
|
|
fea_val = fea[obj]
|
|
error_pct = abs(gnn_val - fea_val) / fea_val * 100 if fea_val > 0 else 0
|
|
|
|
logger.info(f" {obj:<30} {gnn_val:<10.2f} {fea_val:<10.2f} {error_pct:<10.1f}%")
|
|
errors[obj].append(error_pct)
|
|
|
|
# Summary statistics
|
|
logger.info("\n" + "-" * 60)
|
|
logger.info("SUMMARY STATISTICS")
|
|
logger.info("-" * 60)
|
|
|
|
summary = {}
|
|
for obj in self.objective_names:
|
|
if errors[obj]:
|
|
mean_err = np.mean(errors[obj])
|
|
max_err = np.max(errors[obj])
|
|
summary[obj] = {'mean_error_pct': mean_err, 'max_error_pct': max_err}
|
|
logger.info(f"{obj}: Mean error = {mean_err:.1f}%, Max error = {max_err:.1f}%")
|
|
|
|
return summary
|
|
|
|
def run_full_workflow(self, n_trials: int = 5000, n_validate: int = 5, gnn_only: bool = False):
|
|
"""
|
|
Run complete workflow: GNN turbo → select candidates → FEA validation → comparison.
|
|
"""
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Phase 1: GNN Turbo
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("PHASE 1: GNN TURBO OPTIMIZATION")
|
|
logger.info("=" * 60)
|
|
|
|
turbo_results = self.run_turbo(n_trials=n_trials)
|
|
|
|
# Save turbo results
|
|
turbo_summary = {
|
|
'timestamp': timestamp,
|
|
'n_trials': n_trials,
|
|
'n_pareto': len(turbo_results['pareto']),
|
|
'elapsed_time': turbo_results['elapsed_time'],
|
|
'best_per_objective': {
|
|
obj: {
|
|
'design': pred.design,
|
|
'objectives': pred.objectives
|
|
}
|
|
for obj, pred in turbo_results['best_per_objective'].items()
|
|
},
|
|
'pareto_front': [
|
|
{'design': p.design, 'objectives': p.objectives}
|
|
for p in turbo_results['pareto'][:20] # Top 20 from Pareto
|
|
]
|
|
}
|
|
|
|
turbo_file = RESULTS_DIR / f'turbo_results_{timestamp}.json'
|
|
with open(turbo_file, 'w') as f:
|
|
json.dump(turbo_summary, f, indent=2)
|
|
logger.info(f"Turbo results saved to {turbo_file}")
|
|
|
|
if gnn_only:
|
|
logger.info("\n--gnn-only flag set, skipping FEA validation")
|
|
return {'turbo': turbo_summary}
|
|
|
|
# Phase 2: FEA Validation
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("PHASE 2: FEA VALIDATION")
|
|
logger.info("=" * 60)
|
|
|
|
candidates = self.select_validation_candidates(turbo_results, n_validate=n_validate)
|
|
|
|
for i, candidate in enumerate(candidates):
|
|
logger.info(f"\n--- Validating candidate {i+1}/{len(candidates)} ---")
|
|
fea_result = self.run_fea_validation(candidate['design'], trial_num=i+1)
|
|
candidate['fea_success'] = fea_result.get('success', False)
|
|
if fea_result.get('success'):
|
|
candidate['fea_objectives'] = fea_result['objectives']
|
|
candidate['fea_time'] = fea_result.get('elapsed_time', 0)
|
|
|
|
# Phase 3: Comparison
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("PHASE 3: RESULTS COMPARISON")
|
|
logger.info("=" * 60)
|
|
|
|
comparison = self.compare_results(candidates)
|
|
|
|
# Save final report
|
|
final_report = {
|
|
'timestamp': timestamp,
|
|
'turbo_summary': turbo_summary,
|
|
'validation_candidates': [
|
|
{
|
|
'source': c['source'],
|
|
'design': c['design'],
|
|
'gnn_objectives': c['gnn_objectives'],
|
|
'fea_objectives': c.get('fea_objectives'),
|
|
'fea_success': c.get('fea_success', False),
|
|
'fea_time': c.get('fea_time')
|
|
}
|
|
for c in candidates
|
|
],
|
|
'accuracy_summary': comparison
|
|
}
|
|
|
|
report_file = RESULTS_DIR / f'gnn_turbo_report_{timestamp}.json'
|
|
with open(report_file, 'w') as f:
|
|
json.dump(final_report, f, indent=2)
|
|
logger.info(f"\nFinal report saved to {report_file}")
|
|
|
|
# Print final summary
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("WORKFLOW COMPLETE")
|
|
logger.info("=" * 60)
|
|
logger.info(f"GNN Turbo: {n_trials} trials in {turbo_results['elapsed_time']:.1f}s")
|
|
logger.info(f"Pareto front: {len(turbo_results['pareto'])} designs")
|
|
|
|
successful_validations = sum(1 for c in candidates if c.get('fea_success', False))
|
|
logger.info(f"FEA Validations: {successful_validations}/{len(candidates)} successful")
|
|
|
|
if comparison:
|
|
avg_errors = [np.mean([comparison[obj]['mean_error_pct'] for obj in comparison])]
|
|
logger.info(f"Overall GNN accuracy: {100 - np.mean(avg_errors):.1f}%")
|
|
|
|
return final_report
|
|
|
|
|
|
# ============================================================================
|
|
# Main
|
|
# ============================================================================
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="GNN Turbo Optimization with FEA Validation",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__
|
|
)
|
|
parser.add_argument('--trials', type=int, default=5000,
|
|
help='Number of GNN turbo trials (default: 5000)')
|
|
parser.add_argument('--validate-top', type=int, default=5,
|
|
help='Number of top candidates to validate with FEA (default: 5)')
|
|
parser.add_argument('--gnn-only', action='store_true',
|
|
help='Run only GNN turbo, skip FEA validation')
|
|
parser.add_argument('--checkpoint', type=str, default=str(CHECKPOINT_PATH),
|
|
help='Path to GNN checkpoint')
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info(f"Starting GNN Turbo Optimization")
|
|
logger.info(f" Checkpoint: {args.checkpoint}")
|
|
logger.info(f" GNN trials: {args.trials}")
|
|
logger.info(f" FEA validations: {args.validate_top if not args.gnn_only else 'SKIP'}")
|
|
|
|
runner = GNNTurboRunner(
|
|
config_path=CONFIG_PATH,
|
|
checkpoint_path=Path(args.checkpoint)
|
|
)
|
|
|
|
report = runner.run_full_workflow(
|
|
n_trials=args.trials,
|
|
n_validate=args.validate_top,
|
|
gnn_only=args.gnn_only
|
|
)
|
|
|
|
logger.info("\nDone!")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|