""" visualize_results.py 3D Visualization of FEA Results and Neural Predictions Visualizes: - Mesh structure - Displacement fields - Stress fields (von Mises) - Comparison: FEA vs Neural predictions """ import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot3d import Axes3D import json import h5py from pathlib import Path import argparse class FEAVisualizer: """Visualize FEA results in 3D""" def __init__(self, case_dir): """ Initialize visualizer Args: case_dir: Path to case directory with neural_field_data files """ self.case_dir = Path(case_dir) # Load data print(f"Loading data from {case_dir}...") self.load_data() def load_data(self): """Load JSON metadata and HDF5 field data""" json_file = self.case_dir / "neural_field_data.json" h5_file = self.case_dir / "neural_field_data.h5" # Load JSON with open(json_file, 'r') as f: self.metadata = json.load(f) # Get connectivity from JSON self.connectivity = [] if 'mesh' in self.metadata and 'elements' in self.metadata['mesh']: elements = self.metadata['mesh']['elements'] # Elements are categorized by type: solid, shell, beam, rigid for elem_category in ['solid', 'shell', 'beam']: if elem_category in elements and isinstance(elements[elem_category], list): for elem_data in elements[elem_category]: elem_type = elem_data.get('type', '') if elem_type in ['CQUAD4', 'CTRIA3', 'CTETRA', 'CHEXA']: # Store connectivity: [elem_id, n1, n2, n3, n4, ...] nodes = elem_data.get('nodes', []) self.connectivity.append([elem_data['id']] + nodes) self.connectivity = np.array(self.connectivity) if self.connectivity else np.array([[]]) # Load HDF5 with h5py.File(h5_file, 'r') as f: self.node_coords = f['mesh/node_coordinates'][:] # Get node IDs to create mapping self.node_ids = f['mesh/node_ids'][:] # Create mapping from node ID to index self.node_id_to_idx = {nid: idx for idx, nid in enumerate(self.node_ids)} # Displacement if 'results/displacement' in f: self.displacement = f['results/displacement'][:] else: self.displacement = None # Stress (try different possible locations) self.stress = None if 'results/stress/cquad4_stress/data' in f: self.stress = f['results/stress/cquad4_stress/data'][:] elif 'results/stress/cquad4_stress' in f and hasattr(f['results/stress/cquad4_stress'], 'shape'): self.stress = f['results/stress/cquad4_stress'][:] elif 'results/stress' in f and hasattr(f['results/stress'], 'shape'): self.stress = f['results/stress'][:] print(f"Loaded {len(self.node_coords)} nodes, {len(self.connectivity)} elements") def plot_mesh(self, figsize=(12, 8), save_path=None): """ Plot 3D mesh structure Args: figsize: Figure size save_path: Path to save figure (optional) """ fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection='3d') # Extract coordinates x = self.node_coords[:, 0] y = self.node_coords[:, 1] z = self.node_coords[:, 2] # Plot nodes ax.scatter(x, y, z, c='blue', marker='.', s=1, alpha=0.3, label='Nodes') # Plot a subset of elements (for visibility) step = max(1, len(self.connectivity) // 1000) # Show max 1000 elements for i in range(0, len(self.connectivity), step): elem = self.connectivity[i] if len(elem) < 5: # Skip invalid elements continue # Get node IDs (skip element ID at position 0) node_ids = elem[1:5] # First 4 nodes for CQUAD4 # Convert node IDs to indices try: nodes = [self.node_id_to_idx[nid] for nid in node_ids] except KeyError: continue # Skip if node not found # Get coordinates elem_coords = self.node_coords[nodes] # Plot element edges for j in range(min(4, len(nodes))): next_j = (j + 1) % len(nodes) ax.plot([elem_coords[j, 0], elem_coords[next_j, 0]], [elem_coords[j, 1], elem_coords[next_j, 1]], [elem_coords[j, 2], elem_coords[next_j, 2]], 'k-', linewidth=0.1, alpha=0.1) ax.set_xlabel('X (mm)') ax.set_ylabel('Y (mm)') ax.set_zlabel('Z (mm)') ax.set_title(f'Mesh Structure\n{len(self.node_coords)} nodes, {len(self.connectivity)} elements') # Equal aspect ratio self._set_equal_aspect(ax) if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Saved mesh plot to {save_path}") plt.show() def plot_displacement(self, scale=1.0, component='magnitude', figsize=(14, 8), save_path=None): """ Plot displacement field Args: scale: Scale factor for displacement visualization component: 'magnitude', 'x', 'y', or 'z' figsize: Figure size save_path: Path to save figure """ if self.displacement is None: print("No displacement data available") return fig = plt.figure(figsize=figsize) # Original mesh ax1 = fig.add_subplot(121, projection='3d') self._plot_mesh_with_field(ax1, self.displacement, component, scale=0) ax1.set_title('Original Mesh') # Deformed mesh ax2 = fig.add_subplot(122, projection='3d') self._plot_mesh_with_field(ax2, self.displacement, component, scale=scale) ax2.set_title(f'Deformed Mesh (scale={scale}x)\nDisplacement: {component}') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Saved displacement plot to {save_path}") plt.show() def plot_stress(self, component='von_mises', figsize=(12, 8), save_path=None): """ Plot stress field Args: component: 'von_mises', 'xx', 'yy', 'zz', 'xy', 'yz', 'xz' figsize: Figure size save_path: Path to save figure """ if self.stress is None: print("No stress data available") return fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection='3d') # Get stress component if component == 'von_mises': # Von Mises already computed (last column) stress_values = self.stress[:, -1] else: # Map component name to index comp_map = {'xx': 0, 'yy': 1, 'zz': 2, 'xy': 3, 'yz': 4, 'xz': 5} idx = comp_map.get(component, 0) stress_values = self.stress[:, idx] # Plot elements colored by stress self._plot_elements_with_stress(ax, stress_values) ax.set_xlabel('X (mm)') ax.set_ylabel('Y (mm)') ax.set_zlabel('Z (mm)') ax.set_title(f'Stress Field: {component}\nMax: {np.max(stress_values):.2f} MPa') self._set_equal_aspect(ax) if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Saved stress plot to {save_path}") plt.show() def plot_comparison(self, neural_predictions, figsize=(16, 6), save_path=None): """ Plot comparison: FEA vs Neural predictions Args: neural_predictions: Dict with 'displacement' and/or 'stress' figsize: Figure size save_path: Path to save figure """ fig = plt.figure(figsize=figsize) # Displacement comparison if self.displacement is not None and 'displacement' in neural_predictions: ax1 = fig.add_subplot(131, projection='3d') self._plot_mesh_with_field(ax1, self.displacement, 'magnitude', scale=10) ax1.set_title('FEA Displacement') ax2 = fig.add_subplot(132, projection='3d') neural_disp = neural_predictions['displacement'] self._plot_mesh_with_field(ax2, neural_disp, 'magnitude', scale=10) ax2.set_title('Neural Prediction') # Error ax3 = fig.add_subplot(133, projection='3d') error = np.linalg.norm(self.displacement[:, :3] - neural_disp[:, :3], axis=1) self._plot_nodes_with_values(ax3, error) ax3.set_title('Prediction Error') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Saved comparison plot to {save_path}") plt.show() def _plot_mesh_with_field(self, ax, field, component, scale=1.0): """Helper: Plot mesh colored by field values""" # Get field component if component == 'magnitude': values = np.linalg.norm(field[:, :3], axis=1) elif component == 'x': values = field[:, 0] elif component == 'y': values = field[:, 1] elif component == 'z': values = field[:, 2] else: values = np.linalg.norm(field[:, :3], axis=1) # Apply deformation coords = self.node_coords + scale * field[:, :3] # Plot nodes colored by values scatter = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=values, cmap='jet', s=2) plt.colorbar(scatter, ax=ax, label=f'{component} (mm)') ax.set_xlabel('X (mm)') ax.set_ylabel('Y (mm)') ax.set_zlabel('Z (mm)') self._set_equal_aspect(ax) def _plot_elements_with_stress(self, ax, stress_values): """Helper: Plot elements colored by stress""" # Normalize stress for colormap vmin, vmax = np.min(stress_values), np.max(stress_values) norm = plt.Normalize(vmin=vmin, vmax=vmax) cmap = cm.get_cmap('jet') # Plot subset of elements step = max(1, len(self.connectivity) // 500) for i in range(0, min(len(self.connectivity), len(stress_values)), step): elem = self.connectivity[i] if len(elem) < 5: continue # Get node IDs and convert to indices node_ids = elem[1:5] try: nodes = [self.node_id_to_idx[nid] for nid in node_ids] except KeyError: continue elem_coords = self.node_coords[nodes] # Get stress color color = cmap(norm(stress_values[i])) # Plot filled quadrilateral from mpl_toolkits.mplot3d.art3d import Poly3DCollection verts = [elem_coords] poly = Poly3DCollection(verts, facecolors=color, edgecolors='k', linewidths=0.1, alpha=0.8) ax.add_collection3d(poly) # Colorbar sm = cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) plt.colorbar(sm, ax=ax, label='Stress (MPa)') def _plot_nodes_with_values(self, ax, values): """Helper: Plot nodes colored by values""" scatter = ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], self.node_coords[:, 2], c=values, cmap='hot', s=2) plt.colorbar(scatter, ax=ax, label='Error (mm)') ax.set_xlabel('X (mm)') ax.set_ylabel('Y (mm)') ax.set_zlabel('Z (mm)') self._set_equal_aspect(ax) def _set_equal_aspect(self, ax): """Set equal aspect ratio for 3D plot""" # Get limits x_limits = [self.node_coords[:, 0].min(), self.node_coords[:, 0].max()] y_limits = [self.node_coords[:, 1].min(), self.node_coords[:, 1].max()] z_limits = [self.node_coords[:, 2].min(), self.node_coords[:, 2].max()] # Find max range max_range = max(x_limits[1] - x_limits[0], y_limits[1] - y_limits[0], z_limits[1] - z_limits[0]) # Set limits x_middle = np.mean(x_limits) y_middle = np.mean(y_limits) z_middle = np.mean(z_limits) ax.set_xlim(x_middle - max_range/2, x_middle + max_range/2) ax.set_ylim(y_middle - max_range/2, y_middle + max_range/2) ax.set_zlim(z_middle - max_range/2, z_middle + max_range/2) def create_report(self, output_file='visualization_report.md'): """ Create markdown report with all visualizations Args: output_file: Path to save report """ print(f"\nGenerating visualization report...") # Create images directory img_dir = Path(output_file).parent / 'visualization_images' img_dir.mkdir(exist_ok=True) # Generate plots print(" Creating mesh plot...") self.plot_mesh(save_path=img_dir / 'mesh.png') plt.close('all') print(" Creating displacement plot...") self.plot_displacement(scale=10, save_path=img_dir / 'displacement.png') plt.close('all') print(" Creating stress plot...") self.plot_stress(save_path=img_dir / 'stress.png') plt.close('all') # Write report with open(output_file, 'w') as f: f.write(f"# FEA Visualization Report\n\n") f.write(f"**Generated:** {self.metadata['metadata']['created_at']}\n\n") f.write(f"**Case:** {self.metadata['metadata']['case_name']}\n\n") f.write("---\n\n") # Model info f.write("## Model Information\n\n") f.write(f"- **Analysis Type:** {self.metadata['metadata']['analysis_type']}\n") f.write(f"- **Nodes:** {self.metadata['mesh']['statistics']['n_nodes']:,}\n") f.write(f"- **Elements:** {self.metadata['mesh']['statistics']['n_elements']:,}\n") f.write(f"- **Materials:** {len(self.metadata['materials'])}\n\n") # Mesh f.write("## Mesh Structure\n\n") f.write("![Mesh Structure](visualization_images/mesh.png)\n\n") f.write(f"The model contains {self.metadata['mesh']['statistics']['n_nodes']:,} nodes ") f.write(f"and {self.metadata['mesh']['statistics']['n_elements']:,} elements.\n\n") # Displacement if self.displacement is not None: max_disp = self.metadata['results']['displacement']['max_translation'] f.write("## Displacement Results\n\n") f.write("![Displacement Field](visualization_images/displacement.png)\n\n") f.write(f"**Maximum Displacement:** {max_disp:.6f} mm\n\n") f.write("The plots show the original mesh (left) and deformed mesh (right) ") f.write("with displacement magnitude shown in color.\n\n") # Stress if self.stress is not None: f.write("## Stress Results\n\n") f.write("![Stress Field](visualization_images/stress.png)\n\n") # Get max stress from metadata if 'stress' in self.metadata['results']: for stress_type, stress_data in self.metadata['results']['stress'].items(): if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None: max_stress = stress_data['max_von_mises'] f.write(f"**Maximum von Mises Stress:** {max_stress:.2f} MPa\n\n") break f.write("The stress distribution is shown with colors representing von Mises stress levels.\n\n") # Statistics f.write("## Summary Statistics\n\n") f.write("| Property | Value |\n") f.write("|----------|-------|\n") f.write(f"| Nodes | {self.metadata['mesh']['statistics']['n_nodes']:,} |\n") f.write(f"| Elements | {self.metadata['mesh']['statistics']['n_elements']:,} |\n") if self.displacement is not None: max_disp = self.metadata['results']['displacement']['max_translation'] f.write(f"| Max Displacement | {max_disp:.6f} mm |\n") if self.stress is not None and 'stress' in self.metadata['results']: for stress_type, stress_data in self.metadata['results']['stress'].items(): if 'max_von_mises' in stress_data and stress_data['max_von_mises'] is not None: max_stress = stress_data['max_von_mises'] f.write(f"| Max von Mises Stress | {max_stress:.2f} MPa |\n") break f.write("\n---\n\n") f.write("*Report generated by AtomizerField Visualizer*\n") print(f"\nReport saved to: {output_file}") def main(): """Main entry point""" parser = argparse.ArgumentParser( description='Visualize FEA results in 3D', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Visualize mesh python visualize_results.py test_case_beam --mesh # Visualize displacement python visualize_results.py test_case_beam --displacement # Visualize stress python visualize_results.py test_case_beam --stress # Generate full report python visualize_results.py test_case_beam --report # All visualizations python visualize_results.py test_case_beam --all """ ) parser.add_argument('case_dir', help='Path to case directory') parser.add_argument('--mesh', action='store_true', help='Plot mesh structure') parser.add_argument('--displacement', action='store_true', help='Plot displacement field') parser.add_argument('--stress', action='store_true', help='Plot stress field') parser.add_argument('--report', action='store_true', help='Generate markdown report') parser.add_argument('--all', action='store_true', help='Show all plots and generate report') parser.add_argument('--scale', type=float, default=10.0, help='Displacement scale factor (default: 10)') parser.add_argument('--output', default='visualization_report.md', help='Report output file') args = parser.parse_args() # Create visualizer viz = FEAVisualizer(args.case_dir) # Determine what to show show_all = args.all or not (args.mesh or args.displacement or args.stress or args.report) if args.mesh or show_all: print("\nShowing mesh structure...") viz.plot_mesh() if args.displacement or show_all: print("\nShowing displacement field...") viz.plot_displacement(scale=args.scale) if args.stress or show_all: print("\nShowing stress field...") viz.plot_stress() if args.report or show_all: print("\nGenerating report...") viz.create_report(args.output) if __name__ == "__main__": main()