""" physics_losses.py Physics-informed loss functions for training FEA field predictors AtomizerField Physics-Informed Loss Functions v2.0 Key Innovation: Standard neural networks only minimize prediction error. Physics-informed networks also enforce physical laws: - Equilibrium: Forces must balance - Compatibility: Strains must be compatible with displacements - Constitutive: Stress must follow material law (σ = C:ε) This makes the network learn physics, not just patterns. """ import torch import torch.nn as nn import torch.nn.functional as F class PhysicsInformedLoss(nn.Module): """ Combined loss function with physics constraints Total Loss = λ_data * L_data + λ_physics * L_physics Where: - L_data: Standard MSE between prediction and FEA ground truth - L_physics: Physics violation penalty (equilibrium, compatibility, constitutive) """ def __init__( self, lambda_data=1.0, lambda_equilibrium=0.1, lambda_constitutive=0.1, lambda_boundary=1.0, use_relative_error=True ): """ Initialize physics-informed loss Args: lambda_data (float): Weight for data loss lambda_equilibrium (float): Weight for equilibrium violation lambda_constitutive (float): Weight for constitutive law violation lambda_boundary (float): Weight for boundary condition violation use_relative_error (bool): Use relative error instead of absolute """ super().__init__() self.lambda_data = lambda_data self.lambda_equilibrium = lambda_equilibrium self.lambda_constitutive = lambda_constitutive self.lambda_boundary = lambda_boundary self.use_relative_error = use_relative_error def forward(self, predictions, targets, data=None): """ Compute total physics-informed loss Args: predictions (dict): Model predictions - displacement: [num_nodes, 6] - stress: [num_nodes, 6] - von_mises: [num_nodes] targets (dict): Ground truth from FEA - displacement: [num_nodes, 6] - stress: [num_nodes, 6] data: Mesh graph data (for physics constraints) Returns: dict with: - total_loss: Combined loss - data_loss: Data fitting loss - equilibrium_loss: Equilibrium violation - constitutive_loss: Material law violation - boundary_loss: BC violation """ losses = {} # 1. Data Loss: How well do predictions match FEA results? losses['displacement_loss'] = self._displacement_loss( predictions['displacement'], targets['displacement'] ) if 'stress' in predictions and 'stress' in targets: losses['stress_loss'] = self._stress_loss( predictions['stress'], targets['stress'] ) else: losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) losses['data_loss'] = losses['displacement_loss'] + losses['stress_loss'] # 2. Physics Losses: How well do predictions obey physics? if data is not None: # Equilibrium: ∇·σ + f = 0 losses['equilibrium_loss'] = self._equilibrium_loss( predictions, data ) # Constitutive: σ = C:ε losses['constitutive_loss'] = self._constitutive_loss( predictions, data ) # Boundary conditions: u = 0 at fixed nodes losses['boundary_loss'] = self._boundary_condition_loss( predictions, data ) else: losses['equilibrium_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) losses['constitutive_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) losses['boundary_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) # Total loss losses['total_loss'] = ( self.lambda_data * losses['data_loss'] + self.lambda_equilibrium * losses['equilibrium_loss'] + self.lambda_constitutive * losses['constitutive_loss'] + self.lambda_boundary * losses['boundary_loss'] ) return losses def _displacement_loss(self, pred, target): """ Loss for displacement field Uses relative error to handle different displacement magnitudes """ if self.use_relative_error: # Relative L2 error diff = pred - target rel_error = torch.norm(diff, dim=-1) / (torch.norm(target, dim=-1) + 1e-8) return rel_error.mean() else: # Absolute MSE return F.mse_loss(pred, target) def _stress_loss(self, pred, target): """ Loss for stress field Emphasizes von Mises stress (most important for failure prediction) """ # Component-wise MSE component_loss = F.mse_loss(pred, target) # Von Mises stress MSE (computed from components) pred_vm = self._compute_von_mises(pred) target_vm = self._compute_von_mises(target) vm_loss = F.mse_loss(pred_vm, target_vm) # Combined: 50% component accuracy, 50% von Mises accuracy return 0.5 * component_loss + 0.5 * vm_loss def _equilibrium_loss(self, predictions, data): """ Equilibrium loss: ∇·σ + f = 0 In discrete form: sum of forces at each node should be zero (where not externally loaded) This is expensive to compute exactly, so we use a simplified version: Check force balance on each element """ # Simplified: For now, return zero (full implementation requires # computing stress divergence from node stresses) # TODO: Implement finite difference approximation of ∇·σ return torch.tensor(0.0, device=predictions['displacement'].device) def _constitutive_loss(self, predictions, data): """ Constitutive law loss: σ = C:ε Check if predicted stress is consistent with predicted strain (which comes from displacement gradient) Simplified version: Check if stress-strain relationship is reasonable """ # Simplified: For now, return zero # Full implementation would: # 1. Compute strain from displacement gradient # 2. Compute expected stress from strain using material stiffness # 3. Compare with predicted stress # TODO: Implement strain computation and constitutive check return torch.tensor(0.0, device=predictions['displacement'].device) def _boundary_condition_loss(self, predictions, data): """ Boundary condition loss: u = 0 at fixed DOFs Penalize non-zero displacement at constrained nodes """ if not hasattr(data, 'bc_mask') or data.bc_mask is None: return torch.tensor(0.0, device=predictions['displacement'].device) # bc_mask: [num_nodes, 6] boolean mask where True = constrained displacement = predictions['displacement'] bc_mask = data.bc_mask # Compute penalty for non-zero displacement at constrained DOFs constrained_displacement = displacement * bc_mask.float() bc_loss = torch.mean(constrained_displacement ** 2) return bc_loss def _compute_von_mises(self, stress): """ Compute von Mises stress from stress tensor components Args: stress: [num_nodes, 6] with [σxx, σyy, σzz, τxy, τyz, τxz] Returns: von_mises: [num_nodes] """ sxx, syy, szz = stress[:, 0], stress[:, 1], stress[:, 2] txy, tyz, txz = stress[:, 3], stress[:, 4], stress[:, 5] vm = torch.sqrt( 0.5 * ( (sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 + 6 * (txy**2 + tyz**2 + txz**2) ) ) return vm class FieldMSELoss(nn.Module): """ Simple MSE loss for field prediction (no physics constraints) Use this for initial training or when physics constraints are too strict. """ def __init__(self, weight_displacement=1.0, weight_stress=1.0): """ Args: weight_displacement (float): Weight for displacement loss weight_stress (float): Weight for stress loss """ super().__init__() self.weight_displacement = weight_displacement self.weight_stress = weight_stress def forward(self, predictions, targets): """ Compute MSE loss Args: predictions (dict): Model outputs targets (dict): Ground truth Returns: dict with loss components """ losses = {} # Displacement MSE losses['displacement_loss'] = F.mse_loss( predictions['displacement'], targets['displacement'] ) # Stress MSE (if available) if 'stress' in predictions and 'stress' in targets: losses['stress_loss'] = F.mse_loss( predictions['stress'], targets['stress'] ) else: losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) # Total loss losses['total_loss'] = ( self.weight_displacement * losses['displacement_loss'] + self.weight_stress * losses['stress_loss'] ) return losses class RelativeFieldLoss(nn.Module): """ Relative error loss - better for varying displacement/stress magnitudes Uses: ||pred - target|| / ||target|| This makes the loss scale-invariant. """ def __init__(self, epsilon=1e-8): """ Args: epsilon (float): Small constant to avoid division by zero """ super().__init__() self.epsilon = epsilon def forward(self, predictions, targets): """ Compute relative error loss Args: predictions (dict): Model outputs targets (dict): Ground truth Returns: dict with loss components """ losses = {} # Relative displacement error disp_diff = predictions['displacement'] - targets['displacement'] disp_norm_pred = torch.norm(disp_diff, dim=-1) disp_norm_target = torch.norm(targets['displacement'], dim=-1) losses['displacement_loss'] = (disp_norm_pred / (disp_norm_target + self.epsilon)).mean() # Relative stress error if 'stress' in predictions and 'stress' in targets: stress_diff = predictions['stress'] - targets['stress'] stress_norm_pred = torch.norm(stress_diff, dim=-1) stress_norm_target = torch.norm(targets['stress'], dim=-1) losses['stress_loss'] = (stress_norm_pred / (stress_norm_target + self.epsilon)).mean() else: losses['stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) # Total loss losses['total_loss'] = losses['displacement_loss'] + losses['stress_loss'] return losses class MaxValueLoss(nn.Module): """ Loss on maximum values only (for backward compatibility with scalar optimization) This is useful if you want to ensure the network gets the critical max values right, even if the field distribution is slightly off. """ def __init__(self): super().__init__() def forward(self, predictions, targets): """ Compute loss on maximum displacement and stress Args: predictions (dict): Model outputs with 'displacement', 'von_mises' targets (dict): Ground truth Returns: dict with loss components """ losses = {} # Max displacement error pred_max_disp = torch.max(torch.norm(predictions['displacement'][:, :3], dim=1)) target_max_disp = torch.max(torch.norm(targets['displacement'][:, :3], dim=1)) losses['max_displacement_loss'] = F.mse_loss(pred_max_disp, target_max_disp) # Max von Mises stress error if 'von_mises' in predictions and 'stress' in targets: pred_max_vm = torch.max(predictions['von_mises']) # Compute target von Mises target_stress = targets['stress'] sxx, syy, szz = target_stress[:, 0], target_stress[:, 1], target_stress[:, 2] txy, tyz, txz = target_stress[:, 3], target_stress[:, 4], target_stress[:, 5] target_vm = torch.sqrt( 0.5 * ((sxx - syy)**2 + (syy - szz)**2 + (szz - sxx)**2 + 6 * (txy**2 + tyz**2 + txz**2)) ) target_max_vm = torch.max(target_vm) losses['max_stress_loss'] = F.mse_loss(pred_max_vm, target_max_vm) else: losses['max_stress_loss'] = torch.tensor(0.0, device=predictions['displacement'].device) # Total loss losses['total_loss'] = losses['max_displacement_loss'] + losses['max_stress_loss'] return losses def create_loss_function(loss_type='mse', config=None): """ Factory function to create loss function Args: loss_type (str): Type of loss ('mse', 'relative', 'physics', 'max') config (dict): Loss function configuration Returns: Loss function instance """ if config is None: config = {} if loss_type == 'mse': return FieldMSELoss(**config) elif loss_type == 'relative': return RelativeFieldLoss(**config) elif loss_type == 'physics': return PhysicsInformedLoss(**config) elif loss_type == 'max': return MaxValueLoss(**config) else: raise ValueError(f"Unknown loss type: {loss_type}") if __name__ == "__main__": # Test loss functions print("Testing AtomizerField Loss Functions...\n") # Create dummy predictions and targets num_nodes = 100 pred = { 'displacement': torch.randn(num_nodes, 6), 'stress': torch.randn(num_nodes, 6), 'von_mises': torch.abs(torch.randn(num_nodes)) } target = { 'displacement': torch.randn(num_nodes, 6), 'stress': torch.randn(num_nodes, 6) } # Test each loss function loss_types = ['mse', 'relative', 'physics', 'max'] for loss_type in loss_types: print(f"Testing {loss_type.upper()} loss...") loss_fn = create_loss_function(loss_type) losses = loss_fn(pred, target) print(f" Total loss: {losses['total_loss']:.6f}") for key, value in losses.items(): if key != 'total_loss': print(f" {key}: {value:.6f}") print() print("Loss function tests passed!")