450 lines
15 KiB
Python
450 lines
15 KiB
Python
|
|
"""
|
|||
|
|
physics_losses.py
|
|||
|
|
Physics-informed loss functions for training FEA field predictors
|
|||
|
|
|
|||
|
|
AtomizerField Physics-Informed Loss Functions v2.0
|
|||
|
|
|
|||
|
|
Key Innovation:
|
|||
|
|
Standard neural networks only minimize prediction error.
|
|||
|
|
Physics-informed networks also enforce physical laws:
|
|||
|
|
- Equilibrium: Forces must balance
|
|||
|
|
- Compatibility: Strains must be compatible with displacements
|
|||
|
|
- Constitutive: Stress must follow material law (σ = C:ε)
|
|||
|
|
|
|||
|
|
This makes the network learn physics, not just patterns.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PhysicsInformedLoss(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Combined loss function with physics constraints
|
|||
|
|
|
|||
|
|
Total Loss = λ_data * L_data + λ_physics * L_physics
|
|||
|
|
|
|||
|
|
Where:
|
|||
|
|
- L_data: Standard MSE between prediction and FEA ground truth
|
|||
|
|
- L_physics: Physics violation penalty (equilibrium, compatibility, constitutive)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
lambda_data=1.0,
|
|||
|
|
lambda_equilibrium=0.1,
|
|||
|
|
lambda_constitutive=0.1,
|
|||
|
|
lambda_boundary=1.0,
|
|||
|
|
use_relative_error=True
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Initialize physics-informed loss
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
lambda_data (float): Weight for data loss
|
|||
|
|
lambda_equilibrium (float): Weight for equilibrium violation
|
|||
|
|
lambda_constitutive (float): Weight for constitutive law violation
|
|||
|
|
lambda_boundary (float): Weight for boundary condition violation
|
|||
|
|
use_relative_error (bool): Use relative error instead of absolute
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.lambda_data = lambda_data
|
|||
|
|
self.lambda_equilibrium = lambda_equilibrium
|
|||
|
|
self.lambda_constitutive = lambda_constitutive
|
|||
|
|
self.lambda_boundary = lambda_boundary
|
|||
|
|
self.use_relative_error = use_relative_error
|
|||
|
|
|
|||
|
|
def forward(self, predictions, targets, data=None):
|
|||
|
|
"""
|
|||
|
|
Compute total physics-informed loss
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
predictions (dict): Model predictions
|
|||
|
|
- displacement: [num_nodes, 6]
|
|||
|
|
- stress: [num_nodes, 6]
|
|||
|
|
- von_mises: [num_nodes]
|
|||
|
|
targets (dict): Ground truth from FEA
|
|||
|
|
- displacement: [num_nodes, 6]
|
|||
|
|
- stress: [num_nodes, 6]
|
|||
|
|
data: Mesh graph data (for physics constraints)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict with:
|
|||
|
|
- total_loss: Combined loss
|
|||
|
|
- data_loss: Data fitting loss
|
|||
|
|
- equilibrium_loss: Equilibrium violation
|
|||
|
|
- constitutive_loss: Material law violation
|
|||
|
|
- boundary_loss: BC violation
|
|||
|
|
"""
|
|||
|
|
losses = {}
|
|||
|
|
|
|||
|
|
# 1. Data Loss: How well do predictions match FEA results?
|
|||
|
|
losses['displacement_loss'] = self._displacement_loss(
|
|||
|
|
predictions['displacement'],
|
|||
|
|
targets['displacement']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if 'stress' in predictions and 'stress' in targets:
|
|||
|
|
losses['stress_loss'] = self._stress_loss(
|
|||
|
|
predictions['stress'],
|
|||
|
|
targets['stress']
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
losses['data_loss'] = losses['displacement_loss'] + losses['stress_loss']
|
|||
|
|
|
|||
|
|
# 2. Physics Losses: How well do predictions obey physics?
|
|||
|
|
if data is not None:
|
|||
|
|
# Equilibrium: ∇·σ + f = 0
|
|||
|
|
losses['equilibrium_loss'] = self._equilibrium_loss(
|
|||
|
|
predictions, data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Constitutive: σ = C:ε
|
|||
|
|
losses['constitutive_loss'] = self._constitutive_loss(
|
|||
|
|
predictions, data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Boundary conditions: u = 0 at fixed nodes
|
|||
|
|
losses['boundary_loss'] = self._boundary_condition_loss(
|
|||
|
|
predictions, data
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
losses['equilibrium_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
losses['constitutive_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
losses['boundary_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
# Total loss
|
|||
|
|
losses['total_loss'] = (
|
|||
|
|
self.lambda_data * losses['data_loss'] +
|
|||
|
|
self.lambda_equilibrium * losses['equilibrium_loss'] +
|
|||
|
|
self.lambda_constitutive * losses['constitutive_loss'] +
|
|||
|
|
self.lambda_boundary * losses['boundary_loss']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return losses
|
|||
|
|
|
|||
|
|
def _displacement_loss(self, pred, target):
|
|||
|
|
"""
|
|||
|
|
Loss for displacement field
|
|||
|
|
|
|||
|
|
Uses relative error to handle different displacement magnitudes
|
|||
|
|
"""
|
|||
|
|
if self.use_relative_error:
|
|||
|
|
# Relative L2 error
|
|||
|
|
diff = pred - target
|
|||
|
|
rel_error = torch.norm(diff, dim=-1) / (torch.norm(target, dim=-1) + 1e-8)
|
|||
|
|
return rel_error.mean()
|
|||
|
|
else:
|
|||
|
|
# Absolute MSE
|
|||
|
|
return F.mse_loss(pred, target)
|
|||
|
|
|
|||
|
|
def _stress_loss(self, pred, target):
|
|||
|
|
"""
|
|||
|
|
Loss for stress field
|
|||
|
|
|
|||
|
|
Emphasizes von Mises stress (most important for failure prediction)
|
|||
|
|
"""
|
|||
|
|
# Component-wise MSE
|
|||
|
|
component_loss = F.mse_loss(pred, target)
|
|||
|
|
|
|||
|
|
# Von Mises stress MSE (computed from components)
|
|||
|
|
pred_vm = self._compute_von_mises(pred)
|
|||
|
|
target_vm = self._compute_von_mises(target)
|
|||
|
|
vm_loss = F.mse_loss(pred_vm, target_vm)
|
|||
|
|
|
|||
|
|
# Combined: 50% component accuracy, 50% von Mises accuracy
|
|||
|
|
return 0.5 * component_loss + 0.5 * vm_loss
|
|||
|
|
|
|||
|
|
def _equilibrium_loss(self, predictions, data):
|
|||
|
|
"""
|
|||
|
|
Equilibrium loss: ∇·σ + f = 0
|
|||
|
|
|
|||
|
|
In discrete form: sum of forces at each node should be zero
|
|||
|
|
(where not externally loaded)
|
|||
|
|
|
|||
|
|
This is expensive to compute exactly, so we use a simplified version:
|
|||
|
|
Check force balance on each element
|
|||
|
|
"""
|
|||
|
|
# Simplified: For now, return zero (full implementation requires
|
|||
|
|
# computing stress divergence from node stresses)
|
|||
|
|
# TODO: Implement finite difference approximation of ∇·σ
|
|||
|
|
return torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
def _constitutive_loss(self, predictions, data):
|
|||
|
|
"""
|
|||
|
|
Constitutive law loss: σ = C:ε
|
|||
|
|
|
|||
|
|
Check if predicted stress is consistent with predicted strain
|
|||
|
|
(which comes from displacement gradient)
|
|||
|
|
|
|||
|
|
Simplified version: Check if stress-strain relationship is reasonable
|
|||
|
|
"""
|
|||
|
|
# Simplified: For now, return zero
|
|||
|
|
# Full implementation would:
|
|||
|
|
# 1. Compute strain from displacement gradient
|
|||
|
|
# 2. Compute expected stress from strain using material stiffness
|
|||
|
|
# 3. Compare with predicted stress
|
|||
|
|
# TODO: Implement strain computation and constitutive check
|
|||
|
|
return torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
def _boundary_condition_loss(self, predictions, data):
|
|||
|
|
"""
|
|||
|
|
Boundary condition loss: u = 0 at fixed DOFs
|
|||
|
|
|
|||
|
|
Penalize non-zero displacement at constrained nodes
|
|||
|
|
"""
|
|||
|
|
if not hasattr(data, 'bc_mask') or data.bc_mask is None:
|
|||
|
|
return torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
# bc_mask: [num_nodes, 6] boolean mask where True = constrained
|
|||
|
|
displacement = predictions['displacement']
|
|||
|
|
bc_mask = data.bc_mask
|
|||
|
|
|
|||
|
|
# Compute penalty for non-zero displacement at constrained DOFs
|
|||
|
|
constrained_displacement = displacement * bc_mask.float()
|
|||
|
|
bc_loss = torch.mean(constrained_displacement ** 2)
|
|||
|
|
|
|||
|
|
return bc_loss
|
|||
|
|
|
|||
|
|
def _compute_von_mises(self, stress):
|
|||
|
|
"""
|
|||
|
|
Compute von Mises stress from stress tensor components
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
stress: [num_nodes, 6] with [σxx, σyy, σzz, τxy, τyz, τxz]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
von_mises: [num_nodes]
|
|||
|
|
"""
|
|||
|
|
sxx, syy, szz = stress[:, 0], stress[:, 1], stress[:, 2]
|
|||
|
|
txy, tyz, txz = stress[:, 3], stress[:, 4], stress[:, 5]
|
|||
|
|
|
|||
|
|
vm = torch.sqrt(
|
|||
|
|
0.5 * (
|
|||
|
|
(sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 +
|
|||
|
|
6 * (txy**2 + tyz**2 + txz**2)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return vm
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FieldMSELoss(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Simple MSE loss for field prediction (no physics constraints)
|
|||
|
|
|
|||
|
|
Use this for initial training or when physics constraints are too strict.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, weight_displacement=1.0, weight_stress=1.0):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
weight_displacement (float): Weight for displacement loss
|
|||
|
|
weight_stress (float): Weight for stress loss
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
self.weight_displacement = weight_displacement
|
|||
|
|
self.weight_stress = weight_stress
|
|||
|
|
|
|||
|
|
def forward(self, predictions, targets):
|
|||
|
|
"""
|
|||
|
|
Compute MSE loss
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
predictions (dict): Model outputs
|
|||
|
|
targets (dict): Ground truth
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict with loss components
|
|||
|
|
"""
|
|||
|
|
losses = {}
|
|||
|
|
|
|||
|
|
# Displacement MSE
|
|||
|
|
losses['displacement_loss'] = F.mse_loss(
|
|||
|
|
predictions['displacement'],
|
|||
|
|
targets['displacement']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Stress MSE (if available)
|
|||
|
|
if 'stress' in predictions and 'stress' in targets:
|
|||
|
|
losses['stress_loss'] = F.mse_loss(
|
|||
|
|
predictions['stress'],
|
|||
|
|
targets['stress']
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
# Total loss
|
|||
|
|
losses['total_loss'] = (
|
|||
|
|
self.weight_displacement * losses['displacement_loss'] +
|
|||
|
|
self.weight_stress * losses['stress_loss']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return losses
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RelativeFieldLoss(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Relative error loss - better for varying displacement/stress magnitudes
|
|||
|
|
|
|||
|
|
Uses: ||pred - target|| / ||target||
|
|||
|
|
This makes the loss scale-invariant.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, epsilon=1e-8):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
epsilon (float): Small constant to avoid division by zero
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
self.epsilon = epsilon
|
|||
|
|
|
|||
|
|
def forward(self, predictions, targets):
|
|||
|
|
"""
|
|||
|
|
Compute relative error loss
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
predictions (dict): Model outputs
|
|||
|
|
targets (dict): Ground truth
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict with loss components
|
|||
|
|
"""
|
|||
|
|
losses = {}
|
|||
|
|
|
|||
|
|
# Relative displacement error
|
|||
|
|
disp_diff = predictions['displacement'] - targets['displacement']
|
|||
|
|
disp_norm_pred = torch.norm(disp_diff, dim=-1)
|
|||
|
|
disp_norm_target = torch.norm(targets['displacement'], dim=-1)
|
|||
|
|
losses['displacement_loss'] = (disp_norm_pred / (disp_norm_target + self.epsilon)).mean()
|
|||
|
|
|
|||
|
|
# Relative stress error
|
|||
|
|
if 'stress' in predictions and 'stress' in targets:
|
|||
|
|
stress_diff = predictions['stress'] - targets['stress']
|
|||
|
|
stress_norm_pred = torch.norm(stress_diff, dim=-1)
|
|||
|
|
stress_norm_target = torch.norm(targets['stress'], dim=-1)
|
|||
|
|
losses['stress_loss'] = (stress_norm_pred / (stress_norm_target + self.epsilon)).mean()
|
|||
|
|
else:
|
|||
|
|
losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
# Total loss
|
|||
|
|
losses['total_loss'] = losses['displacement_loss'] + losses['stress_loss']
|
|||
|
|
|
|||
|
|
return losses
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MaxValueLoss(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Loss on maximum values only (for backward compatibility with scalar optimization)
|
|||
|
|
|
|||
|
|
This is useful if you want to ensure the network gets the critical max values right,
|
|||
|
|
even if the field distribution is slightly off.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
def forward(self, predictions, targets):
|
|||
|
|
"""
|
|||
|
|
Compute loss on maximum displacement and stress
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
predictions (dict): Model outputs with 'displacement', 'von_mises'
|
|||
|
|
targets (dict): Ground truth
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict with loss components
|
|||
|
|
"""
|
|||
|
|
losses = {}
|
|||
|
|
|
|||
|
|
# Max displacement error
|
|||
|
|
pred_max_disp = torch.max(torch.norm(predictions['displacement'][:, :3], dim=1))
|
|||
|
|
target_max_disp = torch.max(torch.norm(targets['displacement'][:, :3], dim=1))
|
|||
|
|
losses['max_displacement_loss'] = F.mse_loss(pred_max_disp, target_max_disp)
|
|||
|
|
|
|||
|
|
# Max von Mises stress error
|
|||
|
|
if 'von_mises' in predictions and 'stress' in targets:
|
|||
|
|
pred_max_vm = torch.max(predictions['von_mises'])
|
|||
|
|
|
|||
|
|
# Compute target von Mises
|
|||
|
|
target_stress = targets['stress']
|
|||
|
|
sxx, syy, szz = target_stress[:, 0], target_stress[:, 1], target_stress[:, 2]
|
|||
|
|
txy, tyz, txz = target_stress[:, 3], target_stress[:, 4], target_stress[:, 5]
|
|||
|
|
target_vm = torch.sqrt(
|
|||
|
|
0.5 * ((sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 +
|
|||
|
|
6 * (txy**2 + tyz**2 + txz**2))
|
|||
|
|
)
|
|||
|
|
target_max_vm = torch.max(target_vm)
|
|||
|
|
|
|||
|
|
losses['max_stress_loss'] = F.mse_loss(pred_max_vm, target_max_vm)
|
|||
|
|
else:
|
|||
|
|
losses['max_stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device)
|
|||
|
|
|
|||
|
|
# Total loss
|
|||
|
|
losses['total_loss'] = losses['max_displacement_loss'] + losses['max_stress_loss']
|
|||
|
|
|
|||
|
|
return losses
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_loss_function(loss_type='mse', config=None):
|
|||
|
|
"""
|
|||
|
|
Factory function to create loss function
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
loss_type (str): Type of loss ('mse', 'relative', 'physics', 'max')
|
|||
|
|
config (dict): Loss function configuration
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Loss function instance
|
|||
|
|
"""
|
|||
|
|
if config is None:
|
|||
|
|
config = {}
|
|||
|
|
|
|||
|
|
if loss_type == 'mse':
|
|||
|
|
return FieldMSELoss(**config)
|
|||
|
|
elif loss_type == 'relative':
|
|||
|
|
return RelativeFieldLoss(**config)
|
|||
|
|
elif loss_type == 'physics':
|
|||
|
|
return PhysicsInformedLoss(**config)
|
|||
|
|
elif loss_type == 'max':
|
|||
|
|
return MaxValueLoss(**config)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# Test loss functions
|
|||
|
|
print("Testing AtomizerField Loss Functions...\n")
|
|||
|
|
|
|||
|
|
# Create dummy predictions and targets
|
|||
|
|
num_nodes = 100
|
|||
|
|
pred = {
|
|||
|
|
'displacement': torch.randn(num_nodes, 6),
|
|||
|
|
'stress': torch.randn(num_nodes, 6),
|
|||
|
|
'von_mises': torch.abs(torch.randn(num_nodes))
|
|||
|
|
}
|
|||
|
|
target = {
|
|||
|
|
'displacement': torch.randn(num_nodes, 6),
|
|||
|
|
'stress': torch.randn(num_nodes, 6)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Test each loss function
|
|||
|
|
loss_types = ['mse', 'relative', 'physics', 'max']
|
|||
|
|
|
|||
|
|
for loss_type in loss_types:
|
|||
|
|
print(f"Testing {loss_type.upper()} loss...")
|
|||
|
|
loss_fn = create_loss_function(loss_type)
|
|||
|
|
losses = loss_fn(pred, target)
|
|||
|
|
|
|||
|
|
print(f" Total loss: {losses['total_loss']:.6f}")
|
|||
|
|
for key, value in losses.items():
|
|||
|
|
if key != 'total_loss':
|
|||
|
|
print(f" {key}: {value:.6f}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
print("Loss function tests passed!")
|