Permanently integrates the Atomizer-Field GNN surrogate system: - neural_models/: Graph Neural Network for FEA field prediction - batch_parser.py: Parse training data from FEA exports - train.py: Neural network training pipeline - predict.py: Inference engine for fast predictions This enables 600x-2200x speedup over traditional FEA by replacing expensive simulations with millisecond neural network predictions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
516 lines
19 KiB
Python
516 lines
19 KiB
Python
"""
|
|
visualize_results.py
|
|
3D Visualization of FEA Results and Neural Predictions
|
|
|
|
Visualizes:
|
|
- Mesh structure
|
|
- Displacement fields
|
|
- Stress fields (von Mises)
|
|
- Comparison: FEA vs Neural predictions
|
|
"""
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import cm
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
import json
|
|
import h5py
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
|
|
class FEAVisualizer:
|
|
"""Visualize FEA results in 3D"""
|
|
|
|
def __init__(self, case_dir):
|
|
"""
|
|
Initialize visualizer
|
|
|
|
Args:
|
|
case_dir: Path to case directory with neural_field_data files
|
|
"""
|
|
self.case_dir = Path(case_dir)
|
|
|
|
# Load data
|
|
print(f"Loading data from {case_dir}...")
|
|
self.load_data()
|
|
|
|
def load_data(self):
|
|
"""Load JSON metadata and HDF5 field data"""
|
|
json_file = self.case_dir / "neural_field_data.json"
|
|
h5_file = self.case_dir / "neural_field_data.h5"
|
|
|
|
# Load JSON
|
|
with open(json_file, 'r') as f:
|
|
self.metadata = json.load(f)
|
|
|
|
# Get connectivity from JSON
|
|
self.connectivity = []
|
|
if 'mesh' in self.metadata and 'elements' in self.metadata['mesh']:
|
|
elements = self.metadata['mesh']['elements']
|
|
# Elements are categorized by type: solid, shell, beam, rigid
|
|
for elem_category in ['solid', 'shell', 'beam']:
|
|
if elem_category in elements and isinstance(elements[elem_category], list):
|
|
for elem_data in elements[elem_category]:
|
|
elem_type = elem_data.get('type', '')
|
|
if elem_type in ['CQUAD4', 'CTRIA3', 'CTETRA', 'CHEXA']:
|
|
# Store connectivity: [elem_id, n1, n2, n3, n4, ...]
|
|
nodes = elem_data.get('nodes', [])
|
|
self.connectivity.append([elem_data['id']] + nodes)
|
|
self.connectivity = np.array(self.connectivity) if self.connectivity else np.array([[]])
|
|
|
|
# Load HDF5
|
|
with h5py.File(h5_file, 'r') as f:
|
|
self.node_coords = f['mesh/node_coordinates'][:]
|
|
|
|
# Get node IDs to create mapping
|
|
self.node_ids = f['mesh/node_ids'][:]
|
|
# Create mapping from node ID to index
|
|
self.node_id_to_idx = {nid: idx for idx, nid in enumerate(self.node_ids)}
|
|
|
|
# Displacement
|
|
if 'results/displacement' in f:
|
|
self.displacement = f['results/displacement'][:]
|
|
else:
|
|
self.displacement = None
|
|
|
|
# Stress (try different possible locations)
|
|
self.stress = None
|
|
if 'results/stress/cquad4_stress/data' in f:
|
|
self.stress = f['results/stress/cquad4_stress/data'][:]
|
|
elif 'results/stress/cquad4_stress' in f and hasattr(f['results/stress/cquad4_stress'], 'shape'):
|
|
self.stress = f['results/stress/cquad4_stress'][:]
|
|
elif 'results/stress' in f and hasattr(f['results/stress'], 'shape'):
|
|
self.stress = f['results/stress'][:]
|
|
|
|
print(f"Loaded {len(self.node_coords)} nodes, {len(self.connectivity)} elements")
|
|
|
|
def plot_mesh(self, figsize=(12, 8), save_path=None):
|
|
"""
|
|
Plot 3D mesh structure
|
|
|
|
Args:
|
|
figsize: Figure size
|
|
save_path: Path to save figure (optional)
|
|
"""
|
|
fig = plt.figure(figsize=figsize)
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
# Extract coordinates
|
|
x = self.node_coords[:, 0]
|
|
y = self.node_coords[:, 1]
|
|
z = self.node_coords[:, 2]
|
|
|
|
# Plot nodes
|
|
ax.scatter(x, y, z, c='blue', marker='.', s=1, alpha=0.3, label='Nodes')
|
|
|
|
# Plot a subset of elements (for visibility)
|
|
step = max(1, len(self.connectivity) // 1000) # Show max 1000 elements
|
|
for i in range(0, len(self.connectivity), step):
|
|
elem = self.connectivity[i]
|
|
if len(elem) < 5: # Skip invalid elements
|
|
continue
|
|
|
|
# Get node IDs (skip element ID at position 0)
|
|
node_ids = elem[1:5] # First 4 nodes for CQUAD4
|
|
|
|
# Convert node IDs to indices
|
|
try:
|
|
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
|
|
except KeyError:
|
|
continue # Skip if node not found
|
|
|
|
# Get coordinates
|
|
elem_coords = self.node_coords[nodes]
|
|
|
|
# Plot element edges
|
|
for j in range(min(4, len(nodes))):
|
|
next_j = (j + 1) % len(nodes)
|
|
ax.plot([elem_coords[j, 0], elem_coords[next_j, 0]],
|
|
[elem_coords[j, 1], elem_coords[next_j, 1]],
|
|
[elem_coords[j, 2], elem_coords[next_j, 2]],
|
|
'k-', linewidth=0.1, alpha=0.1)
|
|
|
|
ax.set_xlabel('X (mm)')
|
|
ax.set_ylabel('Y (mm)')
|
|
ax.set_zlabel('Z (mm)')
|
|
ax.set_title(f'Mesh Structure\n{len(self.node_coords)} nodes, {len(self.connectivity)} elements')
|
|
|
|
# Equal aspect ratio
|
|
self._set_equal_aspect(ax)
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"Saved mesh plot to {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def plot_displacement(self, scale=1.0, component='magnitude', figsize=(14, 8), save_path=None):
|
|
"""
|
|
Plot displacement field
|
|
|
|
Args:
|
|
scale: Scale factor for displacement visualization
|
|
component: 'magnitude', 'x', 'y', or 'z'
|
|
figsize: Figure size
|
|
save_path: Path to save figure
|
|
"""
|
|
if self.displacement is None:
|
|
print("No displacement data available")
|
|
return
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
# Original mesh
|
|
ax1 = fig.add_subplot(121, projection='3d')
|
|
self._plot_mesh_with_field(ax1, self.displacement, component, scale=0)
|
|
ax1.set_title('Original Mesh')
|
|
|
|
# Deformed mesh
|
|
ax2 = fig.add_subplot(122, projection='3d')
|
|
self._plot_mesh_with_field(ax2, self.displacement, component, scale=scale)
|
|
ax2.set_title(f'Deformed Mesh (scale={scale}x)\nDisplacement: {component}')
|
|
|
|
plt.tight_layout()
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"Saved displacement plot to {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def plot_stress(self, component='von_mises', figsize=(12, 8), save_path=None):
|
|
"""
|
|
Plot stress field
|
|
|
|
Args:
|
|
component: 'von_mises', 'xx', 'yy', 'zz', 'xy', 'yz', 'xz'
|
|
figsize: Figure size
|
|
save_path: Path to save figure
|
|
"""
|
|
if self.stress is None:
|
|
print("No stress data available")
|
|
return
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
# Get stress component
|
|
if component == 'von_mises':
|
|
# Von Mises already computed (last column)
|
|
stress_values = self.stress[:, -1]
|
|
else:
|
|
# Map component name to index
|
|
comp_map = {'xx': 0, 'yy': 1, 'zz': 2, 'xy': 3, 'yz': 4, 'xz': 5}
|
|
idx = comp_map.get(component, 0)
|
|
stress_values = self.stress[:, idx]
|
|
|
|
# Plot elements colored by stress
|
|
self._plot_elements_with_stress(ax, stress_values)
|
|
|
|
ax.set_xlabel('X (mm)')
|
|
ax.set_ylabel('Y (mm)')
|
|
ax.set_zlabel('Z (mm)')
|
|
ax.set_title(f'Stress Field: {component}\nMax: {np.max(stress_values):.2f} MPa')
|
|
|
|
self._set_equal_aspect(ax)
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"Saved stress plot to {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def plot_comparison(self, neural_predictions, figsize=(16, 6), save_path=None):
|
|
"""
|
|
Plot comparison: FEA vs Neural predictions
|
|
|
|
Args:
|
|
neural_predictions: Dict with 'displacement' and/or 'stress'
|
|
figsize: Figure size
|
|
save_path: Path to save figure
|
|
"""
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
# Displacement comparison
|
|
if self.displacement is not None and 'displacement' in neural_predictions:
|
|
ax1 = fig.add_subplot(131, projection='3d')
|
|
self._plot_mesh_with_field(ax1, self.displacement, 'magnitude', scale=10)
|
|
ax1.set_title('FEA Displacement')
|
|
|
|
ax2 = fig.add_subplot(132, projection='3d')
|
|
neural_disp = neural_predictions['displacement']
|
|
self._plot_mesh_with_field(ax2, neural_disp, 'magnitude', scale=10)
|
|
ax2.set_title('Neural Prediction')
|
|
|
|
# Error
|
|
ax3 = fig.add_subplot(133, projection='3d')
|
|
error = np.linalg.norm(self.displacement[:, :3] - neural_disp[:, :3], axis=1)
|
|
self._plot_nodes_with_values(ax3, error)
|
|
ax3.set_title('Prediction Error')
|
|
|
|
plt.tight_layout()
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"Saved comparison plot to {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def _plot_mesh_with_field(self, ax, field, component, scale=1.0):
|
|
"""Helper: Plot mesh colored by field values"""
|
|
# Get field component
|
|
if component == 'magnitude':
|
|
values = np.linalg.norm(field[:, :3], axis=1)
|
|
elif component == 'x':
|
|
values = field[:, 0]
|
|
elif component == 'y':
|
|
values = field[:, 1]
|
|
elif component == 'z':
|
|
values = field[:, 2]
|
|
else:
|
|
values = np.linalg.norm(field[:, :3], axis=1)
|
|
|
|
# Apply deformation
|
|
coords = self.node_coords + scale * field[:, :3]
|
|
|
|
# Plot nodes colored by values
|
|
scatter = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2],
|
|
c=values, cmap='jet', s=2)
|
|
plt.colorbar(scatter, ax=ax, label=f'{component} (mm)')
|
|
|
|
ax.set_xlabel('X (mm)')
|
|
ax.set_ylabel('Y (mm)')
|
|
ax.set_zlabel('Z (mm)')
|
|
|
|
self._set_equal_aspect(ax)
|
|
|
|
def _plot_elements_with_stress(self, ax, stress_values):
|
|
"""Helper: Plot elements colored by stress"""
|
|
# Normalize stress for colormap
|
|
vmin, vmax = np.min(stress_values), np.max(stress_values)
|
|
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
cmap = cm.get_cmap('jet')
|
|
|
|
# Plot subset of elements
|
|
step = max(1, len(self.connectivity) // 500)
|
|
for i in range(0, min(len(self.connectivity), len(stress_values)), step):
|
|
elem = self.connectivity[i]
|
|
if len(elem) < 5:
|
|
continue
|
|
|
|
# Get node IDs and convert to indices
|
|
node_ids = elem[1:5]
|
|
try:
|
|
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
|
|
except KeyError:
|
|
continue
|
|
|
|
elem_coords = self.node_coords[nodes]
|
|
|
|
# Get stress color
|
|
color = cmap(norm(stress_values[i]))
|
|
|
|
# Plot filled quadrilateral
|
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
|
verts = [elem_coords]
|
|
poly = Poly3DCollection(verts, facecolors=color, edgecolors='k',
|
|
linewidths=0.1, alpha=0.8)
|
|
ax.add_collection3d(poly)
|
|
|
|
# Colorbar
|
|
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
sm.set_array([])
|
|
plt.colorbar(sm, ax=ax, label='Stress (MPa)')
|
|
|
|
def _plot_nodes_with_values(self, ax, values):
|
|
"""Helper: Plot nodes colored by values"""
|
|
scatter = ax.scatter(self.node_coords[:, 0],
|
|
self.node_coords[:, 1],
|
|
self.node_coords[:, 2],
|
|
c=values, cmap='hot', s=2)
|
|
plt.colorbar(scatter, ax=ax, label='Error (mm)')
|
|
|
|
ax.set_xlabel('X (mm)')
|
|
ax.set_ylabel('Y (mm)')
|
|
ax.set_zlabel('Z (mm)')
|
|
|
|
self._set_equal_aspect(ax)
|
|
|
|
def _set_equal_aspect(self, ax):
|
|
"""Set equal aspect ratio for 3D plot"""
|
|
# Get limits
|
|
x_limits = [self.node_coords[:, 0].min(), self.node_coords[:, 0].max()]
|
|
y_limits = [self.node_coords[:, 1].min(), self.node_coords[:, 1].max()]
|
|
z_limits = [self.node_coords[:, 2].min(), self.node_coords[:, 2].max()]
|
|
|
|
# Find max range
|
|
max_range = max(x_limits[1] - x_limits[0],
|
|
y_limits[1] - y_limits[0],
|
|
z_limits[1] - z_limits[0])
|
|
|
|
# Set limits
|
|
x_middle = np.mean(x_limits)
|
|
y_middle = np.mean(y_limits)
|
|
z_middle = np.mean(z_limits)
|
|
|
|
ax.set_xlim(x_middle - max_range/2, x_middle + max_range/2)
|
|
ax.set_ylim(y_middle - max_range/2, y_middle + max_range/2)
|
|
ax.set_zlim(z_middle - max_range/2, z_middle + max_range/2)
|
|
|
|
def create_report(self, output_file='visualization_report.md'):
|
|
"""
|
|
Create markdown report with all visualizations
|
|
|
|
Args:
|
|
output_file: Path to save report
|
|
"""
|
|
print(f"\nGenerating visualization report...")
|
|
|
|
# Create images directory
|
|
img_dir = Path(output_file).parent / 'visualization_images'
|
|
img_dir.mkdir(exist_ok=True)
|
|
|
|
# Generate plots
|
|
print(" Creating mesh plot...")
|
|
self.plot_mesh(save_path=img_dir / 'mesh.png')
|
|
plt.close('all')
|
|
|
|
print(" Creating displacement plot...")
|
|
self.plot_displacement(scale=10, save_path=img_dir / 'displacement.png')
|
|
plt.close('all')
|
|
|
|
print(" Creating stress plot...")
|
|
self.plot_stress(save_path=img_dir / 'stress.png')
|
|
plt.close('all')
|
|
|
|
# Write report
|
|
with open(output_file, 'w') as f:
|
|
f.write(f"# FEA Visualization Report\n\n")
|
|
f.write(f"**Generated:** {self.metadata['metadata']['created_at']}\n\n")
|
|
f.write(f"**Case:** {self.metadata['metadata']['case_name']}\n\n")
|
|
|
|
f.write("---\n\n")
|
|
|
|
# Model info
|
|
f.write("## Model Information\n\n")
|
|
f.write(f"- **Analysis Type:** {self.metadata['metadata']['analysis_type']}\n")
|
|
f.write(f"- **Nodes:** {self.metadata['mesh']['statistics']['n_nodes']:,}\n")
|
|
f.write(f"- **Elements:** {self.metadata['mesh']['statistics']['n_elements']:,}\n")
|
|
f.write(f"- **Materials:** {len(self.metadata['materials'])}\n\n")
|
|
|
|
# Mesh
|
|
f.write("## Mesh Structure\n\n")
|
|
f.write("\n\n")
|
|
f.write(f"The model contains {self.metadata['mesh']['statistics']['n_nodes']:,} nodes ")
|
|
f.write(f"and {self.metadata['mesh']['statistics']['n_elements']:,} elements.\n\n")
|
|
|
|
# Displacement
|
|
if self.displacement is not None:
|
|
max_disp = self.metadata['results']['displacement']['max_translation']
|
|
f.write("## Displacement Results\n\n")
|
|
f.write("\n\n")
|
|
f.write(f"**Maximum Displacement:** {max_disp:.6f} mm\n\n")
|
|
f.write("The plots show the original mesh (left) and deformed mesh (right) ")
|
|
f.write("with displacement magnitude shown in color.\n\n")
|
|
|
|
# Stress
|
|
if self.stress is not None:
|
|
f.write("## Stress Results\n\n")
|
|
f.write("\n\n")
|
|
|
|
# Get max stress from metadata
|
|
if 'stress' in self.metadata['results']:
|
|
for stress_type, stress_data in self.metadata['results']['stress'].items():
|
|
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
|
|
max_stress = stress_data['max_von_mises']
|
|
f.write(f"**Maximum von Mises Stress:** {max_stress:.2f} MPa\n\n")
|
|
break
|
|
|
|
f.write("The stress distribution is shown with colors representing von Mises stress levels.\n\n")
|
|
|
|
# Statistics
|
|
f.write("## Summary Statistics\n\n")
|
|
f.write("| Property | Value |\n")
|
|
f.write("|----------|-------|\n")
|
|
f.write(f"| Nodes | {self.metadata['mesh']['statistics']['n_nodes']:,} |\n")
|
|
f.write(f"| Elements | {self.metadata['mesh']['statistics']['n_elements']:,} |\n")
|
|
|
|
if self.displacement is not None:
|
|
max_disp = self.metadata['results']['displacement']['max_translation']
|
|
f.write(f"| Max Displacement | {max_disp:.6f} mm |\n")
|
|
|
|
if self.stress is not None and 'stress' in self.metadata['results']:
|
|
for stress_type, stress_data in self.metadata['results']['stress'].items():
|
|
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
|
|
max_stress = stress_data['max_von_mises']
|
|
f.write(f"| Max von Mises Stress | {max_stress:.2f} MPa |\n")
|
|
break
|
|
|
|
f.write("\n---\n\n")
|
|
f.write("*Report generated by AtomizerField Visualizer*\n")
|
|
|
|
print(f"\nReport saved to: {output_file}")
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Visualize FEA results in 3D',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Visualize mesh
|
|
python visualize_results.py test_case_beam --mesh
|
|
|
|
# Visualize displacement
|
|
python visualize_results.py test_case_beam --displacement
|
|
|
|
# Visualize stress
|
|
python visualize_results.py test_case_beam --stress
|
|
|
|
# Generate full report
|
|
python visualize_results.py test_case_beam --report
|
|
|
|
# All visualizations
|
|
python visualize_results.py test_case_beam --all
|
|
"""
|
|
)
|
|
|
|
parser.add_argument('case_dir', help='Path to case directory')
|
|
parser.add_argument('--mesh', action='store_true', help='Plot mesh structure')
|
|
parser.add_argument('--displacement', action='store_true', help='Plot displacement field')
|
|
parser.add_argument('--stress', action='store_true', help='Plot stress field')
|
|
parser.add_argument('--report', action='store_true', help='Generate markdown report')
|
|
parser.add_argument('--all', action='store_true', help='Show all plots and generate report')
|
|
parser.add_argument('--scale', type=float, default=10.0, help='Displacement scale factor (default: 10)')
|
|
parser.add_argument('--output', default='visualization_report.md', help='Report output file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create visualizer
|
|
viz = FEAVisualizer(args.case_dir)
|
|
|
|
# Determine what to show
|
|
show_all = args.all or not (args.mesh or args.displacement or args.stress or args.report)
|
|
|
|
if args.mesh or show_all:
|
|
print("\nShowing mesh structure...")
|
|
viz.plot_mesh()
|
|
|
|
if args.displacement or show_all:
|
|
print("\nShowing displacement field...")
|
|
viz.plot_displacement(scale=args.scale)
|
|
|
|
if args.stress or show_all:
|
|
print("\nShowing stress field...")
|
|
viz.plot_stress()
|
|
|
|
if args.report or show_all:
|
|
print("\nGenerating report...")
|
|
viz.create_report(args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|