From a24e3f750c9d34a0f5a2422c297eb74d00bbd60d Mon Sep 17 00:00:00 2001 From: Anto01 Date: Sat, 15 Nov 2025 14:46:49 -0500 Subject: [PATCH] feat: Implement Phase 1 - Plugin & Hook System MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- optimization_engine/feature_registry.json | 243 ++++++++++ optimization_engine/plugins/__init__.py | 32 ++ .../plugins/custom_objectives/__init__.py | 1 + optimization_engine/plugins/hook_manager.py | 303 ++++++++++++ optimization_engine/plugins/hooks.py | 115 +++++ .../plugins/post_extraction/__init__.py | 1 + .../plugins/post_mesh/__init__.py | 1 + .../plugins/post_solve/__init__.py | 1 + .../plugins/pre_mesh/__init__.py | 1 + .../plugins/pre_solve/__init__.py | 1 + .../plugins/pre_solve/log_trial_start.py | 56 +++ optimization_engine/plugins/validators.py | 214 +++++++++ optimization_engine/runner.py | 66 +++ tests/test_plugin_system.py | 438 ++++++++++++++++++ 14 files changed, 1473 insertions(+) create mode 100644 optimization_engine/feature_registry.json create mode 100644 optimization_engine/plugins/__init__.py create mode 100644 optimization_engine/plugins/custom_objectives/__init__.py create mode 100644 optimization_engine/plugins/hook_manager.py create mode 100644 optimization_engine/plugins/hooks.py create mode 100644 optimization_engine/plugins/post_extraction/__init__.py create mode 100644 optimization_engine/plugins/post_mesh/__init__.py create mode 100644 optimization_engine/plugins/post_solve/__init__.py create mode 100644 optimization_engine/plugins/pre_mesh/__init__.py create mode 100644 optimization_engine/plugins/pre_solve/__init__.py create mode 100644 optimization_engine/plugins/pre_solve/log_trial_start.py create mode 100644 optimization_engine/plugins/validators.py create mode 100644 tests/test_plugin_system.py diff --git a/optimization_engine/feature_registry.json b/optimization_engine/feature_registry.json new file mode 100644 index 00000000..99d9c2bc --- /dev/null +++ b/optimization_engine/feature_registry.json @@ -0,0 +1,243 @@ +{ + "version": "1.0.0", + "last_updated": "2025-01-15", + "description": "Registry of all Atomizer capabilities for LLM discovery and usage", + + "core_features": { + "optimization": { + "description": "Core optimization engine using Optuna", + "module": "optimization_engine.runner", + "capabilities": [ + "Multi-objective optimization with weighted sum", + "TPE (Tree-structured Parzen Estimator) sampler", + "CMA-ES sampler", + "Gaussian Process sampler", + "50-trial default with 20 startup trials", + "Automatic checkpoint and resume", + "SQLite-based study persistence" + ], + "usage": "python examples/test_journal_optimization.py", + "llm_hint": "Use this for Bayesian optimization with NX simulations" + }, + + "nx_integration": { + "description": "Siemens NX simulation automation via journal scripts", + "module": "optimization_engine.nx_solver", + "capabilities": [ + "Update CAD expressions via NXOpen", + "Execute NX Nastran solver", + "Extract OP2 results (stress, displacement)", + "Extract mass properties", + "Precision control (4 decimals for mm/degrees/MPa)" + ], + "usage": "from optimization_engine.nx_solver import run_nx_simulation", + "llm_hint": "Use for running FEA simulations and extracting results" + }, + + "result_extraction": { + "description": "Extract metrics from simulation results", + "module": "optimization_engine.result_extractors", + "extractors": { + "stress_extractor": { + "description": "Extract stress data from OP2 files", + "metrics": ["max_von_mises", "mean_von_mises", "max_principal"], + "file_type": "OP2", + "usage": "Returns stress in MPa" + }, + "displacement_extractor": { + "description": "Extract displacement data from OP2 files", + "metrics": ["max_displacement", "mean_displacement"], + "file_type": "OP2", + "usage": "Returns displacement in mm" + }, + "mass_extractor": { + "description": "Extract mass properties", + "metrics": ["total_mass", "center_of_gravity"], + "file_type": "NX Part", + "usage": "Returns mass in kg" + } + }, + "llm_hint": "Use extractors to define objectives and constraints" + } + }, + + "plugin_system": { + "description": "Extensible hook system for custom functionality", + "module": "optimization_engine.plugins", + "version": "1.0.0", + + "hook_points": { + "pre_mesh": { + "description": "Execute before meshing operations", + "context": ["trial_number", "design_variables", "sim_file", "working_dir"], + "use_cases": [ + "Modify geometry based on parameters", + "Set up boundary conditions", + "Configure mesh settings" + ] + }, + "post_mesh": { + "description": "Execute after meshing, before solve", + "context": ["trial_number", "mesh_info", "element_count", "node_count"], + "use_cases": [ + "Validate mesh quality", + "Export mesh for visualization", + "Log mesh statistics" + ] + }, + "pre_solve": { + "description": "Execute before solver launch", + "context": ["trial_number", "design_variables", "solver_settings"], + "use_cases": [ + "Log trial parameters", + "Modify solver settings", + "Set up custom load cases" + ] + }, + "post_solve": { + "description": "Execute after solve, before result extraction", + "context": ["trial_number", "solve_status", "output_files"], + "use_cases": [ + "Check solver convergence", + "Post-process results", + "Generate visualizations" + ] + }, + "post_extraction": { + "description": "Execute after result extraction", + "context": ["trial_number", "extracted_results", "objectives", "constraints"], + "use_cases": [ + "Calculate custom metrics", + "Combine multiple objectives (RSS)", + "Apply penalty functions" + ] + }, + "custom_objective": { + "description": "Define custom objective functions", + "context": ["extracted_results", "design_variables"], + "use_cases": [ + "RSS of stress and displacement", + "Weighted multi-criteria", + "Conditional objectives" + ] + } + }, + + "api": { + "register_hook": { + "description": "Register a new hook function", + "signature": "hook_manager.register_hook(hook_point, function, description, name=None, priority=100)", + "parameters": { + "hook_point": "One of: pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objective", + "function": "Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]", + "description": "Human-readable description", + "priority": "Execution order (lower = earlier)" + }, + "example": "See optimization_engine/plugins/pre_solve/log_trial_start.py" + }, + "execute_hooks": { + "description": "Execute all hooks at a specific point", + "signature": "hook_manager.execute_hooks(hook_point, context, fail_fast=False)", + "returns": "List of hook results" + } + }, + + "validators": { + "validate_plugin_code": { + "description": "Validate plugin code for safety", + "checks": [ + "Syntax errors", + "Dangerous imports (os.system, subprocess, etc.)", + "File operations (optional allow)", + "Function signature correctness" + ], + "safe_modules": ["math", "numpy", "scipy", "pandas", "pathlib", "json", "optuna", "pyNastran"], + "llm_hint": "Always validate LLM-generated code before execution" + } + } + }, + + "design_variables": { + "description": "Parametric CAD variables to optimize", + "schema": { + "name": "Unique identifier", + "expression_name": "NX expression name", + "min": "Lower bound (float)", + "max": "Upper bound (float)", + "units": "Unit system (mm, degrees, etc.)" + }, + "example": { + "name": "wall_thickness", + "expression_name": "wall_thickness", + "min": 3.0, + "max": 8.0, + "units": "mm" + } + }, + + "objectives": { + "description": "Metrics to minimize or maximize", + "schema": { + "name": "Unique identifier", + "extractor": "Result extractor to use", + "metric": "Specific metric from extractor", + "direction": "minimize or maximize", + "weight": "Importance (for multi-objective)", + "units": "Unit system" + }, + "example": { + "name": "max_stress", + "extractor": "stress_extractor", + "metric": "max_von_mises", + "direction": "minimize", + "weight": 1.0, + "units": "MPa" + } + }, + + "constraints": { + "description": "Limits on simulation outputs", + "schema": { + "name": "Unique identifier", + "extractor": "Result extractor to use", + "metric": "Specific metric", + "type": "upper_bound or lower_bound", + "limit": "Constraint value", + "units": "Unit system" + }, + "example": { + "name": "max_displacement_limit", + "extractor": "displacement_extractor", + "metric": "max_displacement", + "type": "upper_bound", + "limit": 1.0, + "units": "mm" + } + }, + + "examples": { + "bracket_optimization": { + "description": "Minimize stress on a bracket by varying wall thickness", + "location": "examples/bracket/", + "design_variables": ["wall_thickness"], + "objectives": ["max_von_mises"], + "trials": 50, + "typical_runtime": "2-3 hours", + "llm_hint": "Good template for single-objective structural optimization" + } + }, + + "llm_guidelines": { + "code_generation": { + "hook_template": "Always include: function signature with context dict, docstring, return dict", + "validation": "Use validate_plugin_code() before registration", + "error_handling": "Wrap in try-except, log errors, return None on failure" + }, + "natural_language_mapping": { + "minimize stress": "objective with direction='minimize', extractor='stress_extractor'", + "vary thickness 3-8mm": "design_variable with min=3.0, max=8.0, units='mm'", + "displacement < 1mm": "constraint with type='upper_bound', limit=1.0", + "RSS of stress and displacement": "custom_objective hook with sqrt(stress² + displacement²)" + } + } +} diff --git a/optimization_engine/plugins/__init__.py b/optimization_engine/plugins/__init__.py new file mode 100644 index 00000000..35921e4d --- /dev/null +++ b/optimization_engine/plugins/__init__.py @@ -0,0 +1,32 @@ +""" +Atomizer Plugin System + +Enables extensibility through hooks at various stages of the optimization lifecycle. + +Hook Points: +- pre_mesh: Before meshing operations +- post_mesh: After meshing, before solve +- pre_solve: Before solver execution +- post_solve: After solve, before result extraction +- post_extraction: After result extraction, before objective calculation +- custom_objectives: Custom objective/constraint functions + +Usage: + from optimization_engine.plugins import HookManager + + hook_manager = HookManager() + hook_manager.register_hook('pre_solve', my_custom_function) + hook_manager.execute_hooks('pre_solve', context={'trial_number': 5}) +""" + +from .hooks import Hook, HookPoint +from .hook_manager import HookManager +from .validators import validate_plugin_code, PluginValidationError + +__all__ = [ + 'Hook', + 'HookPoint', + 'HookManager', + 'validate_plugin_code', + 'PluginValidationError' +] diff --git a/optimization_engine/plugins/custom_objectives/__init__.py b/optimization_engine/plugins/custom_objectives/__init__.py new file mode 100644 index 00000000..13040ed4 --- /dev/null +++ b/optimization_engine/plugins/custom_objectives/__init__.py @@ -0,0 +1 @@ +# custom_objectives hooks diff --git a/optimization_engine/plugins/hook_manager.py b/optimization_engine/plugins/hook_manager.py new file mode 100644 index 00000000..1f051349 --- /dev/null +++ b/optimization_engine/plugins/hook_manager.py @@ -0,0 +1,303 @@ +""" +Hook Manager for Atomizer Plugin System + +Manages registration, execution, and lifecycle of hooks. +""" + +from typing import Dict, List, Callable, Any, Optional +from pathlib import Path +import logging +import importlib.util +import json + +from .hooks import Hook, HookPoint, HookContext + +logger = logging.getLogger(__name__) + + +class HookManager: + """ + Central manager for all hooks in the optimization system. + + Example: + >>> manager = HookManager() + >>> def my_hook(context): + >>> print(f"Trial {context['trial_number']} starting") + >>> return {'status': 'logged'} + >>> + >>> manager.register_hook('pre_solve', my_hook, 'Log trial start') + >>> manager.execute_hooks('pre_solve', {'trial_number': 5}) + """ + + def __init__(self): + self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint} + self._hook_history: List[Dict[str, Any]] = [] + logger.info("HookManager initialized") + + def register_hook( + self, + hook_point: str | HookPoint, + function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]], + description: str, + name: Optional[str] = None, + priority: int = 100, + enabled: bool = True + ) -> Hook: + """ + Register a new hook function. + + Args: + hook_point: When to execute ('pre_solve', 'post_mesh', etc.) + function: Callable that takes context dict, returns optional dict + description: Human-readable description + name: Unique name (auto-generated if not provided) + priority: Execution order (lower = earlier) + enabled: Whether hook is active + + Returns: + The created Hook object + + Raises: + ValueError: If hook_point is invalid + """ + # Convert string to HookPoint enum + if isinstance(hook_point, str): + try: + hook_point = HookPoint(hook_point) + except ValueError: + valid_points = [p.value for p in HookPoint] + raise ValueError( + f"Invalid hook_point '{hook_point}'. " + f"Valid options: {valid_points}" + ) + + # Auto-generate name if not provided + if name is None: + name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}" + + # Create hook + hook = Hook( + name=name, + hook_point=hook_point, + function=function, + description=description, + priority=priority, + enabled=enabled + ) + + # Add to registry and sort by priority + self._hooks[hook_point].append(hook) + self._hooks[hook_point].sort(key=lambda h: h.priority) + + logger.info(f"Registered hook: {hook}") + return hook + + def execute_hooks( + self, + hook_point: str | HookPoint, + context: Dict[str, Any], + fail_fast: bool = False + ) -> List[Optional[Dict[str, Any]]]: + """ + Execute all hooks registered at a specific point. + + Args: + hook_point: The execution point + context: Data to pass to hooks + fail_fast: If True, stop on first error. If False, log and continue. + + Returns: + List of results from each hook (None if hook returned nothing) + + Raises: + Exception: If fail_fast=True and a hook fails + """ + # Convert string to enum + if isinstance(hook_point, str): + hook_point = HookPoint(hook_point) + + hooks = self._hooks[hook_point] + if not hooks: + logger.debug(f"No hooks registered for {hook_point.value}") + return [] + + logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}") + + results = [] + for hook in hooks: + try: + result = hook.execute(context) + results.append(result) + + # Record in history + self._hook_history.append({ + 'hook_name': hook.name, + 'hook_point': hook_point.value, + 'success': True, + 'trial_number': context.get('trial_number', -1) + }) + + except Exception as e: + logger.error(f"Hook '{hook.name}' failed: {e}") + + # Record failure + self._hook_history.append({ + 'hook_name': hook.name, + 'hook_point': hook_point.value, + 'success': False, + 'error': str(e), + 'trial_number': context.get('trial_number', -1) + }) + + if fail_fast: + raise + else: + results.append(None) + + return results + + def load_plugins_from_directory(self, directory: Path): + """ + Auto-discover and load plugins from a directory. + + Expected structure: + plugins/ + pre_mesh/ + my_plugin.py # Contains register_hooks(manager) function + post_solve/ + another_plugin.py + + Args: + directory: Path to plugins directory + """ + if not directory.exists(): + logger.warning(f"Plugins directory not found: {directory}") + return + + logger.info(f"Loading plugins from {directory}") + + for hook_point in HookPoint: + hook_dir = directory / hook_point.value + if not hook_dir.exists(): + continue + + for plugin_file in hook_dir.glob("*.py"): + if plugin_file.name.startswith("_"): + continue # Skip __init__.py and private files + + try: + self._load_plugin_file(plugin_file) + except Exception as e: + logger.error(f"Failed to load plugin {plugin_file}: {e}") + + def _load_plugin_file(self, plugin_file: Path): + """Load a single plugin file and call its register_hooks function.""" + spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load spec for {plugin_file}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Call register_hooks if it exists + if hasattr(module, 'register_hooks'): + module.register_hooks(self) + logger.info(f"Loaded plugin: {plugin_file.stem}") + else: + logger.warning(f"Plugin {plugin_file} has no register_hooks() function") + + def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]: + """Get all hooks registered at a specific point.""" + if isinstance(hook_point, str): + hook_point = HookPoint(hook_point) + return self._hooks[hook_point].copy() + + def remove_hook(self, name: str) -> bool: + """ + Remove a hook by name. + + Returns: + True if hook was found and removed, False otherwise + """ + for point, hooks in self._hooks.items(): + for i, hook in enumerate(hooks): + if hook.name == name: + del hooks[i] + logger.info(f"Removed hook: {name}") + return True + return False + + def enable_hook(self, name: str): + """Enable a hook by name.""" + self._set_hook_enabled(name, True) + + def disable_hook(self, name: str): + """Disable a hook by name.""" + self._set_hook_enabled(name, False) + + def _set_hook_enabled(self, name: str, enabled: bool): + """Set enabled status for a hook.""" + for hooks in self._hooks.values(): + for hook in hooks: + if hook.name == name: + hook.enabled = enabled + status = "enabled" if enabled else "disabled" + logger.info(f"Hook '{name}' {status}") + return + logger.warning(f"Hook '{name}' not found") + + def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: + """ + Get hook execution history. + + Args: + limit: Maximum number of records to return (most recent) + + Returns: + List of execution records + """ + if limit: + return self._hook_history[-limit:] + return self._hook_history.copy() + + def clear_history(self): + """Clear all execution history.""" + self._hook_history.clear() + logger.info("Hook history cleared") + + def get_summary(self) -> Dict[str, Any]: + """ + Get a summary of the hook system state. + + Returns: + Dictionary with hook counts, enabled status, etc. + """ + total_hooks = sum(len(hooks) for hooks in self._hooks.values()) + enabled_hooks = sum( + sum(1 for h in hooks if h.enabled) + for hooks in self._hooks.values() + ) + + by_point = { + point.value: { + 'total': len(hooks), + 'enabled': sum(1 for h in hooks if h.enabled), + 'names': [h.name for h in hooks] + } + for point, hooks in self._hooks.items() + } + + return { + 'total_hooks': total_hooks, + 'enabled_hooks': enabled_hooks, + 'disabled_hooks': total_hooks - enabled_hooks, + 'by_hook_point': by_point, + 'history_records': len(self._hook_history) + } + + def __repr__(self) -> str: + summary = self.get_summary() + return ( + f"HookManager(hooks={summary['total_hooks']}, " + f"enabled={summary['enabled_hooks']})" + ) diff --git a/optimization_engine/plugins/hooks.py b/optimization_engine/plugins/hooks.py new file mode 100644 index 00000000..59049cc7 --- /dev/null +++ b/optimization_engine/plugins/hooks.py @@ -0,0 +1,115 @@ +""" +Core Hook System for Atomizer + +Defines hook points in the optimization lifecycle and hook registration mechanism. +""" + +from enum import Enum +from typing import Callable, Dict, Any, Optional +from dataclasses import dataclass +import logging + +logger = logging.getLogger(__name__) + + +class HookPoint(Enum): + """Enumeration of available hook points in the optimization lifecycle.""" + + PRE_MESH = "pre_mesh" # Before meshing + POST_MESH = "post_mesh" # After meshing, before solve + PRE_SOLVE = "pre_solve" # Before solver execution + POST_SOLVE = "post_solve" # After solve, before extraction + POST_EXTRACTION = "post_extraction" # After result extraction + CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions + + +@dataclass +class Hook: + """ + Represents a single hook function to be executed at a specific point. + + Attributes: + name: Unique identifier for this hook + hook_point: When this hook should execute (HookPoint enum) + function: The callable to execute + description: Human-readable description of what this hook does + priority: Execution order (lower = earlier, default=100) + enabled: Whether this hook is currently active + """ + + name: str + hook_point: HookPoint + function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]] + description: str + priority: int = 100 + enabled: bool = True + + def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Execute this hook with the given context. + + Args: + context: Dictionary containing relevant data for this hook point + Common keys: + - trial_number: Current trial number + - design_variables: Current design variable values + - sim_file: Path to simulation file + - working_dir: Current working directory + + Returns: + Optional dictionary with results or modifications to context + + Raises: + Exception: Any exception from the hook function is logged and re-raised + """ + if not self.enabled: + logger.debug(f"Hook '{self.name}' is disabled, skipping") + return None + + try: + logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}") + result = self.function(context) + logger.debug(f"Hook '{self.name}' completed successfully") + return result + + except Exception as e: + logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True) + raise + + def __repr__(self) -> str: + status = "enabled" if self.enabled else "disabled" + return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})" + + +class HookContext: + """ + Context object passed to hooks containing all relevant data. + + This is a convenience wrapper around a dictionary that provides + both dict-like access and attribute access. + """ + + def __init__(self, **kwargs): + self._data = kwargs + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __setitem__(self, key: str, value: Any): + self._data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self._data + + def get(self, key: str, default: Any = None) -> Any: + return self._data.get(key, default) + + def update(self, other: Dict[str, Any]): + self._data.update(other) + + def to_dict(self) -> Dict[str, Any]: + return self._data.copy() + + def __repr__(self) -> str: + keys = list(self._data.keys()) + return f"HookContext({', '.join(keys)})" diff --git a/optimization_engine/plugins/post_extraction/__init__.py b/optimization_engine/plugins/post_extraction/__init__.py new file mode 100644 index 00000000..a20209d6 --- /dev/null +++ b/optimization_engine/plugins/post_extraction/__init__.py @@ -0,0 +1 @@ +# post_extraction hooks diff --git a/optimization_engine/plugins/post_mesh/__init__.py b/optimization_engine/plugins/post_mesh/__init__.py new file mode 100644 index 00000000..37f2e518 --- /dev/null +++ b/optimization_engine/plugins/post_mesh/__init__.py @@ -0,0 +1 @@ +# post_mesh hooks diff --git a/optimization_engine/plugins/post_solve/__init__.py b/optimization_engine/plugins/post_solve/__init__.py new file mode 100644 index 00000000..162194f6 --- /dev/null +++ b/optimization_engine/plugins/post_solve/__init__.py @@ -0,0 +1 @@ +# post_solve hooks diff --git a/optimization_engine/plugins/pre_mesh/__init__.py b/optimization_engine/plugins/pre_mesh/__init__.py new file mode 100644 index 00000000..e2c153bc --- /dev/null +++ b/optimization_engine/plugins/pre_mesh/__init__.py @@ -0,0 +1 @@ +# pre_mesh hooks diff --git a/optimization_engine/plugins/pre_solve/__init__.py b/optimization_engine/plugins/pre_solve/__init__.py new file mode 100644 index 00000000..739f9f94 --- /dev/null +++ b/optimization_engine/plugins/pre_solve/__init__.py @@ -0,0 +1 @@ +# pre_solve hooks diff --git a/optimization_engine/plugins/pre_solve/log_trial_start.py b/optimization_engine/plugins/pre_solve/log_trial_start.py new file mode 100644 index 00000000..976054fc --- /dev/null +++ b/optimization_engine/plugins/pre_solve/log_trial_start.py @@ -0,0 +1,56 @@ +""" +Example Plugin: Log Trial Start + +Simple pre-solve hook that logs trial information. + +This demonstrates the plugin API and serves as a template for custom hooks. +""" + +from typing import Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) + + +def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Log trial information before solver execution. + + Args: + context: Hook context containing: + - trial_number: Current trial number + - design_variables: Dict of variable values + - sim_file: Path to simulation file + + Returns: + Dict with logging status + """ + trial_num = context.get('trial_number', '?') + design_vars = context.get('design_variables', {}) + + logger.info("=" * 60) + logger.info(f"TRIAL {trial_num} STARTING") + logger.info("=" * 60) + + for var_name, var_value in design_vars.items(): + logger.info(f" {var_name}: {var_value:.4f}") + + return {'logged': True, 'trial': trial_num} + + +def register_hooks(hook_manager): + """ + Register this plugin's hooks with the manager. + + This function is called automatically when the plugin is loaded. + + Args: + hook_manager: The HookManager instance + """ + hook_manager.register_hook( + hook_point='pre_solve', + function=trial_start_logger, + description='Log trial number and design variables before solve', + name='log_trial_start', + priority=10 # Run early + ) diff --git a/optimization_engine/plugins/validators.py b/optimization_engine/plugins/validators.py new file mode 100644 index 00000000..7d1da434 --- /dev/null +++ b/optimization_engine/plugins/validators.py @@ -0,0 +1,214 @@ +""" +Plugin Code Validators + +Ensures safety and correctness of plugin code before execution. +""" + +import ast +import re +from typing import List, Set +import logging + +logger = logging.getLogger(__name__) + + +class PluginValidationError(Exception): + """Raised when plugin code fails validation.""" + pass + + +# Whitelist of safe modules for plugin imports +SAFE_MODULES = { + 'math', 'numpy', 'scipy', 'pandas', + 'pathlib', 'json', 'csv', 'datetime', + 'logging', 'typing', 'dataclasses', + 'optuna', 'pyNastran' +} + +# Blacklist of dangerous operations +DANGEROUS_OPERATIONS = { + 'eval', 'exec', 'compile', '__import__', + 'open', # File operations should be explicit + 'subprocess', 'os.system', 'os.popen', + 'shutil.rmtree', # Dangerous file operations +} + + +def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None: + """ + Validate plugin code for safety before execution. + + Args: + code: Python source code to validate + allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.) + + Raises: + PluginValidationError: If code contains unsafe operations + + Example: + >>> code = "def my_hook(context): return {'result': context['x'] * 2}" + >>> validate_plugin_code(code) # Passes + >>> + >>> code = "import os; os.system('rm -rf /')" + >>> validate_plugin_code(code) # Raises PluginValidationError + """ + try: + tree = ast.parse(code) + except SyntaxError as e: + raise PluginValidationError(f"Syntax error in plugin code: {e}") + + # Check for dangerous imports + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name not in SAFE_MODULES: + raise PluginValidationError( + f"Unsafe import: {alias.name}. " + f"Allowed modules: {SAFE_MODULES}" + ) + + elif isinstance(node, ast.ImportFrom): + if node.module and node.module not in SAFE_MODULES: + # Allow submodules of safe modules + base_module = node.module.split('.')[0] + if base_module not in SAFE_MODULES: + raise PluginValidationError( + f"Unsafe import from: {node.module}" + ) + + # Check for dangerous function calls + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in DANGEROUS_OPERATIONS: + if node.func.id == 'open' and allow_file_ops: + continue # Allow if explicitly permitted + raise PluginValidationError( + f"Dangerous operation: {node.func.id}" + ) + + # Check for attribute access to dangerous methods + elif isinstance(node, ast.Attribute): + dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'} + if node.attr in dangerous_attrs and not allow_file_ops: + raise PluginValidationError( + f"Dangerous attribute access: {node.attr}" + ) + + logger.info("Plugin code validation passed") + + +def check_hook_function_signature(func_code: str) -> bool: + """ + Verify that a hook function has the correct signature. + + Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]] + + Args: + func_code: Function source code + + Returns: + True if signature is valid + + Raises: + PluginValidationError: If signature is invalid + """ + try: + tree = ast.parse(func_code) + except SyntaxError as e: + raise PluginValidationError(f"Syntax error: {e}") + + # Find function definition + func_def = None + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_def = node + break + + if func_def is None: + raise PluginValidationError("No function definition found") + + # Check that it takes exactly one argument (context) + if len(func_def.args.args) != 1: + raise PluginValidationError( + f"Hook function must take exactly 1 argument (context), " + f"got {len(func_def.args.args)}" + ) + + logger.info(f"Hook function '{func_def.name}' signature is valid") + return True + + +def sanitize_plugin_name(name: str) -> str: + """ + Sanitize a plugin name to be safe for filesystem and imports. + + Args: + name: Original plugin name + + Returns: + Sanitized name (lowercase, alphanumeric + underscore only) + """ + # Remove special characters, keep only alphanumeric and underscore + sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower()) + + # Ensure it doesn't start with a number + if sanitized[0].isdigit(): + sanitized = f"plugin_{sanitized}" + + return sanitized + + +def get_imported_modules(code: str) -> Set[str]: + """ + Extract all imported modules from code. + + Args: + code: Python source code + + Returns: + Set of module names + """ + try: + tree = ast.parse(code) + except SyntaxError: + return set() + + modules = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + modules.add(alias.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + modules.add(node.module) + + return modules + + +def estimate_complexity(code: str) -> int: + """ + Estimate the cyclomatic complexity of code. + + Higher values indicate more complex code with more branching. + + Args: + code: Python source code + + Returns: + Complexity estimate (1 = simple, 10+ = complex) + """ + try: + tree = ast.parse(code) + except SyntaxError: + return 0 + + complexity = 1 # Base complexity + + for node in ast.walk(tree): + # Add 1 for each branching statement + if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)): + complexity += 1 + elif isinstance(node, ast.BoolOp): + complexity += len(node.values) - 1 + + return complexity diff --git a/optimization_engine/runner.py b/optimization_engine/runner.py index f53a8ee0..cb0d82a8 100644 --- a/optimization_engine/runner.py +++ b/optimization_engine/runner.py @@ -23,6 +23,8 @@ import pandas as pd from datetime import datetime import pickle +from optimization_engine.plugins import HookManager + class OptimizationRunner: """ @@ -68,6 +70,15 @@ class OptimizationRunner: self.output_dir = self.config_path.parent / 'optimization_results' self.output_dir.mkdir(exist_ok=True) + # Initialize plugin/hook system + self.hook_manager = HookManager() + plugins_dir = Path(__file__).parent / 'plugins' + if plugins_dir.exists(): + self.hook_manager.load_plugins_from_directory(plugins_dir) + summary = self.hook_manager.get_summary() + if summary['total_hooks'] > 0: + print(f"Loaded {summary['enabled_hooks']}/{summary['total_hooks']} plugins") + def _load_config(self) -> Dict[str, Any]: """Load and validate optimization configuration.""" with open(self.config_path, 'r') as f: @@ -311,6 +322,16 @@ class OptimizationRunner: int(dv['bounds'][1]) ) + # Execute pre_solve hooks + pre_solve_context = { + 'trial_number': trial.number, + 'design_variables': design_vars, + 'sim_file': self.config.get('sim_file', ''), + 'working_dir': str(Path.cwd()), + 'config': self.config + } + self.hook_manager.execute_hooks('pre_solve', pre_solve_context, fail_fast=False) + # 2. Update NX model with new parameters try: self.model_updater(design_vars) @@ -318,6 +339,15 @@ class OptimizationRunner: print(f"Error updating model: {e}") raise optuna.TrialPruned() + # Execute post_mesh hooks (after model update) + post_mesh_context = { + 'trial_number': trial.number, + 'design_variables': design_vars, + 'sim_file': self.config.get('sim_file', ''), + 'working_dir': str(Path.cwd()) + } + self.hook_manager.execute_hooks('post_mesh', post_mesh_context, fail_fast=False) + # 3. Run simulation try: result_path = self.simulation_runner() @@ -325,6 +355,15 @@ class OptimizationRunner: print(f"Error running simulation: {e}") raise optuna.TrialPruned() + # Execute post_solve hooks + post_solve_context = { + 'trial_number': trial.number, + 'design_variables': design_vars, + 'result_path': str(result_path) if result_path else '', + 'working_dir': str(Path.cwd()) + } + self.hook_manager.execute_hooks('post_solve', post_solve_context, fail_fast=False) + # 4. Extract results with appropriate precision extracted_results = {} for obj in self.config['objectives']: @@ -362,6 +401,16 @@ class OptimizationRunner: print(f"Error extracting {const['name']}: {e}") raise optuna.TrialPruned() + # Execute post_extraction hooks + post_extraction_context = { + 'trial_number': trial.number, + 'design_variables': design_vars, + 'extracted_results': extracted_results, + 'result_path': str(result_path) if result_path else '', + 'working_dir': str(Path.cwd()) + } + self.hook_manager.execute_hooks('post_extraction', post_extraction_context, fail_fast=False) + # 5. Evaluate constraints for const in self.config.get('constraints', []): value = extracted_results[const['name']] @@ -389,6 +438,23 @@ class OptimizationRunner: else: # maximize total_objective -= weight * value + # Execute custom_objective hooks (can modify total_objective) + custom_objective_context = { + 'trial_number': trial.number, + 'design_variables': design_vars, + 'extracted_results': extracted_results, + 'total_objective': total_objective, + 'working_dir': str(Path.cwd()) + } + custom_results = self.hook_manager.execute_hooks('custom_objective', custom_objective_context, fail_fast=False) + + # Allow hooks to override objective value + for result in custom_results: + if result and 'total_objective' in result: + total_objective = result['total_objective'] + print(f"Custom objective hook modified total_objective to {total_objective:.6f}") + break # Use first hook that provides override + # 7. Store results in history history_entry = { 'trial_number': trial.number, diff --git a/tests/test_plugin_system.py b/tests/test_plugin_system.py new file mode 100644 index 00000000..d99d47ca --- /dev/null +++ b/tests/test_plugin_system.py @@ -0,0 +1,438 @@ +""" +Tests for the Plugin System + +Validates hook registration, execution, validation, and integration. +""" + +import pytest +from pathlib import Path +from typing import Dict, Any, Optional + +from optimization_engine.plugins import HookManager, Hook, HookPoint +from optimization_engine.plugins.validators import ( + validate_plugin_code, + PluginValidationError, + check_hook_function_signature, + sanitize_plugin_name, + get_imported_modules, + estimate_complexity +) + + +class TestHookRegistration: + """Test hook registration and management.""" + + def test_register_simple_hook(self): + """Test basic hook registration.""" + manager = HookManager() + + def my_hook(context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return {'status': 'success'} + + hook = manager.register_hook( + hook_point='pre_solve', + function=my_hook, + description='Test hook', + name='test_hook' + ) + + assert hook.name == 'test_hook' + assert hook.hook_point == HookPoint.PRE_SOLVE + assert hook.enabled is True + + def test_hook_priority_ordering(self): + """Test that hooks execute in priority order.""" + manager = HookManager() + execution_order = [] + + def hook_high_priority(context): + execution_order.append('high') + return None + + def hook_low_priority(context): + execution_order.append('low') + return None + + # Register low priority first, high priority second + manager.register_hook('pre_solve', hook_low_priority, 'Low', priority=200) + manager.register_hook('pre_solve', hook_high_priority, 'High', priority=10) + + # Execute hooks + manager.execute_hooks('pre_solve', {'trial_number': 1}) + + # High priority should execute first + assert execution_order == ['high', 'low'] + + def test_disable_enable_hook(self): + """Test disabling and enabling hooks.""" + manager = HookManager() + execution_count = [0] + + def counting_hook(context): + execution_count[0] += 1 + return None + + hook = manager.register_hook( + 'pre_solve', + counting_hook, + 'Counter', + name='counter_hook' + ) + + # Execute while enabled + manager.execute_hooks('pre_solve', {}) + assert execution_count[0] == 1 + + # Disable and execute + manager.disable_hook('counter_hook') + manager.execute_hooks('pre_solve', {}) + assert execution_count[0] == 1 # Should not increment + + # Re-enable and execute + manager.enable_hook('counter_hook') + manager.execute_hooks('pre_solve', {}) + assert execution_count[0] == 2 + + def test_remove_hook(self): + """Test hook removal.""" + manager = HookManager() + + def test_hook(context): + return None + + manager.register_hook('pre_solve', test_hook, 'Test', name='removable') + + assert len(manager.get_hooks('pre_solve')) == 1 + + success = manager.remove_hook('removable') + assert success is True + assert len(manager.get_hooks('pre_solve')) == 0 + + # Try removing non-existent hook + success = manager.remove_hook('nonexistent') + assert success is False + + +class TestHookExecution: + """Test hook execution behavior.""" + + def test_hook_receives_context(self): + """Test that hooks receive correct context.""" + manager = HookManager() + received_context = {} + + def context_checker(context: Dict[str, Any]): + received_context.update(context) + return None + + manager.register_hook('pre_solve', context_checker, 'Checker') + + test_context = { + 'trial_number': 42, + 'design_variables': {'thickness': 5.0}, + 'sim_file': 'test.sim' + } + + manager.execute_hooks('pre_solve', test_context) + + assert received_context['trial_number'] == 42 + assert received_context['design_variables']['thickness'] == 5.0 + + def test_hook_return_values(self): + """Test that hook return values are collected.""" + manager = HookManager() + + def hook1(context): + return {'result': 'hook1'} + + def hook2(context): + return {'result': 'hook2'} + + manager.register_hook('pre_solve', hook1, 'Hook 1') + manager.register_hook('pre_solve', hook2, 'Hook 2') + + results = manager.execute_hooks('pre_solve', {}) + + assert len(results) == 2 + assert results[0]['result'] == 'hook1' + assert results[1]['result'] == 'hook2' + + def test_hook_error_handling_fail_fast(self): + """Test error handling with fail_fast=True.""" + manager = HookManager() + + def failing_hook(context): + raise ValueError("Intentional error") + + manager.register_hook('pre_solve', failing_hook, 'Failing') + + with pytest.raises(ValueError, match="Intentional error"): + manager.execute_hooks('pre_solve', {}, fail_fast=True) + + def test_hook_error_handling_continue(self): + """Test error handling with fail_fast=False.""" + manager = HookManager() + execution_log = [] + + def failing_hook(context): + execution_log.append('failing') + raise ValueError("Intentional error") + + def successful_hook(context): + execution_log.append('successful') + return {'status': 'ok'} + + manager.register_hook('pre_solve', failing_hook, 'Failing', priority=10) + manager.register_hook('pre_solve', successful_hook, 'Success', priority=20) + + results = manager.execute_hooks('pre_solve', {}, fail_fast=False) + + # Both hooks should have attempted execution + assert execution_log == ['failing', 'successful'] + # First result is None (error), second is successful + assert results[0] is None + assert results[1]['status'] == 'ok' + + def test_hook_history_tracking(self): + """Test that hook execution history is tracked.""" + manager = HookManager() + + def test_hook(context): + return {'result': 'success'} + + manager.register_hook('pre_solve', test_hook, 'Test', name='tracked') + + # Execute hooks multiple times + for i in range(3): + manager.execute_hooks('pre_solve', {'trial_number': i}) + + history = manager.get_history() + assert len(history) >= 3 + + # Check history contains success records + successful = [h for h in history if h['success']] + assert len(successful) >= 3 + + +class TestCodeValidation: + """Test plugin code validation.""" + + def test_safe_code_passes(self): + """Test that safe code passes validation.""" + safe_code = """ +import numpy as np +import math + +def my_hook(context): + x = context['design_variables']['thickness'] + result = math.sqrt(x**2 + np.mean([1, 2, 3])) + return {'result': result} +""" + # Should not raise + validate_plugin_code(safe_code) + + def test_dangerous_import_blocked(self): + """Test that dangerous imports are blocked.""" + dangerous_code = """ +import os + +def my_hook(context): + os.system('rm -rf /') + return None +""" + with pytest.raises(PluginValidationError, match="Unsafe import"): + validate_plugin_code(dangerous_code) + + def test_dangerous_operation_blocked(self): + """Test that dangerous operations are blocked.""" + dangerous_code = """ +def my_hook(context): + eval('malicious_code') + return None +""" + with pytest.raises(PluginValidationError, match="Dangerous operation"): + validate_plugin_code(dangerous_code) + + def test_file_operations_with_permission(self): + """Test that file operations work with allow_file_ops=True.""" + code_with_file_ops = """ +def my_hook(context): + with open('output.txt', 'w') as f: + f.write('test') + return None +""" + # Should raise without permission + with pytest.raises(PluginValidationError, match="Dangerous operation: open"): + validate_plugin_code(code_with_file_ops, allow_file_ops=False) + + # Should pass with permission + validate_plugin_code(code_with_file_ops, allow_file_ops=True) + + def test_syntax_error_detected(self): + """Test that syntax errors are detected.""" + bad_syntax = """ +def my_hook(context) + return None # Missing colon +""" + with pytest.raises(PluginValidationError, match="Syntax error"): + validate_plugin_code(bad_syntax) + + def test_hook_signature_validation(self): + """Test hook function signature validation.""" + # Valid signature + valid_code = """ +def my_hook(context): + return None +""" + assert check_hook_function_signature(valid_code) is True + + # Invalid: too many arguments + invalid_code = """ +def my_hook(context, extra_arg): + return None +""" + with pytest.raises(PluginValidationError, match="must take exactly 1 argument"): + check_hook_function_signature(invalid_code) + + # Invalid: no function + invalid_code = """ +x = 5 +""" + with pytest.raises(PluginValidationError, match="No function definition"): + check_hook_function_signature(invalid_code) + + +class TestUtilityFunctions: + """Test plugin utility functions.""" + + def test_sanitize_plugin_name(self): + """Test plugin name sanitization.""" + assert sanitize_plugin_name('My-Plugin!') == 'my_plugin_' + assert sanitize_plugin_name('123_plugin') == 'plugin_123_plugin' + assert sanitize_plugin_name('valid_name') == 'valid_name' + + def test_get_imported_modules(self): + """Test module extraction from code.""" + code = """ +import numpy as np +import math +from pathlib import Path +from typing import Dict, Any +""" + modules = get_imported_modules(code) + assert 'numpy' in modules + assert 'math' in modules + assert 'pathlib' in modules + assert 'typing' in modules + + def test_estimate_complexity(self): + """Test cyclomatic complexity estimation.""" + simple_code = """ +def simple(x): + return x + 1 +""" + complex_code = """ +def complex(x): + if x > 0: + for i in range(x): + if i % 2 == 0: + while i > 0: + i -= 1 + return x +""" + simple_complexity = estimate_complexity(simple_code) + complex_complexity = estimate_complexity(complex_code) + + assert simple_complexity == 1 + assert complex_complexity > simple_complexity + + +class TestHookManager: + """Test HookManager functionality.""" + + def test_get_summary(self): + """Test hook system summary.""" + manager = HookManager() + + def hook1(context): + return None + + def hook2(context): + return None + + manager.register_hook('pre_solve', hook1, 'Hook 1', name='hook1') + manager.register_hook('post_solve', hook2, 'Hook 2', name='hook2') + manager.disable_hook('hook2') + + summary = manager.get_summary() + + assert summary['total_hooks'] == 2 + assert summary['enabled_hooks'] == 1 + assert summary['disabled_hooks'] == 1 + + def test_clear_history(self): + """Test clearing execution history.""" + manager = HookManager() + + def test_hook(context): + return None + + manager.register_hook('pre_solve', test_hook, 'Test') + manager.execute_hooks('pre_solve', {}) + + assert len(manager.get_history()) > 0 + + manager.clear_history() + assert len(manager.get_history()) == 0 + + def test_hook_manager_repr(self): + """Test HookManager string representation.""" + manager = HookManager() + + def hook(context): + return None + + manager.register_hook('pre_solve', hook, 'Test') + + repr_str = repr(manager) + assert 'HookManager' in repr_str + assert 'hooks=1' in repr_str + assert 'enabled=1' in repr_str + + +class TestPluginLoading: + """Test plugin directory loading.""" + + def test_load_plugins_from_nonexistent_directory(self): + """Test loading from non-existent directory.""" + manager = HookManager() + # Should not raise, just log warning + manager.load_plugins_from_directory(Path('/nonexistent/path')) + + def test_plugin_registration_function(self): + """Test that plugins can register hooks via register_hooks().""" + manager = HookManager() + + # Simulate what a plugin file would contain + def register_hooks(hook_manager): + def my_plugin_hook(context): + return {'plugin': 'loaded'} + + hook_manager.register_hook( + 'pre_solve', + my_plugin_hook, + 'Plugin hook', + name='plugin_hook' + ) + + # Call the registration function + register_hooks(manager) + + # Verify hook was registered + hooks = manager.get_hooks('pre_solve') + assert len(hooks) == 1 + assert hooks[0].name == 'plugin_hook' + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])