Files
Atomizer/atomizer-field/neural_models/uncertainty.py

362 lines
12 KiB
Python
Raw Normal View History

"""
uncertainty.py
Uncertainty quantification for neural field predictions
AtomizerField Uncertainty Quantification v2.1
Know when to trust predictions and when to run FEA!
Key Features:
- Ensemble-based uncertainty estimation
- Confidence intervals for predictions
- Automatic FEA recommendation
- Online calibration
"""
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from .field_predictor import AtomizerFieldModel
class UncertainFieldPredictor(nn.Module):
"""
Ensemble of models for uncertainty quantification
Uses multiple models trained with different initializations
to estimate prediction uncertainty.
When uncertainty is high Recommend FEA validation
When uncertainty is low Trust neural prediction
"""
def __init__(self, base_model_config, n_ensemble=5):
"""
Initialize ensemble
Args:
base_model_config (dict): Configuration for base model
n_ensemble (int): Number of models in ensemble
"""
super().__init__()
print(f"\nCreating ensemble with {n_ensemble} models...")
# Create ensemble of models
self.models = nn.ModuleList([
AtomizerFieldModel(**base_model_config)
for _ in range(n_ensemble)
])
self.n_ensemble = n_ensemble
# Initialize each model differently
for i, model in enumerate(self.models):
self._init_weights(model, seed=i)
print(f"Ensemble created with {n_ensemble} models")
def _init_weights(self, model, seed):
"""Initialize model weights with different seed"""
torch.manual_seed(seed)
def init_fn(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
model.apply(init_fn)
def forward(self, data, return_uncertainty=True, return_all_predictions=False):
"""
Forward pass through ensemble
Args:
data: Input graph data
return_uncertainty (bool): Return uncertainty estimates
return_all_predictions (bool): Return all individual predictions
Returns:
dict: Predictions with uncertainty
- displacement: Mean prediction
- stress: Mean prediction
- von_mises: Mean prediction
- displacement_std: Standard deviation (if return_uncertainty)
- stress_std: Standard deviation (if return_uncertainty)
- von_mises_std: Standard deviation (if return_uncertainty)
- all_predictions: List of all predictions (if return_all_predictions)
"""
# Get predictions from all models
all_predictions = []
for model in self.models:
with torch.no_grad():
pred = model(data, return_stress=True)
all_predictions.append(pred)
# Stack predictions
displacement_stack = torch.stack([p['displacement'] for p in all_predictions])
stress_stack = torch.stack([p['stress'] for p in all_predictions])
von_mises_stack = torch.stack([p['von_mises'] for p in all_predictions])
# Compute mean predictions
results = {
'displacement': displacement_stack.mean(dim=0),
'stress': stress_stack.mean(dim=0),
'von_mises': von_mises_stack.mean(dim=0)
}
# Compute uncertainty (standard deviation across ensemble)
if return_uncertainty:
results['displacement_std'] = displacement_stack.std(dim=0)
results['stress_std'] = stress_stack.std(dim=0)
results['von_mises_std'] = von_mises_stack.std(dim=0)
# Overall uncertainty metrics
results['max_displacement_uncertainty'] = results['displacement_std'].max().item()
results['max_stress_uncertainty'] = results['von_mises_std'].max().item()
# Uncertainty as percentage of prediction
results['displacement_rel_uncertainty'] = (
results['displacement_std'] / (torch.abs(results['displacement']) + 1e-8)
).mean().item()
results['stress_rel_uncertainty'] = (
results['von_mises_std'] / (results['von_mises'] + 1e-8)
).mean().item()
# Return all predictions if requested
if return_all_predictions:
results['all_predictions'] = all_predictions
return results
def needs_fea_validation(self, predictions, threshold=0.1):
"""
Determine if FEA validation is recommended
Args:
predictions (dict): Output from forward() with uncertainty
threshold (float): Relative uncertainty threshold
Returns:
dict: Recommendation and reasons
"""
reasons = []
# Check displacement uncertainty
if predictions['displacement_rel_uncertainty'] > threshold:
reasons.append(
f"High displacement uncertainty: "
f"{predictions['displacement_rel_uncertainty']*100:.1f}% > {threshold*100:.1f}%"
)
# Check stress uncertainty
if predictions['stress_rel_uncertainty'] > threshold:
reasons.append(
f"High stress uncertainty: "
f"{predictions['stress_rel_uncertainty']*100:.1f}% > {threshold*100:.1f}%"
)
recommend_fea = len(reasons) > 0
return {
'recommend_fea': recommend_fea,
'reasons': reasons,
'displacement_uncertainty': predictions['displacement_rel_uncertainty'],
'stress_uncertainty': predictions['stress_rel_uncertainty']
}
def get_confidence_intervals(self, predictions, confidence=0.95):
"""
Compute confidence intervals for predictions
Args:
predictions (dict): Output from forward() with uncertainty
confidence (float): Confidence level (0.95 = 95% confidence)
Returns:
dict: Confidence intervals
"""
# For normal distribution, 95% CI is ±1.96 std
# For 90% CI is ±1.645 std
z_score = {0.90: 1.645, 0.95: 1.96, 0.99: 2.576}.get(confidence, 1.96)
intervals = {}
# Displacement intervals
intervals['displacement_lower'] = predictions['displacement'] - z_score * predictions['displacement_std']
intervals['displacement_upper'] = predictions['displacement'] + z_score * predictions['displacement_std']
# Stress intervals
intervals['von_mises_lower'] = predictions['von_mises'] - z_score * predictions['von_mises_std']
intervals['von_mises_upper'] = predictions['von_mises'] + z_score * predictions['von_mises_std']
# Max values with confidence intervals
max_vm = predictions['von_mises'].max()
max_vm_std = predictions['von_mises_std'].max()
intervals['max_stress_estimate'] = max_vm.item()
intervals['max_stress_lower'] = (max_vm - z_score * max_vm_std).item()
intervals['max_stress_upper'] = (max_vm + z_score * max_vm_std).item()
return intervals
class OnlineLearner:
"""
Online learning from FEA runs during optimization
As optimization progresses and you run FEA for validation,
this module can quickly update the model to improve predictions.
This creates a virtuous cycle:
1. Use neural network for fast exploration
2. Run FEA on promising designs
3. Update neural network with new data
4. Neural network gets better need less FEA
"""
def __init__(self, model, learning_rate=0.0001):
"""
Initialize online learner
Args:
model: Neural network model
learning_rate (float): Learning rate for updates
"""
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
self.replay_buffer = []
self.update_count = 0
print(f"\nOnline learner initialized")
print(f"Learning rate: {learning_rate}")
def add_fea_result(self, graph_data, fea_results):
"""
Add new FEA result to replay buffer
Args:
graph_data: Mesh graph
fea_results (dict): FEA results (displacement, stress)
"""
self.replay_buffer.append({
'graph_data': graph_data,
'fea_results': fea_results
})
print(f"Added FEA result to buffer (total: {len(self.replay_buffer)})")
def quick_update(self, steps=10):
"""
Quick fine-tuning on recent FEA results
Args:
steps (int): Number of gradient steps
"""
if len(self.replay_buffer) == 0:
print("No data in replay buffer")
return
print(f"\nQuick update: {steps} steps on {len(self.replay_buffer)} samples")
self.model.train()
for step in range(steps):
total_loss = 0.0
# Train on all samples in buffer
for sample in self.replay_buffer:
graph_data = sample['graph_data']
fea_results = sample['fea_results']
# Forward pass
predictions = self.model(graph_data, return_stress=True)
# Compute loss
disp_loss = nn.functional.mse_loss(
predictions['displacement'],
fea_results['displacement']
)
if 'stress' in fea_results:
stress_loss = nn.functional.mse_loss(
predictions['stress'],
fea_results['stress']
)
loss = disp_loss + stress_loss
else:
loss = disp_loss
# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if step % 5 == 0:
avg_loss = total_loss / len(self.replay_buffer)
print(f" Step {step}/{steps}: Loss = {avg_loss:.6f}")
self.model.eval()
self.update_count += 1
print(f"Update complete (total updates: {self.update_count})")
def clear_buffer(self):
"""Clear replay buffer"""
self.replay_buffer = []
print("Replay buffer cleared")
def create_uncertain_predictor(model_config, n_ensemble=5):
"""
Factory function to create uncertain predictor
Args:
model_config (dict): Model configuration
n_ensemble (int): Ensemble size
Returns:
UncertainFieldPredictor instance
"""
return UncertainFieldPredictor(model_config, n_ensemble)
if __name__ == "__main__":
# Test uncertainty quantification
print("Testing Uncertainty Quantification...\n")
# Create ensemble
model_config = {
'node_feature_dim': 12,
'edge_feature_dim': 5,
'hidden_dim': 64,
'num_layers': 4,
'dropout': 0.1
}
ensemble = UncertainFieldPredictor(model_config, n_ensemble=3)
print(f"\nEnsemble created with {ensemble.n_ensemble} models")
print("Uncertainty quantification ready!")
print("\nUsage:")
print("""
# Get predictions with uncertainty
predictions = ensemble(graph_data, return_uncertainty=True)
# Check if FEA validation needed
recommendation = ensemble.needs_fea_validation(predictions, threshold=0.1)
if recommendation['recommend_fea']:
print("Recommendation: Run FEA for validation")
for reason in recommendation['reasons']:
print(f" - {reason}")
else:
print("Prediction confident - no FEA needed!")
""")