feat: Merge Atomizer-Field neural network module into main repository
Permanently integrates the Atomizer-Field GNN surrogate system: - neural_models/: Graph Neural Network for FEA field prediction - batch_parser.py: Parse training data from FEA exports - train.py: Neural network training pipeline - predict.py: Inference engine for fast predictions This enables 600x-2200x speedup over traditional FEA by replacing expensive simulations with millisecond neural network predictions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
515
atomizer-field/visualize_results.py
Normal file
515
atomizer-field/visualize_results.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
visualize_results.py
|
||||
3D Visualization of FEA Results and Neural Predictions
|
||||
|
||||
Visualizes:
|
||||
- Mesh structure
|
||||
- Displacement fields
|
||||
- Stress fields (von Mises)
|
||||
- Comparison: FEA vs Neural predictions
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import cm
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
import json
|
||||
import h5py
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
class FEAVisualizer:
|
||||
"""Visualize FEA results in 3D"""
|
||||
|
||||
def __init__(self, case_dir):
|
||||
"""
|
||||
Initialize visualizer
|
||||
|
||||
Args:
|
||||
case_dir: Path to case directory with neural_field_data files
|
||||
"""
|
||||
self.case_dir = Path(case_dir)
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {case_dir}...")
|
||||
self.load_data()
|
||||
|
||||
def load_data(self):
|
||||
"""Load JSON metadata and HDF5 field data"""
|
||||
json_file = self.case_dir / "neural_field_data.json"
|
||||
h5_file = self.case_dir / "neural_field_data.h5"
|
||||
|
||||
# Load JSON
|
||||
with open(json_file, 'r') as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
# Get connectivity from JSON
|
||||
self.connectivity = []
|
||||
if 'mesh' in self.metadata and 'elements' in self.metadata['mesh']:
|
||||
elements = self.metadata['mesh']['elements']
|
||||
# Elements are categorized by type: solid, shell, beam, rigid
|
||||
for elem_category in ['solid', 'shell', 'beam']:
|
||||
if elem_category in elements and isinstance(elements[elem_category], list):
|
||||
for elem_data in elements[elem_category]:
|
||||
elem_type = elem_data.get('type', '')
|
||||
if elem_type in ['CQUAD4', 'CTRIA3', 'CTETRA', 'CHEXA']:
|
||||
# Store connectivity: [elem_id, n1, n2, n3, n4, ...]
|
||||
nodes = elem_data.get('nodes', [])
|
||||
self.connectivity.append([elem_data['id']] + nodes)
|
||||
self.connectivity = np.array(self.connectivity) if self.connectivity else np.array([[]])
|
||||
|
||||
# Load HDF5
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
self.node_coords = f['mesh/node_coordinates'][:]
|
||||
|
||||
# Get node IDs to create mapping
|
||||
self.node_ids = f['mesh/node_ids'][:]
|
||||
# Create mapping from node ID to index
|
||||
self.node_id_to_idx = {nid: idx for idx, nid in enumerate(self.node_ids)}
|
||||
|
||||
# Displacement
|
||||
if 'results/displacement' in f:
|
||||
self.displacement = f['results/displacement'][:]
|
||||
else:
|
||||
self.displacement = None
|
||||
|
||||
# Stress (try different possible locations)
|
||||
self.stress = None
|
||||
if 'results/stress/cquad4_stress/data' in f:
|
||||
self.stress = f['results/stress/cquad4_stress/data'][:]
|
||||
elif 'results/stress/cquad4_stress' in f and hasattr(f['results/stress/cquad4_stress'], 'shape'):
|
||||
self.stress = f['results/stress/cquad4_stress'][:]
|
||||
elif 'results/stress' in f and hasattr(f['results/stress'], 'shape'):
|
||||
self.stress = f['results/stress'][:]
|
||||
|
||||
print(f"Loaded {len(self.node_coords)} nodes, {len(self.connectivity)} elements")
|
||||
|
||||
def plot_mesh(self, figsize=(12, 8), save_path=None):
|
||||
"""
|
||||
Plot 3D mesh structure
|
||||
|
||||
Args:
|
||||
figsize: Figure size
|
||||
save_path: Path to save figure (optional)
|
||||
"""
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# Extract coordinates
|
||||
x = self.node_coords[:, 0]
|
||||
y = self.node_coords[:, 1]
|
||||
z = self.node_coords[:, 2]
|
||||
|
||||
# Plot nodes
|
||||
ax.scatter(x, y, z, c='blue', marker='.', s=1, alpha=0.3, label='Nodes')
|
||||
|
||||
# Plot a subset of elements (for visibility)
|
||||
step = max(1, len(self.connectivity) // 1000) # Show max 1000 elements
|
||||
for i in range(0, len(self.connectivity), step):
|
||||
elem = self.connectivity[i]
|
||||
if len(elem) < 5: # Skip invalid elements
|
||||
continue
|
||||
|
||||
# Get node IDs (skip element ID at position 0)
|
||||
node_ids = elem[1:5] # First 4 nodes for CQUAD4
|
||||
|
||||
# Convert node IDs to indices
|
||||
try:
|
||||
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
|
||||
except KeyError:
|
||||
continue # Skip if node not found
|
||||
|
||||
# Get coordinates
|
||||
elem_coords = self.node_coords[nodes]
|
||||
|
||||
# Plot element edges
|
||||
for j in range(min(4, len(nodes))):
|
||||
next_j = (j + 1) % len(nodes)
|
||||
ax.plot([elem_coords[j, 0], elem_coords[next_j, 0]],
|
||||
[elem_coords[j, 1], elem_coords[next_j, 1]],
|
||||
[elem_coords[j, 2], elem_coords[next_j, 2]],
|
||||
'k-', linewidth=0.1, alpha=0.1)
|
||||
|
||||
ax.set_xlabel('X (mm)')
|
||||
ax.set_ylabel('Y (mm)')
|
||||
ax.set_zlabel('Z (mm)')
|
||||
ax.set_title(f'Mesh Structure\n{len(self.node_coords)} nodes, {len(self.connectivity)} elements')
|
||||
|
||||
# Equal aspect ratio
|
||||
self._set_equal_aspect(ax)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved mesh plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_displacement(self, scale=1.0, component='magnitude', figsize=(14, 8), save_path=None):
|
||||
"""
|
||||
Plot displacement field
|
||||
|
||||
Args:
|
||||
scale: Scale factor for displacement visualization
|
||||
component: 'magnitude', 'x', 'y', or 'z'
|
||||
figsize: Figure size
|
||||
save_path: Path to save figure
|
||||
"""
|
||||
if self.displacement is None:
|
||||
print("No displacement data available")
|
||||
return
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
|
||||
# Original mesh
|
||||
ax1 = fig.add_subplot(121, projection='3d')
|
||||
self._plot_mesh_with_field(ax1, self.displacement, component, scale=0)
|
||||
ax1.set_title('Original Mesh')
|
||||
|
||||
# Deformed mesh
|
||||
ax2 = fig.add_subplot(122, projection='3d')
|
||||
self._plot_mesh_with_field(ax2, self.displacement, component, scale=scale)
|
||||
ax2.set_title(f'Deformed Mesh (scale={scale}x)\nDisplacement: {component}')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved displacement plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_stress(self, component='von_mises', figsize=(12, 8), save_path=None):
|
||||
"""
|
||||
Plot stress field
|
||||
|
||||
Args:
|
||||
component: 'von_mises', 'xx', 'yy', 'zz', 'xy', 'yz', 'xz'
|
||||
figsize: Figure size
|
||||
save_path: Path to save figure
|
||||
"""
|
||||
if self.stress is None:
|
||||
print("No stress data available")
|
||||
return
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# Get stress component
|
||||
if component == 'von_mises':
|
||||
# Von Mises already computed (last column)
|
||||
stress_values = self.stress[:, -1]
|
||||
else:
|
||||
# Map component name to index
|
||||
comp_map = {'xx': 0, 'yy': 1, 'zz': 2, 'xy': 3, 'yz': 4, 'xz': 5}
|
||||
idx = comp_map.get(component, 0)
|
||||
stress_values = self.stress[:, idx]
|
||||
|
||||
# Plot elements colored by stress
|
||||
self._plot_elements_with_stress(ax, stress_values)
|
||||
|
||||
ax.set_xlabel('X (mm)')
|
||||
ax.set_ylabel('Y (mm)')
|
||||
ax.set_zlabel('Z (mm)')
|
||||
ax.set_title(f'Stress Field: {component}\nMax: {np.max(stress_values):.2f} MPa')
|
||||
|
||||
self._set_equal_aspect(ax)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved stress plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_comparison(self, neural_predictions, figsize=(16, 6), save_path=None):
|
||||
"""
|
||||
Plot comparison: FEA vs Neural predictions
|
||||
|
||||
Args:
|
||||
neural_predictions: Dict with 'displacement' and/or 'stress'
|
||||
figsize: Figure size
|
||||
save_path: Path to save figure
|
||||
"""
|
||||
fig = plt.figure(figsize=figsize)
|
||||
|
||||
# Displacement comparison
|
||||
if self.displacement is not None and 'displacement' in neural_predictions:
|
||||
ax1 = fig.add_subplot(131, projection='3d')
|
||||
self._plot_mesh_with_field(ax1, self.displacement, 'magnitude', scale=10)
|
||||
ax1.set_title('FEA Displacement')
|
||||
|
||||
ax2 = fig.add_subplot(132, projection='3d')
|
||||
neural_disp = neural_predictions['displacement']
|
||||
self._plot_mesh_with_field(ax2, neural_disp, 'magnitude', scale=10)
|
||||
ax2.set_title('Neural Prediction')
|
||||
|
||||
# Error
|
||||
ax3 = fig.add_subplot(133, projection='3d')
|
||||
error = np.linalg.norm(self.displacement[:, :3] - neural_disp[:, :3], axis=1)
|
||||
self._plot_nodes_with_values(ax3, error)
|
||||
ax3.set_title('Prediction Error')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved comparison plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def _plot_mesh_with_field(self, ax, field, component, scale=1.0):
|
||||
"""Helper: Plot mesh colored by field values"""
|
||||
# Get field component
|
||||
if component == 'magnitude':
|
||||
values = np.linalg.norm(field[:, :3], axis=1)
|
||||
elif component == 'x':
|
||||
values = field[:, 0]
|
||||
elif component == 'y':
|
||||
values = field[:, 1]
|
||||
elif component == 'z':
|
||||
values = field[:, 2]
|
||||
else:
|
||||
values = np.linalg.norm(field[:, :3], axis=1)
|
||||
|
||||
# Apply deformation
|
||||
coords = self.node_coords + scale * field[:, :3]
|
||||
|
||||
# Plot nodes colored by values
|
||||
scatter = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2],
|
||||
c=values, cmap='jet', s=2)
|
||||
plt.colorbar(scatter, ax=ax, label=f'{component} (mm)')
|
||||
|
||||
ax.set_xlabel('X (mm)')
|
||||
ax.set_ylabel('Y (mm)')
|
||||
ax.set_zlabel('Z (mm)')
|
||||
|
||||
self._set_equal_aspect(ax)
|
||||
|
||||
def _plot_elements_with_stress(self, ax, stress_values):
|
||||
"""Helper: Plot elements colored by stress"""
|
||||
# Normalize stress for colormap
|
||||
vmin, vmax = np.min(stress_values), np.max(stress_values)
|
||||
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
||||
cmap = cm.get_cmap('jet')
|
||||
|
||||
# Plot subset of elements
|
||||
step = max(1, len(self.connectivity) // 500)
|
||||
for i in range(0, min(len(self.connectivity), len(stress_values)), step):
|
||||
elem = self.connectivity[i]
|
||||
if len(elem) < 5:
|
||||
continue
|
||||
|
||||
# Get node IDs and convert to indices
|
||||
node_ids = elem[1:5]
|
||||
try:
|
||||
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
elem_coords = self.node_coords[nodes]
|
||||
|
||||
# Get stress color
|
||||
color = cmap(norm(stress_values[i]))
|
||||
|
||||
# Plot filled quadrilateral
|
||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||
verts = [elem_coords]
|
||||
poly = Poly3DCollection(verts, facecolors=color, edgecolors='k',
|
||||
linewidths=0.1, alpha=0.8)
|
||||
ax.add_collection3d(poly)
|
||||
|
||||
# Colorbar
|
||||
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
|
||||
sm.set_array([])
|
||||
plt.colorbar(sm, ax=ax, label='Stress (MPa)')
|
||||
|
||||
def _plot_nodes_with_values(self, ax, values):
|
||||
"""Helper: Plot nodes colored by values"""
|
||||
scatter = ax.scatter(self.node_coords[:, 0],
|
||||
self.node_coords[:, 1],
|
||||
self.node_coords[:, 2],
|
||||
c=values, cmap='hot', s=2)
|
||||
plt.colorbar(scatter, ax=ax, label='Error (mm)')
|
||||
|
||||
ax.set_xlabel('X (mm)')
|
||||
ax.set_ylabel('Y (mm)')
|
||||
ax.set_zlabel('Z (mm)')
|
||||
|
||||
self._set_equal_aspect(ax)
|
||||
|
||||
def _set_equal_aspect(self, ax):
|
||||
"""Set equal aspect ratio for 3D plot"""
|
||||
# Get limits
|
||||
x_limits = [self.node_coords[:, 0].min(), self.node_coords[:, 0].max()]
|
||||
y_limits = [self.node_coords[:, 1].min(), self.node_coords[:, 1].max()]
|
||||
z_limits = [self.node_coords[:, 2].min(), self.node_coords[:, 2].max()]
|
||||
|
||||
# Find max range
|
||||
max_range = max(x_limits[1] - x_limits[0],
|
||||
y_limits[1] - y_limits[0],
|
||||
z_limits[1] - z_limits[0])
|
||||
|
||||
# Set limits
|
||||
x_middle = np.mean(x_limits)
|
||||
y_middle = np.mean(y_limits)
|
||||
z_middle = np.mean(z_limits)
|
||||
|
||||
ax.set_xlim(x_middle - max_range/2, x_middle + max_range/2)
|
||||
ax.set_ylim(y_middle - max_range/2, y_middle + max_range/2)
|
||||
ax.set_zlim(z_middle - max_range/2, z_middle + max_range/2)
|
||||
|
||||
def create_report(self, output_file='visualization_report.md'):
|
||||
"""
|
||||
Create markdown report with all visualizations
|
||||
|
||||
Args:
|
||||
output_file: Path to save report
|
||||
"""
|
||||
print(f"\nGenerating visualization report...")
|
||||
|
||||
# Create images directory
|
||||
img_dir = Path(output_file).parent / 'visualization_images'
|
||||
img_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate plots
|
||||
print(" Creating mesh plot...")
|
||||
self.plot_mesh(save_path=img_dir / 'mesh.png')
|
||||
plt.close('all')
|
||||
|
||||
print(" Creating displacement plot...")
|
||||
self.plot_displacement(scale=10, save_path=img_dir / 'displacement.png')
|
||||
plt.close('all')
|
||||
|
||||
print(" Creating stress plot...")
|
||||
self.plot_stress(save_path=img_dir / 'stress.png')
|
||||
plt.close('all')
|
||||
|
||||
# Write report
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(f"# FEA Visualization Report\n\n")
|
||||
f.write(f"**Generated:** {self.metadata['metadata']['created_at']}\n\n")
|
||||
f.write(f"**Case:** {self.metadata['metadata']['case_name']}\n\n")
|
||||
|
||||
f.write("---\n\n")
|
||||
|
||||
# Model info
|
||||
f.write("## Model Information\n\n")
|
||||
f.write(f"- **Analysis Type:** {self.metadata['metadata']['analysis_type']}\n")
|
||||
f.write(f"- **Nodes:** {self.metadata['mesh']['statistics']['n_nodes']:,}\n")
|
||||
f.write(f"- **Elements:** {self.metadata['mesh']['statistics']['n_elements']:,}\n")
|
||||
f.write(f"- **Materials:** {len(self.metadata['materials'])}\n\n")
|
||||
|
||||
# Mesh
|
||||
f.write("## Mesh Structure\n\n")
|
||||
f.write("\n\n")
|
||||
f.write(f"The model contains {self.metadata['mesh']['statistics']['n_nodes']:,} nodes ")
|
||||
f.write(f"and {self.metadata['mesh']['statistics']['n_elements']:,} elements.\n\n")
|
||||
|
||||
# Displacement
|
||||
if self.displacement is not None:
|
||||
max_disp = self.metadata['results']['displacement']['max_translation']
|
||||
f.write("## Displacement Results\n\n")
|
||||
f.write("\n\n")
|
||||
f.write(f"**Maximum Displacement:** {max_disp:.6f} mm\n\n")
|
||||
f.write("The plots show the original mesh (left) and deformed mesh (right) ")
|
||||
f.write("with displacement magnitude shown in color.\n\n")
|
||||
|
||||
# Stress
|
||||
if self.stress is not None:
|
||||
f.write("## Stress Results\n\n")
|
||||
f.write("\n\n")
|
||||
|
||||
# Get max stress from metadata
|
||||
if 'stress' in self.metadata['results']:
|
||||
for stress_type, stress_data in self.metadata['results']['stress'].items():
|
||||
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
|
||||
max_stress = stress_data['max_von_mises']
|
||||
f.write(f"**Maximum von Mises Stress:** {max_stress:.2f} MPa\n\n")
|
||||
break
|
||||
|
||||
f.write("The stress distribution is shown with colors representing von Mises stress levels.\n\n")
|
||||
|
||||
# Statistics
|
||||
f.write("## Summary Statistics\n\n")
|
||||
f.write("| Property | Value |\n")
|
||||
f.write("|----------|-------|\n")
|
||||
f.write(f"| Nodes | {self.metadata['mesh']['statistics']['n_nodes']:,} |\n")
|
||||
f.write(f"| Elements | {self.metadata['mesh']['statistics']['n_elements']:,} |\n")
|
||||
|
||||
if self.displacement is not None:
|
||||
max_disp = self.metadata['results']['displacement']['max_translation']
|
||||
f.write(f"| Max Displacement | {max_disp:.6f} mm |\n")
|
||||
|
||||
if self.stress is not None and 'stress' in self.metadata['results']:
|
||||
for stress_type, stress_data in self.metadata['results']['stress'].items():
|
||||
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
|
||||
max_stress = stress_data['max_von_mises']
|
||||
f.write(f"| Max von Mises Stress | {max_stress:.2f} MPa |\n")
|
||||
break
|
||||
|
||||
f.write("\n---\n\n")
|
||||
f.write("*Report generated by AtomizerField Visualizer*\n")
|
||||
|
||||
print(f"\nReport saved to: {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Visualize FEA results in 3D',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Visualize mesh
|
||||
python visualize_results.py test_case_beam --mesh
|
||||
|
||||
# Visualize displacement
|
||||
python visualize_results.py test_case_beam --displacement
|
||||
|
||||
# Visualize stress
|
||||
python visualize_results.py test_case_beam --stress
|
||||
|
||||
# Generate full report
|
||||
python visualize_results.py test_case_beam --report
|
||||
|
||||
# All visualizations
|
||||
python visualize_results.py test_case_beam --all
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('case_dir', help='Path to case directory')
|
||||
parser.add_argument('--mesh', action='store_true', help='Plot mesh structure')
|
||||
parser.add_argument('--displacement', action='store_true', help='Plot displacement field')
|
||||
parser.add_argument('--stress', action='store_true', help='Plot stress field')
|
||||
parser.add_argument('--report', action='store_true', help='Generate markdown report')
|
||||
parser.add_argument('--all', action='store_true', help='Show all plots and generate report')
|
||||
parser.add_argument('--scale', type=float, default=10.0, help='Displacement scale factor (default: 10)')
|
||||
parser.add_argument('--output', default='visualization_report.md', help='Report output file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create visualizer
|
||||
viz = FEAVisualizer(args.case_dir)
|
||||
|
||||
# Determine what to show
|
||||
show_all = args.all or not (args.mesh or args.displacement or args.stress or args.report)
|
||||
|
||||
if args.mesh or show_all:
|
||||
print("\nShowing mesh structure...")
|
||||
viz.plot_mesh()
|
||||
|
||||
if args.displacement or show_all:
|
||||
print("\nShowing displacement field...")
|
||||
viz.plot_displacement(scale=args.scale)
|
||||
|
||||
if args.stress or show_all:
|
||||
print("\nShowing stress field...")
|
||||
viz.plot_stress()
|
||||
|
||||
if args.report or show_all:
|
||||
print("\nGenerating report...")
|
||||
viz.create_report(args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user