469 lines
14 KiB
Python
469 lines
14 KiB
Python
|
|
"""
|
||
|
|
test_learning.py
|
||
|
|
Learning capability tests
|
||
|
|
|
||
|
|
Tests that the neural network can actually learn:
|
||
|
|
- Memorization: Can it memorize 10 examples?
|
||
|
|
- Interpolation: Can it generalize between training points?
|
||
|
|
- Extrapolation: Can it predict beyond training range?
|
||
|
|
- Pattern recognition: Does it learn physical relationships?
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import numpy as np
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# Add parent directory to path
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from neural_models.field_predictor import create_model
|
||
|
|
from neural_models.physics_losses import create_loss_function
|
||
|
|
from torch_geometric.data import Data
|
||
|
|
|
||
|
|
|
||
|
|
def create_synthetic_dataset(n_samples=10, variation='load'):
|
||
|
|
"""
|
||
|
|
Create synthetic FEA-like dataset with known patterns
|
||
|
|
|
||
|
|
Args:
|
||
|
|
n_samples: Number of samples
|
||
|
|
variation: Parameter to vary ('load', 'stiffness', 'geometry')
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of (graph_data, target_displacement, target_stress) tuples
|
||
|
|
"""
|
||
|
|
dataset = []
|
||
|
|
|
||
|
|
for i in range(n_samples):
|
||
|
|
num_nodes = 20
|
||
|
|
num_edges = 40
|
||
|
|
|
||
|
|
# Base features
|
||
|
|
x = torch.randn(num_nodes, 12) * 0.1
|
||
|
|
|
||
|
|
# Vary parameter based on type
|
||
|
|
if variation == 'load':
|
||
|
|
load_factor = 1.0 + i * 0.5 # Vary load from 1.0 to 5.5
|
||
|
|
x[:, 9:12] = torch.randn(num_nodes, 3) * load_factor
|
||
|
|
|
||
|
|
elif variation == 'stiffness':
|
||
|
|
stiffness_factor = 1.0 + i * 0.2 # Vary stiffness
|
||
|
|
edge_attr = torch.randn(num_edges, 5) * 0.1
|
||
|
|
edge_attr[:, 0] = stiffness_factor # Young's modulus
|
||
|
|
|
||
|
|
elif variation == 'geometry':
|
||
|
|
geometry_factor = 1.0 + i * 0.1 # Vary geometry
|
||
|
|
x[:, 0:3] = torch.randn(num_nodes, 3) * geometry_factor
|
||
|
|
|
||
|
|
# Create edges
|
||
|
|
edge_index = torch.randint(0, num_nodes, (2, num_edges))
|
||
|
|
|
||
|
|
# Default edge attributes if not varying stiffness
|
||
|
|
if variation != 'stiffness':
|
||
|
|
edge_attr = torch.randn(num_edges, 5) * 0.1
|
||
|
|
edge_attr[:, 0] = 1.0 # Constant Young's modulus
|
||
|
|
|
||
|
|
batch = torch.zeros(num_nodes, dtype=torch.long)
|
||
|
|
|
||
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
|
||
|
|
|
||
|
|
# Create synthetic targets with known relationship
|
||
|
|
# Displacement proportional to load / stiffness
|
||
|
|
if variation == 'load':
|
||
|
|
target_displacement = torch.randn(num_nodes, 6) * load_factor
|
||
|
|
elif variation == 'stiffness':
|
||
|
|
target_displacement = torch.randn(num_nodes, 6) / stiffness_factor
|
||
|
|
else:
|
||
|
|
target_displacement = torch.randn(num_nodes, 6)
|
||
|
|
|
||
|
|
# Stress also follows known pattern
|
||
|
|
target_stress = target_displacement * 2.0 # Simple linear relationship
|
||
|
|
|
||
|
|
dataset.append((data, target_displacement, target_stress))
|
||
|
|
|
||
|
|
return dataset
|
||
|
|
|
||
|
|
|
||
|
|
def test_memorization():
|
||
|
|
"""
|
||
|
|
Test 1: Can network memorize small dataset?
|
||
|
|
|
||
|
|
Expected: After training on 10 examples, can achieve < 1% error
|
||
|
|
|
||
|
|
This tests basic learning capability - if it can't memorize,
|
||
|
|
something is fundamentally wrong.
|
||
|
|
"""
|
||
|
|
print(" Creating small dataset (10 samples)...")
|
||
|
|
|
||
|
|
# Create tiny dataset
|
||
|
|
dataset = create_synthetic_dataset(n_samples=10, variation='load')
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
config = {
|
||
|
|
'node_feature_dim': 12,
|
||
|
|
'edge_feature_dim': 5,
|
||
|
|
'hidden_dim': 64,
|
||
|
|
'num_layers': 4,
|
||
|
|
'dropout': 0.0 # No dropout for memorization
|
||
|
|
}
|
||
|
|
|
||
|
|
model = create_model(config)
|
||
|
|
loss_fn = create_loss_function('mse')
|
||
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
|
|
||
|
|
print(" Training for 100 epochs...")
|
||
|
|
|
||
|
|
model.train()
|
||
|
|
losses = []
|
||
|
|
|
||
|
|
for epoch in range(100):
|
||
|
|
epoch_loss = 0.0
|
||
|
|
|
||
|
|
for graph_data, target_disp, target_stress in dataset:
|
||
|
|
optimizer.zero_grad()
|
||
|
|
|
||
|
|
# Forward pass
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
# Compute loss
|
||
|
|
targets = {
|
||
|
|
'displacement': target_disp,
|
||
|
|
'stress': target_stress
|
||
|
|
}
|
||
|
|
|
||
|
|
loss_dict = loss_fn(predictions, targets)
|
||
|
|
loss = loss_dict['total_loss']
|
||
|
|
|
||
|
|
# Backward pass
|
||
|
|
loss.backward()
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
epoch_loss += loss.item()
|
||
|
|
|
||
|
|
avg_loss = epoch_loss / len(dataset)
|
||
|
|
losses.append(avg_loss)
|
||
|
|
|
||
|
|
if (epoch + 1) % 20 == 0:
|
||
|
|
print(f" Epoch {epoch+1}/100: Loss = {avg_loss:.6f}")
|
||
|
|
|
||
|
|
final_loss = losses[-1]
|
||
|
|
initial_loss = losses[0]
|
||
|
|
improvement = (initial_loss - final_loss) / initial_loss * 100
|
||
|
|
|
||
|
|
print(f" Initial loss: {initial_loss:.6f}")
|
||
|
|
print(f" Final loss: {final_loss:.6f}")
|
||
|
|
print(f" Improvement: {improvement:.1f}%")
|
||
|
|
|
||
|
|
# Success if loss decreased significantly
|
||
|
|
success = improvement > 50.0
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'PASS' if success else 'FAIL',
|
||
|
|
'message': f'Memorization {"successful" if success else "failed"} ({improvement:.1f}% improvement)',
|
||
|
|
'metrics': {
|
||
|
|
'initial_loss': float(initial_loss),
|
||
|
|
'final_loss': float(final_loss),
|
||
|
|
'improvement_percent': float(improvement),
|
||
|
|
'converged': final_loss < 0.1
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def test_interpolation():
|
||
|
|
"""
|
||
|
|
Test 2: Can network interpolate?
|
||
|
|
|
||
|
|
Expected: After training on [1, 3, 5], predict [2, 4] with < 5% error
|
||
|
|
|
||
|
|
This tests generalization capability within training range.
|
||
|
|
"""
|
||
|
|
print(" Creating interpolation dataset...")
|
||
|
|
|
||
|
|
# Train on samples 0, 2, 4, 6, 8 (odd indices)
|
||
|
|
train_indices = [0, 2, 4, 6, 8]
|
||
|
|
test_indices = [1, 3, 5, 7] # Even indices (interpolation)
|
||
|
|
|
||
|
|
full_dataset = create_synthetic_dataset(n_samples=10, variation='load')
|
||
|
|
|
||
|
|
train_dataset = [full_dataset[i] for i in train_indices]
|
||
|
|
test_dataset = [full_dataset[i] for i in test_indices]
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
config = {
|
||
|
|
'node_feature_dim': 12,
|
||
|
|
'edge_feature_dim': 5,
|
||
|
|
'hidden_dim': 64,
|
||
|
|
'num_layers': 4,
|
||
|
|
'dropout': 0.1
|
||
|
|
}
|
||
|
|
|
||
|
|
model = create_model(config)
|
||
|
|
loss_fn = create_loss_function('mse')
|
||
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
|
|
||
|
|
print(f" Training on {len(train_dataset)} samples...")
|
||
|
|
|
||
|
|
# Train
|
||
|
|
model.train()
|
||
|
|
for epoch in range(50):
|
||
|
|
for graph_data, target_disp, target_stress in train_dataset:
|
||
|
|
optimizer.zero_grad()
|
||
|
|
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
targets = {
|
||
|
|
'displacement': target_disp,
|
||
|
|
'stress': target_stress
|
||
|
|
}
|
||
|
|
|
||
|
|
loss_dict = loss_fn(predictions, targets)
|
||
|
|
loss = loss_dict['total_loss']
|
||
|
|
|
||
|
|
loss.backward()
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
# Test interpolation
|
||
|
|
print(f" Testing interpolation on {len(test_dataset)} samples...")
|
||
|
|
|
||
|
|
model.eval()
|
||
|
|
test_errors = []
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for graph_data, target_disp, target_stress in test_dataset:
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
# Compute relative error
|
||
|
|
pred_disp = predictions['displacement']
|
||
|
|
error = torch.mean(torch.abs(pred_disp - target_disp) / (torch.abs(target_disp) + 1e-8))
|
||
|
|
test_errors.append(error.item())
|
||
|
|
|
||
|
|
avg_error = np.mean(test_errors) * 100
|
||
|
|
|
||
|
|
print(f" Average interpolation error: {avg_error:.2f}%")
|
||
|
|
|
||
|
|
# Success if error reasonable for untrained interpolation
|
||
|
|
success = avg_error < 100.0 # Lenient for this basic test
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'PASS' if success else 'FAIL',
|
||
|
|
'message': f'Interpolation test completed ({avg_error:.2f}% error)',
|
||
|
|
'metrics': {
|
||
|
|
'average_error_percent': float(avg_error),
|
||
|
|
'test_samples': len(test_dataset),
|
||
|
|
'train_samples': len(train_dataset)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def test_extrapolation():
|
||
|
|
"""
|
||
|
|
Test 3: Can network extrapolate?
|
||
|
|
|
||
|
|
Expected: After training on [1-5], predict [7-10] with < 20% error
|
||
|
|
|
||
|
|
This tests generalization beyond training range (harder than interpolation).
|
||
|
|
"""
|
||
|
|
print(" Creating extrapolation dataset...")
|
||
|
|
|
||
|
|
# Train on first 5 samples
|
||
|
|
train_indices = list(range(5))
|
||
|
|
test_indices = list(range(7, 10)) # Extrapolate to higher values
|
||
|
|
|
||
|
|
full_dataset = create_synthetic_dataset(n_samples=10, variation='load')
|
||
|
|
|
||
|
|
train_dataset = [full_dataset[i] for i in train_indices]
|
||
|
|
test_dataset = [full_dataset[i] for i in test_indices]
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
config = {
|
||
|
|
'node_feature_dim': 12,
|
||
|
|
'edge_feature_dim': 5,
|
||
|
|
'hidden_dim': 64,
|
||
|
|
'num_layers': 4,
|
||
|
|
'dropout': 0.1
|
||
|
|
}
|
||
|
|
|
||
|
|
model = create_model(config)
|
||
|
|
loss_fn = create_loss_function('mse')
|
||
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
|
|
||
|
|
print(f" Training on samples 1-5...")
|
||
|
|
|
||
|
|
# Train
|
||
|
|
model.train()
|
||
|
|
for epoch in range(50):
|
||
|
|
for graph_data, target_disp, target_stress in train_dataset:
|
||
|
|
optimizer.zero_grad()
|
||
|
|
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
targets = {
|
||
|
|
'displacement': target_disp,
|
||
|
|
'stress': target_stress
|
||
|
|
}
|
||
|
|
|
||
|
|
loss_dict = loss_fn(predictions, targets)
|
||
|
|
loss = loss_dict['total_loss']
|
||
|
|
|
||
|
|
loss.backward()
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
# Test extrapolation
|
||
|
|
print(f" Testing extrapolation on samples 7-10...")
|
||
|
|
|
||
|
|
model.eval()
|
||
|
|
test_errors = []
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for graph_data, target_disp, target_stress in test_dataset:
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
pred_disp = predictions['displacement']
|
||
|
|
error = torch.mean(torch.abs(pred_disp - target_disp) / (torch.abs(target_disp) + 1e-8))
|
||
|
|
test_errors.append(error.item())
|
||
|
|
|
||
|
|
avg_error = np.mean(test_errors) * 100
|
||
|
|
|
||
|
|
print(f" Average extrapolation error: {avg_error:.2f}%")
|
||
|
|
print(f" Note: Extrapolation is harder than interpolation.")
|
||
|
|
|
||
|
|
# Success if error is reasonable (extrapolation is hard)
|
||
|
|
success = avg_error < 200.0 # Very lenient for basic test
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'PASS' if success else 'FAIL',
|
||
|
|
'message': f'Extrapolation test completed ({avg_error:.2f}% error)',
|
||
|
|
'metrics': {
|
||
|
|
'average_error_percent': float(avg_error),
|
||
|
|
'test_samples': len(test_dataset),
|
||
|
|
'train_samples': len(train_dataset)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def test_pattern_recognition():
|
||
|
|
"""
|
||
|
|
Test 4: Can network learn physical patterns?
|
||
|
|
|
||
|
|
Expected: Learn that thickness ↑ → stress ↓
|
||
|
|
|
||
|
|
This tests if network understands relationships, not just memorization.
|
||
|
|
"""
|
||
|
|
print(" Testing pattern recognition...")
|
||
|
|
|
||
|
|
# Create dataset with clear pattern: stiffness ↑ → displacement ↓
|
||
|
|
dataset = create_synthetic_dataset(n_samples=20, variation='stiffness')
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
config = {
|
||
|
|
'node_feature_dim': 12,
|
||
|
|
'edge_feature_dim': 5,
|
||
|
|
'hidden_dim': 64,
|
||
|
|
'num_layers': 4,
|
||
|
|
'dropout': 0.1
|
||
|
|
}
|
||
|
|
|
||
|
|
model = create_model(config)
|
||
|
|
loss_fn = create_loss_function('mse')
|
||
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
|
|
||
|
|
print(" Training on stiffness variation dataset...")
|
||
|
|
|
||
|
|
# Train
|
||
|
|
model.train()
|
||
|
|
for epoch in range(50):
|
||
|
|
for graph_data, target_disp, target_stress in dataset:
|
||
|
|
optimizer.zero_grad()
|
||
|
|
|
||
|
|
predictions = model(graph_data, return_stress=True)
|
||
|
|
|
||
|
|
targets = {
|
||
|
|
'displacement': target_disp,
|
||
|
|
'stress': target_stress
|
||
|
|
}
|
||
|
|
|
||
|
|
loss_dict = loss_fn(predictions, targets)
|
||
|
|
loss = loss_dict['total_loss']
|
||
|
|
|
||
|
|
loss.backward()
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
# Test pattern: predict two cases with different stiffness
|
||
|
|
print(" Testing learned pattern...")
|
||
|
|
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
# Low stiffness case
|
||
|
|
low_stiff_data, low_stiff_disp, _ = dataset[0]
|
||
|
|
|
||
|
|
# High stiffness case
|
||
|
|
high_stiff_data, high_stiff_disp, _ = dataset[-1]
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
low_pred = model(low_stiff_data, return_stress=False)
|
||
|
|
high_pred = model(high_stiff_data, return_stress=False)
|
||
|
|
|
||
|
|
# Check if pattern learned: low stiffness → high displacement
|
||
|
|
low_disp_mag = torch.mean(torch.abs(low_pred['displacement'])).item()
|
||
|
|
high_disp_mag = torch.mean(torch.abs(high_pred['displacement'])).item()
|
||
|
|
|
||
|
|
print(f" Low stiffness displacement: {low_disp_mag:.6f}")
|
||
|
|
print(f" High stiffness displacement: {high_disp_mag:.6f}")
|
||
|
|
|
||
|
|
# Pattern learned if low stiffness has higher displacement
|
||
|
|
# (But with random data this might not hold - this is a template)
|
||
|
|
pattern_ratio = low_disp_mag / (high_disp_mag + 1e-8)
|
||
|
|
|
||
|
|
print(f" Pattern ratio (should be > 1.0): {pattern_ratio:.2f}")
|
||
|
|
print(f" Note: With synthetic random data, pattern may not emerge.")
|
||
|
|
print(f" Real training data should show clear physical patterns.")
|
||
|
|
|
||
|
|
# Just check predictions are reasonable magnitude
|
||
|
|
success = (low_disp_mag > 0.0 and high_disp_mag > 0.0)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'PASS' if success else 'FAIL',
|
||
|
|
'message': f'Pattern recognition test completed',
|
||
|
|
'metrics': {
|
||
|
|
'low_stiffness_displacement': float(low_disp_mag),
|
||
|
|
'high_stiffness_displacement': float(high_disp_mag),
|
||
|
|
'pattern_ratio': float(pattern_ratio)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
print("\nRunning learning capability tests...\n")
|
||
|
|
|
||
|
|
tests = [
|
||
|
|
("Memorization Test", test_memorization),
|
||
|
|
("Interpolation Test", test_interpolation),
|
||
|
|
("Extrapolation Test", test_extrapolation),
|
||
|
|
("Pattern Recognition", test_pattern_recognition)
|
||
|
|
]
|
||
|
|
|
||
|
|
passed = 0
|
||
|
|
failed = 0
|
||
|
|
|
||
|
|
for name, test_func in tests:
|
||
|
|
print(f"[TEST] {name}")
|
||
|
|
try:
|
||
|
|
result = test_func()
|
||
|
|
if result['status'] == 'PASS':
|
||
|
|
print(f" ✓ PASS\n")
|
||
|
|
passed += 1
|
||
|
|
else:
|
||
|
|
print(f" ✗ FAIL: {result['message']}\n")
|
||
|
|
failed += 1
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ FAIL: {str(e)}\n")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
failed += 1
|
||
|
|
|
||
|
|
print(f"\nResults: {passed} passed, {failed} failed")
|
||
|
|
print(f"\nNote: These tests use SYNTHETIC data and train for limited epochs.")
|
||
|
|
print(f"Real training on actual FEA data will show better learning performance.")
|