Files
Atomizer/atomizer-field/visualize_results.py

516 lines
19 KiB
Python
Raw Normal View History

"""
visualize_results.py
3D Visualization of FEA Results and Neural Predictions
Visualizes:
- Mesh structure
- Displacement fields
- Stress fields (von Mises)
- Comparison: FEA vs Neural predictions
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import json
import h5py
from pathlib import Path
import argparse
class FEAVisualizer:
"""Visualize FEA results in 3D"""
def __init__(self, case_dir):
"""
Initialize visualizer
Args:
case_dir: Path to case directory with neural_field_data files
"""
self.case_dir = Path(case_dir)
# Load data
print(f"Loading data from {case_dir}...")
self.load_data()
def load_data(self):
"""Load JSON metadata and HDF5 field data"""
json_file = self.case_dir / "neural_field_data.json"
h5_file = self.case_dir / "neural_field_data.h5"
# Load JSON
with open(json_file, 'r') as f:
self.metadata = json.load(f)
# Get connectivity from JSON
self.connectivity = []
if 'mesh' in self.metadata and 'elements' in self.metadata['mesh']:
elements = self.metadata['mesh']['elements']
# Elements are categorized by type: solid, shell, beam, rigid
for elem_category in ['solid', 'shell', 'beam']:
if elem_category in elements and isinstance(elements[elem_category], list):
for elem_data in elements[elem_category]:
elem_type = elem_data.get('type', '')
if elem_type in ['CQUAD4', 'CTRIA3', 'CTETRA', 'CHEXA']:
# Store connectivity: [elem_id, n1, n2, n3, n4, ...]
nodes = elem_data.get('nodes', [])
self.connectivity.append([elem_data['id']] + nodes)
self.connectivity = np.array(self.connectivity) if self.connectivity else np.array([[]])
# Load HDF5
with h5py.File(h5_file, 'r') as f:
self.node_coords = f['mesh/node_coordinates'][:]
# Get node IDs to create mapping
self.node_ids = f['mesh/node_ids'][:]
# Create mapping from node ID to index
self.node_id_to_idx = {nid: idx for idx, nid in enumerate(self.node_ids)}
# Displacement
if 'results/displacement' in f:
self.displacement = f['results/displacement'][:]
else:
self.displacement = None
# Stress (try different possible locations)
self.stress = None
if 'results/stress/cquad4_stress/data' in f:
self.stress = f['results/stress/cquad4_stress/data'][:]
elif 'results/stress/cquad4_stress' in f and hasattr(f['results/stress/cquad4_stress'], 'shape'):
self.stress = f['results/stress/cquad4_stress'][:]
elif 'results/stress' in f and hasattr(f['results/stress'], 'shape'):
self.stress = f['results/stress'][:]
print(f"Loaded {len(self.node_coords)} nodes, {len(self.connectivity)} elements")
def plot_mesh(self, figsize=(12, 8), save_path=None):
"""
Plot 3D mesh structure
Args:
figsize: Figure size
save_path: Path to save figure (optional)
"""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
# Extract coordinates
x = self.node_coords[:, 0]
y = self.node_coords[:, 1]
z = self.node_coords[:, 2]
# Plot nodes
ax.scatter(x, y, z, c='blue', marker='.', s=1, alpha=0.3, label='Nodes')
# Plot a subset of elements (for visibility)
step = max(1, len(self.connectivity) // 1000) # Show max 1000 elements
for i in range(0, len(self.connectivity), step):
elem = self.connectivity[i]
if len(elem) < 5: # Skip invalid elements
continue
# Get node IDs (skip element ID at position 0)
node_ids = elem[1:5] # First 4 nodes for CQUAD4
# Convert node IDs to indices
try:
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
except KeyError:
continue # Skip if node not found
# Get coordinates
elem_coords = self.node_coords[nodes]
# Plot element edges
for j in range(min(4, len(nodes))):
next_j = (j + 1) % len(nodes)
ax.plot([elem_coords[j, 0], elem_coords[next_j, 0]],
[elem_coords[j, 1], elem_coords[next_j, 1]],
[elem_coords[j, 2], elem_coords[next_j, 2]],
'k-', linewidth=0.1, alpha=0.1)
ax.set_xlabel('X (mm)')
ax.set_ylabel('Y (mm)')
ax.set_zlabel('Z (mm)')
ax.set_title(f'Mesh Structure\n{len(self.node_coords)} nodes, {len(self.connectivity)} elements')
# Equal aspect ratio
self._set_equal_aspect(ax)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved mesh plot to {save_path}")
plt.show()
def plot_displacement(self, scale=1.0, component='magnitude', figsize=(14, 8), save_path=None):
"""
Plot displacement field
Args:
scale: Scale factor for displacement visualization
component: 'magnitude', 'x', 'y', or 'z'
figsize: Figure size
save_path: Path to save figure
"""
if self.displacement is None:
print("No displacement data available")
return
fig = plt.figure(figsize=figsize)
# Original mesh
ax1 = fig.add_subplot(121, projection='3d')
self._plot_mesh_with_field(ax1, self.displacement, component, scale=0)
ax1.set_title('Original Mesh')
# Deformed mesh
ax2 = fig.add_subplot(122, projection='3d')
self._plot_mesh_with_field(ax2, self.displacement, component, scale=scale)
ax2.set_title(f'Deformed Mesh (scale={scale}x)\nDisplacement: {component}')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved displacement plot to {save_path}")
plt.show()
def plot_stress(self, component='von_mises', figsize=(12, 8), save_path=None):
"""
Plot stress field
Args:
component: 'von_mises', 'xx', 'yy', 'zz', 'xy', 'yz', 'xz'
figsize: Figure size
save_path: Path to save figure
"""
if self.stress is None:
print("No stress data available")
return
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
# Get stress component
if component == 'von_mises':
# Von Mises already computed (last column)
stress_values = self.stress[:, -1]
else:
# Map component name to index
comp_map = {'xx': 0, 'yy': 1, 'zz': 2, 'xy': 3, 'yz': 4, 'xz': 5}
idx = comp_map.get(component, 0)
stress_values = self.stress[:, idx]
# Plot elements colored by stress
self._plot_elements_with_stress(ax, stress_values)
ax.set_xlabel('X (mm)')
ax.set_ylabel('Y (mm)')
ax.set_zlabel('Z (mm)')
ax.set_title(f'Stress Field: {component}\nMax: {np.max(stress_values):.2f} MPa')
self._set_equal_aspect(ax)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved stress plot to {save_path}")
plt.show()
def plot_comparison(self, neural_predictions, figsize=(16, 6), save_path=None):
"""
Plot comparison: FEA vs Neural predictions
Args:
neural_predictions: Dict with 'displacement' and/or 'stress'
figsize: Figure size
save_path: Path to save figure
"""
fig = plt.figure(figsize=figsize)
# Displacement comparison
if self.displacement is not None and 'displacement' in neural_predictions:
ax1 = fig.add_subplot(131, projection='3d')
self._plot_mesh_with_field(ax1, self.displacement, 'magnitude', scale=10)
ax1.set_title('FEA Displacement')
ax2 = fig.add_subplot(132, projection='3d')
neural_disp = neural_predictions['displacement']
self._plot_mesh_with_field(ax2, neural_disp, 'magnitude', scale=10)
ax2.set_title('Neural Prediction')
# Error
ax3 = fig.add_subplot(133, projection='3d')
error = np.linalg.norm(self.displacement[:, :3] - neural_disp[:, :3], axis=1)
self._plot_nodes_with_values(ax3, error)
ax3.set_title('Prediction Error')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved comparison plot to {save_path}")
plt.show()
def _plot_mesh_with_field(self, ax, field, component, scale=1.0):
"""Helper: Plot mesh colored by field values"""
# Get field component
if component == 'magnitude':
values = np.linalg.norm(field[:, :3], axis=1)
elif component == 'x':
values = field[:, 0]
elif component == 'y':
values = field[:, 1]
elif component == 'z':
values = field[:, 2]
else:
values = np.linalg.norm(field[:, :3], axis=1)
# Apply deformation
coords = self.node_coords + scale * field[:, :3]
# Plot nodes colored by values
scatter = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2],
c=values, cmap='jet', s=2)
plt.colorbar(scatter, ax=ax, label=f'{component} (mm)')
ax.set_xlabel('X (mm)')
ax.set_ylabel('Y (mm)')
ax.set_zlabel('Z (mm)')
self._set_equal_aspect(ax)
def _plot_elements_with_stress(self, ax, stress_values):
"""Helper: Plot elements colored by stress"""
# Normalize stress for colormap
vmin, vmax = np.min(stress_values), np.max(stress_values)
norm = plt.Normalize(vmin=vmin, vmax=vmax)
cmap = cm.get_cmap('jet')
# Plot subset of elements
step = max(1, len(self.connectivity) // 500)
for i in range(0, min(len(self.connectivity), len(stress_values)), step):
elem = self.connectivity[i]
if len(elem) < 5:
continue
# Get node IDs and convert to indices
node_ids = elem[1:5]
try:
nodes = [self.node_id_to_idx[nid] for nid in node_ids]
except KeyError:
continue
elem_coords = self.node_coords[nodes]
# Get stress color
color = cmap(norm(stress_values[i]))
# Plot filled quadrilateral
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
verts = [elem_coords]
poly = Poly3DCollection(verts, facecolors=color, edgecolors='k',
linewidths=0.1, alpha=0.8)
ax.add_collection3d(poly)
# Colorbar
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
plt.colorbar(sm, ax=ax, label='Stress (MPa)')
def _plot_nodes_with_values(self, ax, values):
"""Helper: Plot nodes colored by values"""
scatter = ax.scatter(self.node_coords[:, 0],
self.node_coords[:, 1],
self.node_coords[:, 2],
c=values, cmap='hot', s=2)
plt.colorbar(scatter, ax=ax, label='Error (mm)')
ax.set_xlabel('X (mm)')
ax.set_ylabel('Y (mm)')
ax.set_zlabel('Z (mm)')
self._set_equal_aspect(ax)
def _set_equal_aspect(self, ax):
"""Set equal aspect ratio for 3D plot"""
# Get limits
x_limits = [self.node_coords[:, 0].min(), self.node_coords[:, 0].max()]
y_limits = [self.node_coords[:, 1].min(), self.node_coords[:, 1].max()]
z_limits = [self.node_coords[:, 2].min(), self.node_coords[:, 2].max()]
# Find max range
max_range = max(x_limits[1] - x_limits[0],
y_limits[1] - y_limits[0],
z_limits[1] - z_limits[0])
# Set limits
x_middle = np.mean(x_limits)
y_middle = np.mean(y_limits)
z_middle = np.mean(z_limits)
ax.set_xlim(x_middle - max_range/2, x_middle + max_range/2)
ax.set_ylim(y_middle - max_range/2, y_middle + max_range/2)
ax.set_zlim(z_middle - max_range/2, z_middle + max_range/2)
def create_report(self, output_file='visualization_report.md'):
"""
Create markdown report with all visualizations
Args:
output_file: Path to save report
"""
print(f"\nGenerating visualization report...")
# Create images directory
img_dir = Path(output_file).parent / 'visualization_images'
img_dir.mkdir(exist_ok=True)
# Generate plots
print(" Creating mesh plot...")
self.plot_mesh(save_path=img_dir / 'mesh.png')
plt.close('all')
print(" Creating displacement plot...")
self.plot_displacement(scale=10, save_path=img_dir / 'displacement.png')
plt.close('all')
print(" Creating stress plot...")
self.plot_stress(save_path=img_dir / 'stress.png')
plt.close('all')
# Write report
with open(output_file, 'w') as f:
f.write(f"# FEA Visualization Report\n\n")
f.write(f"**Generated:** {self.metadata['metadata']['created_at']}\n\n")
f.write(f"**Case:** {self.metadata['metadata']['case_name']}\n\n")
f.write("---\n\n")
# Model info
f.write("## Model Information\n\n")
f.write(f"- **Analysis Type:** {self.metadata['metadata']['analysis_type']}\n")
f.write(f"- **Nodes:** {self.metadata['mesh']['statistics']['n_nodes']:,}\n")
f.write(f"- **Elements:** {self.metadata['mesh']['statistics']['n_elements']:,}\n")
f.write(f"- **Materials:** {len(self.metadata['materials'])}\n\n")
# Mesh
f.write("## Mesh Structure\n\n")
f.write("![Mesh Structure](visualization_images/mesh.png)\n\n")
f.write(f"The model contains {self.metadata['mesh']['statistics']['n_nodes']:,} nodes ")
f.write(f"and {self.metadata['mesh']['statistics']['n_elements']:,} elements.\n\n")
# Displacement
if self.displacement is not None:
max_disp = self.metadata['results']['displacement']['max_translation']
f.write("## Displacement Results\n\n")
f.write("![Displacement Field](visualization_images/displacement.png)\n\n")
f.write(f"**Maximum Displacement:** {max_disp:.6f} mm\n\n")
f.write("The plots show the original mesh (left) and deformed mesh (right) ")
f.write("with displacement magnitude shown in color.\n\n")
# Stress
if self.stress is not None:
f.write("## Stress Results\n\n")
f.write("![Stress Field](visualization_images/stress.png)\n\n")
# Get max stress from metadata
if 'stress' in self.metadata['results']:
for stress_type, stress_data in self.metadata['results']['stress'].items():
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
max_stress = stress_data['max_von_mises']
f.write(f"**Maximum von Mises Stress:** {max_stress:.2f} MPa\n\n")
break
f.write("The stress distribution is shown with colors representing von Mises stress levels.\n\n")
# Statistics
f.write("## Summary Statistics\n\n")
f.write("| Property | Value |\n")
f.write("|----------|-------|\n")
f.write(f"| Nodes | {self.metadata['mesh']['statistics']['n_nodes']:,} |\n")
f.write(f"| Elements | {self.metadata['mesh']['statistics']['n_elements']:,} |\n")
if self.displacement is not None:
max_disp = self.metadata['results']['displacement']['max_translation']
f.write(f"| Max Displacement | {max_disp:.6f} mm |\n")
if self.stress is not None and 'stress' in self.metadata['results']:
for stress_type, stress_data in self.metadata['results']['stress'].items():
if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None:
max_stress = stress_data['max_von_mises']
f.write(f"| Max von Mises Stress | {max_stress:.2f} MPa |\n")
break
f.write("\n---\n\n")
f.write("*Report generated by AtomizerField Visualizer*\n")
print(f"\nReport saved to: {output_file}")
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description='Visualize FEA results in 3D',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Visualize mesh
python visualize_results.py test_case_beam --mesh
# Visualize displacement
python visualize_results.py test_case_beam --displacement
# Visualize stress
python visualize_results.py test_case_beam --stress
# Generate full report
python visualize_results.py test_case_beam --report
# All visualizations
python visualize_results.py test_case_beam --all
"""
)
parser.add_argument('case_dir', help='Path to case directory')
parser.add_argument('--mesh', action='store_true', help='Plot mesh structure')
parser.add_argument('--displacement', action='store_true', help='Plot displacement field')
parser.add_argument('--stress', action='store_true', help='Plot stress field')
parser.add_argument('--report', action='store_true', help='Generate markdown report')
parser.add_argument('--all', action='store_true', help='Show all plots and generate report')
parser.add_argument('--scale', type=float, default=10.0, help='Displacement scale factor (default: 10)')
parser.add_argument('--output', default='visualization_report.md', help='Report output file')
args = parser.parse_args()
# Create visualizer
viz = FEAVisualizer(args.case_dir)
# Determine what to show
show_all = args.all or not (args.mesh or args.displacement or args.stress or args.report)
if args.mesh or show_all:
print("\nShowing mesh structure...")
viz.plot_mesh()
if args.displacement or show_all:
print("\nShowing displacement field...")
viz.plot_displacement(scale=args.scale)
if args.stress or show_all:
print("\nShowing stress field...")
viz.plot_stress()
if args.report or show_all:
print("\nGenerating report...")
viz.create_report(args.output)
if __name__ == "__main__":
main()