62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
|
|
"""Test parametric surrogate integration."""
|
||
|
|
|
||
|
|
import time
|
||
|
|
from optimization_engine.neural_surrogate import create_parametric_surrogate_for_study
|
||
|
|
|
||
|
|
print("Testing Parametric Neural Surrogate")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
# Create surrogate with auto-detection
|
||
|
|
surrogate = create_parametric_surrogate_for_study()
|
||
|
|
|
||
|
|
if surrogate is None:
|
||
|
|
print("ERROR: Failed to create surrogate")
|
||
|
|
exit(1)
|
||
|
|
|
||
|
|
print(f"Surrogate created successfully!")
|
||
|
|
print(f" Device: {surrogate.device}")
|
||
|
|
print(f" Nodes: {surrogate.num_nodes}")
|
||
|
|
print(f" Model val_loss: {surrogate.best_val_loss:.4f}")
|
||
|
|
print(f" Design vars: {surrogate.design_var_names}")
|
||
|
|
|
||
|
|
# Test prediction with example params
|
||
|
|
test_params = {
|
||
|
|
"beam_half_core_thickness": 7.0,
|
||
|
|
"beam_face_thickness": 2.5,
|
||
|
|
"holes_diameter": 35.0,
|
||
|
|
"hole_count": 10.0
|
||
|
|
}
|
||
|
|
|
||
|
|
print(f"\nTest prediction with params: {test_params}")
|
||
|
|
results = surrogate.predict(test_params)
|
||
|
|
|
||
|
|
print(f"\nResults:")
|
||
|
|
print(f" Mass: {results['mass']:.2f} g")
|
||
|
|
print(f" Frequency: {results['frequency']:.2f} Hz")
|
||
|
|
print(f" Max displacement: {results['max_displacement']:.6f} mm")
|
||
|
|
print(f" Max stress: {results['max_stress']:.2f} MPa")
|
||
|
|
print(f" Inference time: {results['inference_time_ms']:.2f} ms")
|
||
|
|
|
||
|
|
# Speed test
|
||
|
|
n = 100
|
||
|
|
start = time.time()
|
||
|
|
for _ in range(n):
|
||
|
|
surrogate.predict(test_params)
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
print(f"\nSpeed test: {n} predictions in {elapsed:.3f}s")
|
||
|
|
print(f" Average: {elapsed/n*1000:.2f} ms per prediction")
|
||
|
|
|
||
|
|
# Compare with training data range
|
||
|
|
print(f"\nExpected range (from training data):")
|
||
|
|
print(f" Mass: ~2808 - 5107 g")
|
||
|
|
print(f" Frequency: ~15.8 - 21.9 Hz")
|
||
|
|
print(f" Max displacement: ~0.02-0.03 mm")
|
||
|
|
|
||
|
|
stats = surrogate.get_statistics()
|
||
|
|
print(f"\nStatistics:")
|
||
|
|
print(f" Total predictions: {stats['total_predictions']}")
|
||
|
|
print(f" Average time: {stats['average_time_ms']:.2f} ms")
|
||
|
|
|
||
|
|
print("\nParametric surrogate test PASSED!")
|