Files
Atomizer/archive/scripts/run_nn_optimization.py

164 lines
4.7 KiB
Python
Raw Normal View History

"""
Neural Network Only Optimization
This script runs multi-objective optimization using ONLY the neural network
surrogate (no FEA). This demonstrates the speed improvement from NN predictions.
Objectives:
- Minimize mass
- Maximize frequency (minimize -frequency)
"""
import sys
from pathlib import Path
import time
import json
import optuna
from optuna.samplers import NSGAIISampler
import numpy as np
# Add project paths
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'atomizer-field'))
from optimization_engine.processors.surrogates.simple_mlp_surrogate import SimpleSurrogate
def main():
print("="*60)
print("Neural Network Only Optimization (Simple MLP)")
print("="*60)
# Load surrogate
print("\n[1] Loading neural surrogate...")
model_path = project_root / "simple_mlp_surrogate.pt"
if not model_path.exists():
print(f"ERROR: Model not found at {model_path}")
print("Run 'python optimization_engine/simple_mlp_surrogate.py' first to train")
return
surrogate = SimpleSurrogate.load(model_path)
if not surrogate:
print("ERROR: Could not load neural surrogate")
return
print(f" Design variables: {surrogate.design_var_names}")
# Define bounds (from UAV arm study)
bounds = {
'beam_half_core_thickness': (1.0, 5.0),
'beam_face_thickness': (0.5, 3.0),
'holes_diameter': (0.5, 5.0),
'hole_count': (0.0, 6.0)
}
print(f" Bounds: {bounds}")
# Create Optuna study
print("\n[2] Creating Optuna study...")
storage_path = project_root / "nn_only_optimization_study.db"
# Remove old study if exists
if storage_path.exists():
storage_path.unlink()
storage = optuna.storages.RDBStorage(f"sqlite:///{storage_path}")
study = optuna.create_study(
study_name="nn_only_optimization",
storage=storage,
directions=["minimize", "minimize"], # mass, -frequency (minimize both)
sampler=NSGAIISampler()
)
# Track stats
start_time = time.time()
trial_times = []
def objective(trial: optuna.Trial):
trial_start = time.time()
# Suggest parameters
params = {}
for name, (low, high) in bounds.items():
if name == 'hole_count':
params[name] = trial.suggest_int(name, int(low), int(high))
else:
params[name] = trial.suggest_float(name, low, high)
# Predict with NN
results = surrogate.predict(params)
mass = results['mass']
frequency = results['frequency']
trial_time = (time.time() - trial_start) * 1000
trial_times.append(trial_time)
# Log progress every 100 trials
if trial.number % 100 == 0:
print(f" Trial {trial.number}: mass={mass:.1f}g, freq={frequency:.2f}Hz, time={trial_time:.1f}ms")
# Return objectives: minimize mass, minimize -frequency (= maximize frequency)
return mass, -frequency
# Run optimization
n_trials = 1000 # Much faster with NN!
print(f"\n[3] Running {n_trials} trials...")
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
total_time = time.time() - start_time
# Results
print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"\nTotal time: {total_time:.1f}s for {n_trials} trials")
print(f"Average time per trial: {np.mean(trial_times):.1f}ms")
print(f"Trials per second: {n_trials/total_time:.1f}")
# Get Pareto front
pareto_front = study.best_trials
print(f"\nPareto front size: {len(pareto_front)} designs")
print("\nTop 5 Pareto-optimal designs:")
for i, trial in enumerate(pareto_front[:5]):
mass = trial.values[0]
freq = -trial.values[1] # Convert back to positive
print(f" {i+1}. Mass={mass:.1f}g, Freq={freq:.2f}Hz")
print(f" Params: {trial.params}")
# Save results
results_file = project_root / "nn_optimization_results.json"
results = {
'n_trials': n_trials,
'total_time_s': total_time,
'avg_trial_time_ms': np.mean(trial_times),
'trials_per_second': n_trials/total_time,
'pareto_front_size': len(pareto_front),
'pareto_designs': [
{
'mass': t.values[0],
'frequency': -t.values[1],
'params': t.params
}
for t in pareto_front
]
}
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_file}")
print(f"Study database: {storage_path}")
print("\nView in Optuna dashboard:")
print(f" optuna-dashboard sqlite:///{storage_path}")
if __name__ == "__main__":
main()