""" Neural Network Only Optimization This script runs multi-objective optimization using ONLY the neural network surrogate (no FEA). This demonstrates the speed improvement from NN predictions. Objectives: - Minimize mass - Maximize frequency (minimize -frequency) """ import sys from pathlib import Path import time import json import optuna from optuna.samplers import NSGAIISampler import numpy as np # Add project paths project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'atomizer-field')) from optimization_engine.processors.surrogates.simple_mlp_surrogate import SimpleSurrogate def main(): print("="*60) print("Neural Network Only Optimization (Simple MLP)") print("="*60) # Load surrogate print("\n[1] Loading neural surrogate...") model_path = project_root / "simple_mlp_surrogate.pt" if not model_path.exists(): print(f"ERROR: Model not found at {model_path}") print("Run 'python optimization_engine/simple_mlp_surrogate.py' first to train") return surrogate = SimpleSurrogate.load(model_path) if not surrogate: print("ERROR: Could not load neural surrogate") return print(f" Design variables: {surrogate.design_var_names}") # Define bounds (from UAV arm study) bounds = { 'beam_half_core_thickness': (1.0, 5.0), 'beam_face_thickness': (0.5, 3.0), 'holes_diameter': (0.5, 5.0), 'hole_count': (0.0, 6.0) } print(f" Bounds: {bounds}") # Create Optuna study print("\n[2] Creating Optuna study...") storage_path = project_root / "nn_only_optimization_study.db" # Remove old study if exists if storage_path.exists(): storage_path.unlink() storage = optuna.storages.RDBStorage(f"sqlite:///{storage_path}") study = optuna.create_study( study_name="nn_only_optimization", storage=storage, directions=["minimize", "minimize"], # mass, -frequency (minimize both) sampler=NSGAIISampler() ) # Track stats start_time = time.time() trial_times = [] def objective(trial: optuna.Trial): trial_start = time.time() # Suggest parameters params = {} for name, (low, high) in bounds.items(): if name == 'hole_count': params[name] = trial.suggest_int(name, int(low), int(high)) else: params[name] = trial.suggest_float(name, low, high) # Predict with NN results = surrogate.predict(params) mass = results['mass'] frequency = results['frequency'] trial_time = (time.time() - trial_start) * 1000 trial_times.append(trial_time) # Log progress every 100 trials if trial.number % 100 == 0: print(f" Trial {trial.number}: mass={mass:.1f}g, freq={frequency:.2f}Hz, time={trial_time:.1f}ms") # Return objectives: minimize mass, minimize -frequency (= maximize frequency) return mass, -frequency # Run optimization n_trials = 1000 # Much faster with NN! print(f"\n[3] Running {n_trials} trials...") study.optimize(objective, n_trials=n_trials, show_progress_bar=True) total_time = time.time() - start_time # Results print("\n" + "="*60) print("RESULTS") print("="*60) print(f"\nTotal time: {total_time:.1f}s for {n_trials} trials") print(f"Average time per trial: {np.mean(trial_times):.1f}ms") print(f"Trials per second: {n_trials/total_time:.1f}") # Get Pareto front pareto_front = study.best_trials print(f"\nPareto front size: {len(pareto_front)} designs") print("\nTop 5 Pareto-optimal designs:") for i, trial in enumerate(pareto_front[:5]): mass = trial.values[0] freq = -trial.values[1] # Convert back to positive print(f" {i+1}. Mass={mass:.1f}g, Freq={freq:.2f}Hz") print(f" Params: {trial.params}") # Save results results_file = project_root / "nn_optimization_results.json" results = { 'n_trials': n_trials, 'total_time_s': total_time, 'avg_trial_time_ms': np.mean(trial_times), 'trials_per_second': n_trials/total_time, 'pareto_front_size': len(pareto_front), 'pareto_designs': [ { 'mass': t.values[0], 'frequency': -t.values[1], 'params': t.params } for t in pareto_front ] } with open(results_file, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to: {results_file}") print(f"Study database: {storage_path}") print("\nView in Optuna dashboard:") print(f" optuna-dashboard sqlite:///{storage_path}") if __name__ == "__main__": main()