Files
Atomizer/atomizer-field/neural_models/uncertainty.py
Antoine d5ffba099e feat: Merge Atomizer-Field neural network module into main repository
Permanently integrates the Atomizer-Field GNN surrogate system:
- neural_models/: Graph Neural Network for FEA field prediction
- batch_parser.py: Parse training data from FEA exports
- train.py: Neural network training pipeline
- predict.py: Inference engine for fast predictions

This enables 600x-2200x speedup over traditional FEA by replacing
expensive simulations with millisecond neural network predictions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 15:31:33 -05:00

362 lines
12 KiB
Python

"""
uncertainty.py
Uncertainty quantification for neural field predictions
AtomizerField Uncertainty Quantification v2.1
Know when to trust predictions and when to run FEA!
Key Features:
- Ensemble-based uncertainty estimation
- Confidence intervals for predictions
- Automatic FEA recommendation
- Online calibration
"""
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from .field_predictor import AtomizerFieldModel
class UncertainFieldPredictor(nn.Module):
"""
Ensemble of models for uncertainty quantification
Uses multiple models trained with different initializations
to estimate prediction uncertainty.
When uncertainty is high → Recommend FEA validation
When uncertainty is low → Trust neural prediction
"""
def __init__(self, base_model_config, n_ensemble=5):
"""
Initialize ensemble
Args:
base_model_config (dict): Configuration for base model
n_ensemble (int): Number of models in ensemble
"""
super().__init__()
print(f"\nCreating ensemble with {n_ensemble} models...")
# Create ensemble of models
self.models = nn.ModuleList([
AtomizerFieldModel(**base_model_config)
for _ in range(n_ensemble)
])
self.n_ensemble = n_ensemble
# Initialize each model differently
for i, model in enumerate(self.models):
self._init_weights(model, seed=i)
print(f"Ensemble created with {n_ensemble} models")
def _init_weights(self, model, seed):
"""Initialize model weights with different seed"""
torch.manual_seed(seed)
def init_fn(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
model.apply(init_fn)
def forward(self, data, return_uncertainty=True, return_all_predictions=False):
"""
Forward pass through ensemble
Args:
data: Input graph data
return_uncertainty (bool): Return uncertainty estimates
return_all_predictions (bool): Return all individual predictions
Returns:
dict: Predictions with uncertainty
- displacement: Mean prediction
- stress: Mean prediction
- von_mises: Mean prediction
- displacement_std: Standard deviation (if return_uncertainty)
- stress_std: Standard deviation (if return_uncertainty)
- von_mises_std: Standard deviation (if return_uncertainty)
- all_predictions: List of all predictions (if return_all_predictions)
"""
# Get predictions from all models
all_predictions = []
for model in self.models:
with torch.no_grad():
pred = model(data, return_stress=True)
all_predictions.append(pred)
# Stack predictions
displacement_stack = torch.stack([p['displacement'] for p in all_predictions])
stress_stack = torch.stack([p['stress'] for p in all_predictions])
von_mises_stack = torch.stack([p['von_mises'] for p in all_predictions])
# Compute mean predictions
results = {
'displacement': displacement_stack.mean(dim=0),
'stress': stress_stack.mean(dim=0),
'von_mises': von_mises_stack.mean(dim=0)
}
# Compute uncertainty (standard deviation across ensemble)
if return_uncertainty:
results['displacement_std'] = displacement_stack.std(dim=0)
results['stress_std'] = stress_stack.std(dim=0)
results['von_mises_std'] = von_mises_stack.std(dim=0)
# Overall uncertainty metrics
results['max_displacement_uncertainty'] = results['displacement_std'].max().item()
results['max_stress_uncertainty'] = results['von_mises_std'].max().item()
# Uncertainty as percentage of prediction
results['displacement_rel_uncertainty'] = (
results['displacement_std'] / (torch.abs(results['displacement']) + 1e-8)
).mean().item()
results['stress_rel_uncertainty'] = (
results['von_mises_std'] / (results['von_mises'] + 1e-8)
).mean().item()
# Return all predictions if requested
if return_all_predictions:
results['all_predictions'] = all_predictions
return results
def needs_fea_validation(self, predictions, threshold=0.1):
"""
Determine if FEA validation is recommended
Args:
predictions (dict): Output from forward() with uncertainty
threshold (float): Relative uncertainty threshold
Returns:
dict: Recommendation and reasons
"""
reasons = []
# Check displacement uncertainty
if predictions['displacement_rel_uncertainty'] > threshold:
reasons.append(
f"High displacement uncertainty: "
f"{predictions['displacement_rel_uncertainty']*100:.1f}% > {threshold*100:.1f}%"
)
# Check stress uncertainty
if predictions['stress_rel_uncertainty'] > threshold:
reasons.append(
f"High stress uncertainty: "
f"{predictions['stress_rel_uncertainty']*100:.1f}% > {threshold*100:.1f}%"
)
recommend_fea = len(reasons) > 0
return {
'recommend_fea': recommend_fea,
'reasons': reasons,
'displacement_uncertainty': predictions['displacement_rel_uncertainty'],
'stress_uncertainty': predictions['stress_rel_uncertainty']
}
def get_confidence_intervals(self, predictions, confidence=0.95):
"""
Compute confidence intervals for predictions
Args:
predictions (dict): Output from forward() with uncertainty
confidence (float): Confidence level (0.95 = 95% confidence)
Returns:
dict: Confidence intervals
"""
# For normal distribution, 95% CI is ±1.96 std
# For 90% CI is ±1.645 std
z_score = {0.90: 1.645, 0.95: 1.96, 0.99: 2.576}.get(confidence, 1.96)
intervals = {}
# Displacement intervals
intervals['displacement_lower'] = predictions['displacement'] - z_score * predictions['displacement_std']
intervals['displacement_upper'] = predictions['displacement'] + z_score * predictions['displacement_std']
# Stress intervals
intervals['von_mises_lower'] = predictions['von_mises'] - z_score * predictions['von_mises_std']
intervals['von_mises_upper'] = predictions['von_mises'] + z_score * predictions['von_mises_std']
# Max values with confidence intervals
max_vm = predictions['von_mises'].max()
max_vm_std = predictions['von_mises_std'].max()
intervals['max_stress_estimate'] = max_vm.item()
intervals['max_stress_lower'] = (max_vm - z_score * max_vm_std).item()
intervals['max_stress_upper'] = (max_vm + z_score * max_vm_std).item()
return intervals
class OnlineLearner:
"""
Online learning from FEA runs during optimization
As optimization progresses and you run FEA for validation,
this module can quickly update the model to improve predictions.
This creates a virtuous cycle:
1. Use neural network for fast exploration
2. Run FEA on promising designs
3. Update neural network with new data
4. Neural network gets better → need less FEA
"""
def __init__(self, model, learning_rate=0.0001):
"""
Initialize online learner
Args:
model: Neural network model
learning_rate (float): Learning rate for updates
"""
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
self.replay_buffer = []
self.update_count = 0
print(f"\nOnline learner initialized")
print(f"Learning rate: {learning_rate}")
def add_fea_result(self, graph_data, fea_results):
"""
Add new FEA result to replay buffer
Args:
graph_data: Mesh graph
fea_results (dict): FEA results (displacement, stress)
"""
self.replay_buffer.append({
'graph_data': graph_data,
'fea_results': fea_results
})
print(f"Added FEA result to buffer (total: {len(self.replay_buffer)})")
def quick_update(self, steps=10):
"""
Quick fine-tuning on recent FEA results
Args:
steps (int): Number of gradient steps
"""
if len(self.replay_buffer) == 0:
print("No data in replay buffer")
return
print(f"\nQuick update: {steps} steps on {len(self.replay_buffer)} samples")
self.model.train()
for step in range(steps):
total_loss = 0.0
# Train on all samples in buffer
for sample in self.replay_buffer:
graph_data = sample['graph_data']
fea_results = sample['fea_results']
# Forward pass
predictions = self.model(graph_data, return_stress=True)
# Compute loss
disp_loss = nn.functional.mse_loss(
predictions['displacement'],
fea_results['displacement']
)
if 'stress' in fea_results:
stress_loss = nn.functional.mse_loss(
predictions['stress'],
fea_results['stress']
)
loss = disp_loss + stress_loss
else:
loss = disp_loss
# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if step % 5 == 0:
avg_loss = total_loss / len(self.replay_buffer)
print(f" Step {step}/{steps}: Loss = {avg_loss:.6f}")
self.model.eval()
self.update_count += 1
print(f"Update complete (total updates: {self.update_count})")
def clear_buffer(self):
"""Clear replay buffer"""
self.replay_buffer = []
print("Replay buffer cleared")
def create_uncertain_predictor(model_config, n_ensemble=5):
"""
Factory function to create uncertain predictor
Args:
model_config (dict): Model configuration
n_ensemble (int): Ensemble size
Returns:
UncertainFieldPredictor instance
"""
return UncertainFieldPredictor(model_config, n_ensemble)
if __name__ == "__main__":
# Test uncertainty quantification
print("Testing Uncertainty Quantification...\n")
# Create ensemble
model_config = {
'node_feature_dim': 12,
'edge_feature_dim': 5,
'hidden_dim': 64,
'num_layers': 4,
'dropout': 0.1
}
ensemble = UncertainFieldPredictor(model_config, n_ensemble=3)
print(f"\nEnsemble created with {ensemble.n_ensemble} models")
print("Uncertainty quantification ready!")
print("\nUsage:")
print("""
# Get predictions with uncertainty
predictions = ensemble(graph_data, return_uncertainty=True)
# Check if FEA validation needed
recommendation = ensemble.needs_fea_validation(predictions, threshold=0.1)
if recommendation['recommend_fea']:
print("Recommendation: Run FEA for validation")
for reason in recommendation['reasons']:
print(f" - {reason}")
else:
print("Prediction confident - no FEA needed!")
""")