Files
Atomizer/docs/04_USER_GUIDES/PHYSICS_LOSS_GUIDE.md

531 lines
12 KiB
Markdown
Raw Normal View History

# Physics Loss Functions Guide
**Selecting and configuring loss functions for AtomizerField training**
---
## Overview
AtomizerField uses physics-informed loss functions to train neural networks that respect engineering principles. This guide explains each loss function and when to use them.
---
## Available Loss Functions
| Loss Function | Purpose | Best For |
|--------------|---------|----------|
| **MSE Loss** | Standard L2 error | General training, balanced outputs |
| **Relative Loss** | Percentage error | Multi-scale outputs (MPa + mm) |
| **Physics-Informed Loss** | Enforce physics | Better generalization, extrapolation |
| **Max Error Loss** | Penalize outliers | Safety-critical applications |
| **Combined Loss** | Weighted combination | Production models |
---
## 1. MSE Loss (Mean Squared Error)
### Description
Standard L2 loss that treats all predictions equally.
```python
loss = mean((predicted - target)²)
```
### Implementation
```python
def mse_loss(predicted, target):
"""Simple MSE loss"""
return torch.mean((predicted - target) ** 2)
```
### When to Use
- Starting point for new models
- When all outputs have similar magnitudes
- When you don't have physics constraints
### Pros & Cons
| Pros | Cons |
|------|------|
| Simple and stable | Ignores physics |
| Fast computation | Scale-sensitive |
| Well-understood | Large errors dominate |
---
## 2. Relative Loss
### Description
Computes percentage error instead of absolute error. Critical for multi-scale outputs.
```python
loss = mean(|predicted - target| / |target|)
```
### Implementation
```python
def relative_loss(predicted, target, epsilon=1e-8):
"""Relative (percentage) loss"""
relative_error = torch.abs(predicted - target) / (torch.abs(target) + epsilon)
return torch.mean(relative_error)
```
### When to Use
- Outputs have different scales (stress in MPa, displacement in mm)
- Percentage accuracy matters more than absolute accuracy
- Training data has wide range of values
### Pros & Cons
| Pros | Cons |
|------|------|
| Scale-independent | Unstable near zero |
| Intuitive (% error) | Requires epsilon |
| Equal weight to all magnitudes | May overfit small values |
### Example
```python
# Without relative loss
stress_error = |100 MPa - 105 MPa| = 5 MPa
displacement_error = |0.01 mm - 0.02 mm| = 0.01 mm
# MSE dominated by stress, displacement ignored
# With relative loss
stress_error = |5| / |100| = 5%
displacement_error = |0.01| / |0.01| = 100%
# Both contribute proportionally
```
---
## 3. Physics-Informed Loss
### Description
Adds physics constraints as regularization terms. The network learns to satisfy physical laws.
```python
loss = mse_loss + λ₁·equilibrium + λ₂·constitutive + λ₃·boundary
```
### Implementation
```python
def physics_informed_loss(predicted, target, data, config):
"""
Physics-informed loss with multiple constraint terms.
Components:
1. Data loss (MSE)
2. Equilibrium loss (F = ma)
3. Constitutive loss (σ = Eε)
4. Boundary condition loss (u = 0 at supports)
"""
# Data loss
data_loss = mse_loss(predicted, target)
# Equilibrium loss: sum of forces at each node = 0
equilibrium_loss = compute_equilibrium_residual(
predicted['displacement'],
data.edge_index,
data.stiffness
)
# Constitutive loss: stress-strain relationship
predicted_stress = compute_stress_from_displacement(
predicted['displacement'],
data.material,
data.strain_operator
)
constitutive_loss = mse_loss(predicted['stress'], predicted_stress)
# Boundary condition loss: fixed nodes have zero displacement
bc_mask = data.boundary_conditions > 0
bc_loss = torch.mean(predicted['displacement'][bc_mask] ** 2)
# Combine with weights
total_loss = (
data_loss +
config.lambda_equilibrium * equilibrium_loss +
config.lambda_constitutive * constitutive_loss +
config.lambda_bc * bc_loss
)
return total_loss
```
### Physics Constraints
#### Equilibrium (Force Balance)
At each node, the sum of forces must be zero:
```
∑F = 0 at every node
```
```python
def equilibrium_residual(displacement, stiffness_matrix):
"""
Check if Ku = F (stiffness × displacement = force)
Residual should be zero for valid solutions.
"""
internal_forces = stiffness_matrix @ displacement
external_forces = get_external_forces()
residual = internal_forces - external_forces
return torch.mean(residual ** 2)
```
#### Constitutive (Stress-Strain)
Stress must follow material law:
```
σ = Eε (Hooke's law)
```
```python
def constitutive_residual(displacement, stress, material):
"""
Check if stress follows constitutive law.
"""
strain = compute_strain(displacement)
predicted_stress = material.E * strain
residual = stress - predicted_stress
return torch.mean(residual ** 2)
```
#### Boundary Conditions
Fixed nodes must have zero displacement:
```python
def boundary_residual(displacement, bc_mask):
"""
Fixed nodes should have zero displacement.
"""
return torch.mean(displacement[bc_mask] ** 2)
```
### When to Use
- When you need good generalization
- When extrapolating beyond training data
- When physical correctness is important
- When training data is limited
### Pros & Cons
| Pros | Cons |
|------|------|
| Physics consistency | More computation |
| Better extrapolation | Requires physics info |
| Works with less data | Weight tuning needed |
### Weight Selection
| Constraint | Typical λ | Notes |
|------------|-----------|-------|
| Equilibrium | 0.1 - 0.5 | Most important |
| Constitutive | 0.05 - 0.2 | Material law |
| Boundary | 0.5 - 1.0 | Hard constraint |
---
## 4. Max Error Loss
### Description
Penalizes the worst predictions. Critical for safety-critical applications.
```python
loss = max(|predicted - target|)
```
### Implementation
```python
def max_error_loss(predicted, target, percentile=99):
"""
Penalize worst predictions.
Uses percentile to avoid single outlier domination.
"""
errors = torch.abs(predicted - target)
# Use percentile instead of max for stability
max_error = torch.quantile(errors, percentile / 100.0)
return max_error
```
### When to Use
- Safety-critical applications
- When outliers are unacceptable
- Quality assurance requirements
- Certification contexts
### Pros & Cons
| Pros | Cons |
|------|------|
| Controls worst case | Unstable gradients |
| Safety-focused | May slow convergence |
| Clear metric | Sensitive to outliers |
---
## 5. Combined Loss (Production)
### Description
Combines multiple loss functions for production models.
```python
loss = α·MSE + β·Relative + γ·Physics + δ·MaxError
```
### Implementation
```python
def combined_loss(predicted, target, data, config):
"""
Production loss combining multiple objectives.
"""
losses = {}
# MSE component
losses['mse'] = mse_loss(predicted, target)
# Relative component
losses['relative'] = relative_loss(predicted, target)
# Physics component
losses['physics'] = physics_informed_loss(predicted, target, data, config)
# Max error component
losses['max'] = max_error_loss(predicted, target)
# Weighted combination
total = (
config.alpha * losses['mse'] +
config.beta * losses['relative'] +
config.gamma * losses['physics'] +
config.delta * losses['max']
)
return total, losses
```
### Recommended Weights
| Application | MSE (α) | Relative (β) | Physics (γ) | Max (δ) |
|-------------|---------|--------------|-------------|---------|
| General | 0.5 | 0.3 | 0.2 | 0.0 |
| Multi-scale | 0.2 | 0.5 | 0.2 | 0.1 |
| Safety-critical | 0.2 | 0.2 | 0.3 | 0.3 |
| Extrapolation | 0.2 | 0.2 | 0.5 | 0.1 |
---
## Configuration Examples
### Basic Training
```python
# config.yaml
loss:
type: "mse"
```
### Multi-Scale Outputs
```python
# config.yaml
loss:
type: "combined"
weights:
mse: 0.2
relative: 0.5
physics: 0.2
max_error: 0.1
```
### Physics-Informed Training
```python
# config.yaml
loss:
type: "physics_informed"
physics_weight: 0.3
constraints:
equilibrium: 0.3
constitutive: 0.1
boundary: 0.5
```
### Safety-Critical
```python
# config.yaml
loss:
type: "combined"
weights:
mse: 0.2
relative: 0.2
physics: 0.3
max_error: 0.3
max_error_percentile: 99
```
---
## Training Strategies
### Curriculum Learning
Start simple, add complexity:
```python
def get_loss_weights(epoch, total_epochs):
"""Gradually increase physics loss weight"""
progress = epoch / total_epochs
if progress < 0.3:
# Phase 1: Pure MSE
return {'mse': 1.0, 'physics': 0.0}
elif progress < 0.6:
# Phase 2: Add physics
physics_weight = (progress - 0.3) / 0.3 * 0.3
return {'mse': 1.0 - physics_weight, 'physics': physics_weight}
else:
# Phase 3: Full physics
return {'mse': 0.7, 'physics': 0.3}
```
### Adaptive Weighting
Adjust weights based on loss magnitudes:
```python
def adaptive_weights(losses):
"""Balance losses to similar magnitudes"""
# Compute inverse of each loss (normalized)
total = sum(losses.values())
weights = {k: total / (v + 1e-8) for k, v in losses.items()}
# Normalize to sum to 1
weight_sum = sum(weights.values())
weights = {k: v / weight_sum for k, v in weights.items()}
return weights
```
---
## Troubleshooting
### Loss Not Decreasing
**Symptom**: Training loss stays flat.
**Solutions**:
1. Reduce learning rate
2. Check data normalization
3. Simplify loss (use MSE first)
4. Increase model capacity
### Physics Loss Dominates
**Symptom**: Physics loss >> data loss.
**Solutions**:
1. Reduce physics weight (λ)
2. Use curriculum learning
3. Check physics computation
4. Normalize constraints
### Unstable Training
**Symptom**: Loss oscillates or explodes.
**Solutions**:
1. Use gradient clipping
2. Reduce learning rate
3. Check for NaN in physics terms
4. Add epsilon to divisions
---
## Metrics for Evaluation
### Training Metrics
```python
metrics = {
'train_loss': total_loss.item(),
'train_mse': losses['mse'].item(),
'train_physics': losses['physics'].item(),
'train_max': losses['max'].item()
}
```
### Validation Metrics
```python
def compute_validation_metrics(model, val_loader):
"""Compute physics-aware validation metrics"""
all_errors = []
physics_violations = []
for batch in val_loader:
pred = model(batch)
# Prediction errors
errors = torch.abs(pred - batch.y)
all_errors.append(errors)
# Physics violations
violations = compute_physics_residual(pred, batch)
physics_violations.append(violations)
return {
'val_mae': torch.cat(all_errors).mean(),
'val_max': torch.cat(all_errors).max(),
'val_physics_violation': torch.cat(physics_violations).mean(),
'val_physics_compliance': (torch.cat(physics_violations) < 0.01).float().mean()
}
```
---
## Summary
| Situation | Recommended Loss |
|-----------|-----------------|
| Starting out | MSE |
| Multi-scale outputs | Relative + MSE |
| Need generalization | Physics-informed |
| Safety-critical | Combined with max error |
| Limited training data | Physics-informed |
| Production deployment | Combined (tuned) |
---
## See Also
- [Neural Features Complete](NEURAL_FEATURES_COMPLETE.md) - Overview
- [GNN Architecture](GNN_ARCHITECTURE.md) - Model details
- [Neural Workflow Tutorial](NEURAL_WORKFLOW_TUTORIAL.md) - Training guide