feat: Implement Phase 1 - Plugin & Hook System

Core plugin architecture for LLM-driven optimization:

New Features:
- Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives)
- HookManager for centralized registration and execution
- Code validation with AST-based safety checks
- Feature registry (JSON) for LLM capability discovery
- Example plugin: log_trial_start
- 23 comprehensive tests (all passing)

Integration:
- OptimizationRunner now loads plugins automatically
- Hooks execute at 5 points in optimization loop
- Custom objectives can override total_objective via hooks

Safety:
- Module whitelist (numpy, scipy, pandas, optuna, pyNastran)
- Dangerous operation blocking (eval, exec, os.system, subprocess)
- Optional file operation permission flag

Files Added:
- optimization_engine/plugins/__init__.py
- optimization_engine/plugins/hooks.py
- optimization_engine/plugins/hook_manager.py
- optimization_engine/plugins/validators.py
- optimization_engine/feature_registry.json
- optimization_engine/plugins/pre_solve/log_trial_start.py
- tests/test_plugin_system.py (23 tests)

Files Modified:
- optimization_engine/runner.py (added hook integration)

Ready for Phase 2: LLM interface layer

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-15 14:46:49 -05:00
parent 0ce9ddf3e2
commit a24e3f750c
14 changed files with 1473 additions and 0 deletions

View File

@@ -0,0 +1,243 @@
{
"version": "1.0.0",
"last_updated": "2025-01-15",
"description": "Registry of all Atomizer capabilities for LLM discovery and usage",
"core_features": {
"optimization": {
"description": "Core optimization engine using Optuna",
"module": "optimization_engine.runner",
"capabilities": [
"Multi-objective optimization with weighted sum",
"TPE (Tree-structured Parzen Estimator) sampler",
"CMA-ES sampler",
"Gaussian Process sampler",
"50-trial default with 20 startup trials",
"Automatic checkpoint and resume",
"SQLite-based study persistence"
],
"usage": "python examples/test_journal_optimization.py",
"llm_hint": "Use this for Bayesian optimization with NX simulations"
},
"nx_integration": {
"description": "Siemens NX simulation automation via journal scripts",
"module": "optimization_engine.nx_solver",
"capabilities": [
"Update CAD expressions via NXOpen",
"Execute NX Nastran solver",
"Extract OP2 results (stress, displacement)",
"Extract mass properties",
"Precision control (4 decimals for mm/degrees/MPa)"
],
"usage": "from optimization_engine.nx_solver import run_nx_simulation",
"llm_hint": "Use for running FEA simulations and extracting results"
},
"result_extraction": {
"description": "Extract metrics from simulation results",
"module": "optimization_engine.result_extractors",
"extractors": {
"stress_extractor": {
"description": "Extract stress data from OP2 files",
"metrics": ["max_von_mises", "mean_von_mises", "max_principal"],
"file_type": "OP2",
"usage": "Returns stress in MPa"
},
"displacement_extractor": {
"description": "Extract displacement data from OP2 files",
"metrics": ["max_displacement", "mean_displacement"],
"file_type": "OP2",
"usage": "Returns displacement in mm"
},
"mass_extractor": {
"description": "Extract mass properties",
"metrics": ["total_mass", "center_of_gravity"],
"file_type": "NX Part",
"usage": "Returns mass in kg"
}
},
"llm_hint": "Use extractors to define objectives and constraints"
}
},
"plugin_system": {
"description": "Extensible hook system for custom functionality",
"module": "optimization_engine.plugins",
"version": "1.0.0",
"hook_points": {
"pre_mesh": {
"description": "Execute before meshing operations",
"context": ["trial_number", "design_variables", "sim_file", "working_dir"],
"use_cases": [
"Modify geometry based on parameters",
"Set up boundary conditions",
"Configure mesh settings"
]
},
"post_mesh": {
"description": "Execute after meshing, before solve",
"context": ["trial_number", "mesh_info", "element_count", "node_count"],
"use_cases": [
"Validate mesh quality",
"Export mesh for visualization",
"Log mesh statistics"
]
},
"pre_solve": {
"description": "Execute before solver launch",
"context": ["trial_number", "design_variables", "solver_settings"],
"use_cases": [
"Log trial parameters",
"Modify solver settings",
"Set up custom load cases"
]
},
"post_solve": {
"description": "Execute after solve, before result extraction",
"context": ["trial_number", "solve_status", "output_files"],
"use_cases": [
"Check solver convergence",
"Post-process results",
"Generate visualizations"
]
},
"post_extraction": {
"description": "Execute after result extraction",
"context": ["trial_number", "extracted_results", "objectives", "constraints"],
"use_cases": [
"Calculate custom metrics",
"Combine multiple objectives (RSS)",
"Apply penalty functions"
]
},
"custom_objective": {
"description": "Define custom objective functions",
"context": ["extracted_results", "design_variables"],
"use_cases": [
"RSS of stress and displacement",
"Weighted multi-criteria",
"Conditional objectives"
]
}
},
"api": {
"register_hook": {
"description": "Register a new hook function",
"signature": "hook_manager.register_hook(hook_point, function, description, name=None, priority=100)",
"parameters": {
"hook_point": "One of: pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objective",
"function": "Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]",
"description": "Human-readable description",
"priority": "Execution order (lower = earlier)"
},
"example": "See optimization_engine/plugins/pre_solve/log_trial_start.py"
},
"execute_hooks": {
"description": "Execute all hooks at a specific point",
"signature": "hook_manager.execute_hooks(hook_point, context, fail_fast=False)",
"returns": "List of hook results"
}
},
"validators": {
"validate_plugin_code": {
"description": "Validate plugin code for safety",
"checks": [
"Syntax errors",
"Dangerous imports (os.system, subprocess, etc.)",
"File operations (optional allow)",
"Function signature correctness"
],
"safe_modules": ["math", "numpy", "scipy", "pandas", "pathlib", "json", "optuna", "pyNastran"],
"llm_hint": "Always validate LLM-generated code before execution"
}
}
},
"design_variables": {
"description": "Parametric CAD variables to optimize",
"schema": {
"name": "Unique identifier",
"expression_name": "NX expression name",
"min": "Lower bound (float)",
"max": "Upper bound (float)",
"units": "Unit system (mm, degrees, etc.)"
},
"example": {
"name": "wall_thickness",
"expression_name": "wall_thickness",
"min": 3.0,
"max": 8.0,
"units": "mm"
}
},
"objectives": {
"description": "Metrics to minimize or maximize",
"schema": {
"name": "Unique identifier",
"extractor": "Result extractor to use",
"metric": "Specific metric from extractor",
"direction": "minimize or maximize",
"weight": "Importance (for multi-objective)",
"units": "Unit system"
},
"example": {
"name": "max_stress",
"extractor": "stress_extractor",
"metric": "max_von_mises",
"direction": "minimize",
"weight": 1.0,
"units": "MPa"
}
},
"constraints": {
"description": "Limits on simulation outputs",
"schema": {
"name": "Unique identifier",
"extractor": "Result extractor to use",
"metric": "Specific metric",
"type": "upper_bound or lower_bound",
"limit": "Constraint value",
"units": "Unit system"
},
"example": {
"name": "max_displacement_limit",
"extractor": "displacement_extractor",
"metric": "max_displacement",
"type": "upper_bound",
"limit": 1.0,
"units": "mm"
}
},
"examples": {
"bracket_optimization": {
"description": "Minimize stress on a bracket by varying wall thickness",
"location": "examples/bracket/",
"design_variables": ["wall_thickness"],
"objectives": ["max_von_mises"],
"trials": 50,
"typical_runtime": "2-3 hours",
"llm_hint": "Good template for single-objective structural optimization"
}
},
"llm_guidelines": {
"code_generation": {
"hook_template": "Always include: function signature with context dict, docstring, return dict",
"validation": "Use validate_plugin_code() before registration",
"error_handling": "Wrap in try-except, log errors, return None on failure"
},
"natural_language_mapping": {
"minimize stress": "objective with direction='minimize', extractor='stress_extractor'",
"vary thickness 3-8mm": "design_variable with min=3.0, max=8.0, units='mm'",
"displacement < 1mm": "constraint with type='upper_bound', limit=1.0",
"RSS of stress and displacement": "custom_objective hook with sqrt(stress² + displacement²)"
}
}
}

View File

@@ -0,0 +1,32 @@
"""
Atomizer Plugin System
Enables extensibility through hooks at various stages of the optimization lifecycle.
Hook Points:
- pre_mesh: Before meshing operations
- post_mesh: After meshing, before solve
- pre_solve: Before solver execution
- post_solve: After solve, before result extraction
- post_extraction: After result extraction, before objective calculation
- custom_objectives: Custom objective/constraint functions
Usage:
from optimization_engine.plugins import HookManager
hook_manager = HookManager()
hook_manager.register_hook('pre_solve', my_custom_function)
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
"""
from .hooks import Hook, HookPoint
from .hook_manager import HookManager
from .validators import validate_plugin_code, PluginValidationError
__all__ = [
'Hook',
'HookPoint',
'HookManager',
'validate_plugin_code',
'PluginValidationError'
]

View File

@@ -0,0 +1 @@
# custom_objectives hooks

View File

@@ -0,0 +1,303 @@
"""
Hook Manager for Atomizer Plugin System
Manages registration, execution, and lifecycle of hooks.
"""
from typing import Dict, List, Callable, Any, Optional
from pathlib import Path
import logging
import importlib.util
import json
from .hooks import Hook, HookPoint, HookContext
logger = logging.getLogger(__name__)
class HookManager:
"""
Central manager for all hooks in the optimization system.
Example:
>>> manager = HookManager()
>>> def my_hook(context):
>>> print(f"Trial {context['trial_number']} starting")
>>> return {'status': 'logged'}
>>>
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
"""
def __init__(self):
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
self._hook_history: List[Dict[str, Any]] = []
logger.info("HookManager initialized")
def register_hook(
self,
hook_point: str | HookPoint,
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
description: str,
name: Optional[str] = None,
priority: int = 100,
enabled: bool = True
) -> Hook:
"""
Register a new hook function.
Args:
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
function: Callable that takes context dict, returns optional dict
description: Human-readable description
name: Unique name (auto-generated if not provided)
priority: Execution order (lower = earlier)
enabled: Whether hook is active
Returns:
The created Hook object
Raises:
ValueError: If hook_point is invalid
"""
# Convert string to HookPoint enum
if isinstance(hook_point, str):
try:
hook_point = HookPoint(hook_point)
except ValueError:
valid_points = [p.value for p in HookPoint]
raise ValueError(
f"Invalid hook_point '{hook_point}'. "
f"Valid options: {valid_points}"
)
# Auto-generate name if not provided
if name is None:
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
# Create hook
hook = Hook(
name=name,
hook_point=hook_point,
function=function,
description=description,
priority=priority,
enabled=enabled
)
# Add to registry and sort by priority
self._hooks[hook_point].append(hook)
self._hooks[hook_point].sort(key=lambda h: h.priority)
logger.info(f"Registered hook: {hook}")
return hook
def execute_hooks(
self,
hook_point: str | HookPoint,
context: Dict[str, Any],
fail_fast: bool = False
) -> List[Optional[Dict[str, Any]]]:
"""
Execute all hooks registered at a specific point.
Args:
hook_point: The execution point
context: Data to pass to hooks
fail_fast: If True, stop on first error. If False, log and continue.
Returns:
List of results from each hook (None if hook returned nothing)
Raises:
Exception: If fail_fast=True and a hook fails
"""
# Convert string to enum
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
hooks = self._hooks[hook_point]
if not hooks:
logger.debug(f"No hooks registered for {hook_point.value}")
return []
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
results = []
for hook in hooks:
try:
result = hook.execute(context)
results.append(result)
# Record in history
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': True,
'trial_number': context.get('trial_number', -1)
})
except Exception as e:
logger.error(f"Hook '{hook.name}' failed: {e}")
# Record failure
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': False,
'error': str(e),
'trial_number': context.get('trial_number', -1)
})
if fail_fast:
raise
else:
results.append(None)
return results
def load_plugins_from_directory(self, directory: Path):
"""
Auto-discover and load plugins from a directory.
Expected structure:
plugins/
pre_mesh/
my_plugin.py # Contains register_hooks(manager) function
post_solve/
another_plugin.py
Args:
directory: Path to plugins directory
"""
if not directory.exists():
logger.warning(f"Plugins directory not found: {directory}")
return
logger.info(f"Loading plugins from {directory}")
for hook_point in HookPoint:
hook_dir = directory / hook_point.value
if not hook_dir.exists():
continue
for plugin_file in hook_dir.glob("*.py"):
if plugin_file.name.startswith("_"):
continue # Skip __init__.py and private files
try:
self._load_plugin_file(plugin_file)
except Exception as e:
logger.error(f"Failed to load plugin {plugin_file}: {e}")
def _load_plugin_file(self, plugin_file: Path):
"""Load a single plugin file and call its register_hooks function."""
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for {plugin_file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Call register_hooks if it exists
if hasattr(module, 'register_hooks'):
module.register_hooks(self)
logger.info(f"Loaded plugin: {plugin_file.stem}")
else:
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
"""Get all hooks registered at a specific point."""
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
return self._hooks[hook_point].copy()
def remove_hook(self, name: str) -> bool:
"""
Remove a hook by name.
Returns:
True if hook was found and removed, False otherwise
"""
for point, hooks in self._hooks.items():
for i, hook in enumerate(hooks):
if hook.name == name:
del hooks[i]
logger.info(f"Removed hook: {name}")
return True
return False
def enable_hook(self, name: str):
"""Enable a hook by name."""
self._set_hook_enabled(name, True)
def disable_hook(self, name: str):
"""Disable a hook by name."""
self._set_hook_enabled(name, False)
def _set_hook_enabled(self, name: str, enabled: bool):
"""Set enabled status for a hook."""
for hooks in self._hooks.values():
for hook in hooks:
if hook.name == name:
hook.enabled = enabled
status = "enabled" if enabled else "disabled"
logger.info(f"Hook '{name}' {status}")
return
logger.warning(f"Hook '{name}' not found")
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Get hook execution history.
Args:
limit: Maximum number of records to return (most recent)
Returns:
List of execution records
"""
if limit:
return self._hook_history[-limit:]
return self._hook_history.copy()
def clear_history(self):
"""Clear all execution history."""
self._hook_history.clear()
logger.info("Hook history cleared")
def get_summary(self) -> Dict[str, Any]:
"""
Get a summary of the hook system state.
Returns:
Dictionary with hook counts, enabled status, etc.
"""
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
enabled_hooks = sum(
sum(1 for h in hooks if h.enabled)
for hooks in self._hooks.values()
)
by_point = {
point.value: {
'total': len(hooks),
'enabled': sum(1 for h in hooks if h.enabled),
'names': [h.name for h in hooks]
}
for point, hooks in self._hooks.items()
}
return {
'total_hooks': total_hooks,
'enabled_hooks': enabled_hooks,
'disabled_hooks': total_hooks - enabled_hooks,
'by_hook_point': by_point,
'history_records': len(self._hook_history)
}
def __repr__(self) -> str:
summary = self.get_summary()
return (
f"HookManager(hooks={summary['total_hooks']}, "
f"enabled={summary['enabled_hooks']})"
)

View File

@@ -0,0 +1,115 @@
"""
Core Hook System for Atomizer
Defines hook points in the optimization lifecycle and hook registration mechanism.
"""
from enum import Enum
from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
class HookPoint(Enum):
"""Enumeration of available hook points in the optimization lifecycle."""
PRE_MESH = "pre_mesh" # Before meshing
POST_MESH = "post_mesh" # After meshing, before solve
PRE_SOLVE = "pre_solve" # Before solver execution
POST_SOLVE = "post_solve" # After solve, before extraction
POST_EXTRACTION = "post_extraction" # After result extraction
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
@dataclass
class Hook:
"""
Represents a single hook function to be executed at a specific point.
Attributes:
name: Unique identifier for this hook
hook_point: When this hook should execute (HookPoint enum)
function: The callable to execute
description: Human-readable description of what this hook does
priority: Execution order (lower = earlier, default=100)
enabled: Whether this hook is currently active
"""
name: str
hook_point: HookPoint
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
description: str
priority: int = 100
enabled: bool = True
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Execute this hook with the given context.
Args:
context: Dictionary containing relevant data for this hook point
Common keys:
- trial_number: Current trial number
- design_variables: Current design variable values
- sim_file: Path to simulation file
- working_dir: Current working directory
Returns:
Optional dictionary with results or modifications to context
Raises:
Exception: Any exception from the hook function is logged and re-raised
"""
if not self.enabled:
logger.debug(f"Hook '{self.name}' is disabled, skipping")
return None
try:
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
result = self.function(context)
logger.debug(f"Hook '{self.name}' completed successfully")
return result
except Exception as e:
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
raise
def __repr__(self) -> str:
status = "enabled" if self.enabled else "disabled"
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
class HookContext:
"""
Context object passed to hooks containing all relevant data.
This is a convenience wrapper around a dictionary that provides
both dict-like access and attribute access.
"""
def __init__(self, **kwargs):
self._data = kwargs
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, value: Any):
self._data[key] = value
def __contains__(self, key: str) -> bool:
return key in self._data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def update(self, other: Dict[str, Any]):
self._data.update(other)
def to_dict(self) -> Dict[str, Any]:
return self._data.copy()
def __repr__(self) -> str:
keys = list(self._data.keys())
return f"HookContext({', '.join(keys)})"

View File

@@ -0,0 +1 @@
# post_extraction hooks

View File

@@ -0,0 +1 @@
# post_mesh hooks

View File

@@ -0,0 +1 @@
# post_solve hooks

View File

@@ -0,0 +1 @@
# pre_mesh hooks

View File

@@ -0,0 +1 @@
# pre_solve hooks

View File

@@ -0,0 +1,56 @@
"""
Example Plugin: Log Trial Start
Simple pre-solve hook that logs trial information.
This demonstrates the plugin API and serves as a template for custom hooks.
"""
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Log trial information before solver execution.
Args:
context: Hook context containing:
- trial_number: Current trial number
- design_variables: Dict of variable values
- sim_file: Path to simulation file
Returns:
Dict with logging status
"""
trial_num = context.get('trial_number', '?')
design_vars = context.get('design_variables', {})
logger.info("=" * 60)
logger.info(f"TRIAL {trial_num} STARTING")
logger.info("=" * 60)
for var_name, var_value in design_vars.items():
logger.info(f" {var_name}: {var_value:.4f}")
return {'logged': True, 'trial': trial_num}
def register_hooks(hook_manager):
"""
Register this plugin's hooks with the manager.
This function is called automatically when the plugin is loaded.
Args:
hook_manager: The HookManager instance
"""
hook_manager.register_hook(
hook_point='pre_solve',
function=trial_start_logger,
description='Log trial number and design variables before solve',
name='log_trial_start',
priority=10 # Run early
)

View File

@@ -0,0 +1,214 @@
"""
Plugin Code Validators
Ensures safety and correctness of plugin code before execution.
"""
import ast
import re
from typing import List, Set
import logging
logger = logging.getLogger(__name__)
class PluginValidationError(Exception):
"""Raised when plugin code fails validation."""
pass
# Whitelist of safe modules for plugin imports
SAFE_MODULES = {
'math', 'numpy', 'scipy', 'pandas',
'pathlib', 'json', 'csv', 'datetime',
'logging', 'typing', 'dataclasses',
'optuna', 'pyNastran'
}
# Blacklist of dangerous operations
DANGEROUS_OPERATIONS = {
'eval', 'exec', 'compile', '__import__',
'open', # File operations should be explicit
'subprocess', 'os.system', 'os.popen',
'shutil.rmtree', # Dangerous file operations
}
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
"""
Validate plugin code for safety before execution.
Args:
code: Python source code to validate
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
Raises:
PluginValidationError: If code contains unsafe operations
Example:
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
>>> validate_plugin_code(code) # Passes
>>>
>>> code = "import os; os.system('rm -rf /')"
>>> validate_plugin_code(code) # Raises PluginValidationError
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error in plugin code: {e}")
# Check for dangerous imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import: {alias.name}. "
f"Allowed modules: {SAFE_MODULES}"
)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module not in SAFE_MODULES:
# Allow submodules of safe modules
base_module = node.module.split('.')[0]
if base_module not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import from: {node.module}"
)
# Check for dangerous function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in DANGEROUS_OPERATIONS:
if node.func.id == 'open' and allow_file_ops:
continue # Allow if explicitly permitted
raise PluginValidationError(
f"Dangerous operation: {node.func.id}"
)
# Check for attribute access to dangerous methods
elif isinstance(node, ast.Attribute):
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
if node.attr in dangerous_attrs and not allow_file_ops:
raise PluginValidationError(
f"Dangerous attribute access: {node.attr}"
)
logger.info("Plugin code validation passed")
def check_hook_function_signature(func_code: str) -> bool:
"""
Verify that a hook function has the correct signature.
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
Args:
func_code: Function source code
Returns:
True if signature is valid
Raises:
PluginValidationError: If signature is invalid
"""
try:
tree = ast.parse(func_code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error: {e}")
# Find function definition
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
raise PluginValidationError("No function definition found")
# Check that it takes exactly one argument (context)
if len(func_def.args.args) != 1:
raise PluginValidationError(
f"Hook function must take exactly 1 argument (context), "
f"got {len(func_def.args.args)}"
)
logger.info(f"Hook function '{func_def.name}' signature is valid")
return True
def sanitize_plugin_name(name: str) -> str:
"""
Sanitize a plugin name to be safe for filesystem and imports.
Args:
name: Original plugin name
Returns:
Sanitized name (lowercase, alphanumeric + underscore only)
"""
# Remove special characters, keep only alphanumeric and underscore
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
# Ensure it doesn't start with a number
if sanitized[0].isdigit():
sanitized = f"plugin_{sanitized}"
return sanitized
def get_imported_modules(code: str) -> Set[str]:
"""
Extract all imported modules from code.
Args:
code: Python source code
Returns:
Set of module names
"""
try:
tree = ast.parse(code)
except SyntaxError:
return set()
modules = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
modules.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.add(node.module)
return modules
def estimate_complexity(code: str) -> int:
"""
Estimate the cyclomatic complexity of code.
Higher values indicate more complex code with more branching.
Args:
code: Python source code
Returns:
Complexity estimate (1 = simple, 10+ = complex)
"""
try:
tree = ast.parse(code)
except SyntaxError:
return 0
complexity = 1 # Base complexity
for node in ast.walk(tree):
# Add 1 for each branching statement
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
return complexity

View File

@@ -23,6 +23,8 @@ import pandas as pd
from datetime import datetime
import pickle
from optimization_engine.plugins import HookManager
class OptimizationRunner:
"""
@@ -68,6 +70,15 @@ class OptimizationRunner:
self.output_dir = self.config_path.parent / 'optimization_results'
self.output_dir.mkdir(exist_ok=True)
# Initialize plugin/hook system
self.hook_manager = HookManager()
plugins_dir = Path(__file__).parent / 'plugins'
if plugins_dir.exists():
self.hook_manager.load_plugins_from_directory(plugins_dir)
summary = self.hook_manager.get_summary()
if summary['total_hooks'] > 0:
print(f"Loaded {summary['enabled_hooks']}/{summary['total_hooks']} plugins")
def _load_config(self) -> Dict[str, Any]:
"""Load and validate optimization configuration."""
with open(self.config_path, 'r') as f:
@@ -311,6 +322,16 @@ class OptimizationRunner:
int(dv['bounds'][1])
)
# Execute pre_solve hooks
pre_solve_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'sim_file': self.config.get('sim_file', ''),
'working_dir': str(Path.cwd()),
'config': self.config
}
self.hook_manager.execute_hooks('pre_solve', pre_solve_context, fail_fast=False)
# 2. Update NX model with new parameters
try:
self.model_updater(design_vars)
@@ -318,6 +339,15 @@ class OptimizationRunner:
print(f"Error updating model: {e}")
raise optuna.TrialPruned()
# Execute post_mesh hooks (after model update)
post_mesh_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'sim_file': self.config.get('sim_file', ''),
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_mesh', post_mesh_context, fail_fast=False)
# 3. Run simulation
try:
result_path = self.simulation_runner()
@@ -325,6 +355,15 @@ class OptimizationRunner:
print(f"Error running simulation: {e}")
raise optuna.TrialPruned()
# Execute post_solve hooks
post_solve_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'result_path': str(result_path) if result_path else '',
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_solve', post_solve_context, fail_fast=False)
# 4. Extract results with appropriate precision
extracted_results = {}
for obj in self.config['objectives']:
@@ -362,6 +401,16 @@ class OptimizationRunner:
print(f"Error extracting {const['name']}: {e}")
raise optuna.TrialPruned()
# Execute post_extraction hooks
post_extraction_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'extracted_results': extracted_results,
'result_path': str(result_path) if result_path else '',
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_extraction', post_extraction_context, fail_fast=False)
# 5. Evaluate constraints
for const in self.config.get('constraints', []):
value = extracted_results[const['name']]
@@ -389,6 +438,23 @@ class OptimizationRunner:
else: # maximize
total_objective -= weight * value
# Execute custom_objective hooks (can modify total_objective)
custom_objective_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'extracted_results': extracted_results,
'total_objective': total_objective,
'working_dir': str(Path.cwd())
}
custom_results = self.hook_manager.execute_hooks('custom_objective', custom_objective_context, fail_fast=False)
# Allow hooks to override objective value
for result in custom_results:
if result and 'total_objective' in result:
total_objective = result['total_objective']
print(f"Custom objective hook modified total_objective to {total_objective:.6f}")
break # Use first hook that provides override
# 7. Store results in history
history_entry = {
'trial_number': trial.number,

438
tests/test_plugin_system.py Normal file
View File

@@ -0,0 +1,438 @@
"""
Tests for the Plugin System
Validates hook registration, execution, validation, and integration.
"""
import pytest
from pathlib import Path
from typing import Dict, Any, Optional
from optimization_engine.plugins import HookManager, Hook, HookPoint
from optimization_engine.plugins.validators import (
validate_plugin_code,
PluginValidationError,
check_hook_function_signature,
sanitize_plugin_name,
get_imported_modules,
estimate_complexity
)
class TestHookRegistration:
"""Test hook registration and management."""
def test_register_simple_hook(self):
"""Test basic hook registration."""
manager = HookManager()
def my_hook(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
return {'status': 'success'}
hook = manager.register_hook(
hook_point='pre_solve',
function=my_hook,
description='Test hook',
name='test_hook'
)
assert hook.name == 'test_hook'
assert hook.hook_point == HookPoint.PRE_SOLVE
assert hook.enabled is True
def test_hook_priority_ordering(self):
"""Test that hooks execute in priority order."""
manager = HookManager()
execution_order = []
def hook_high_priority(context):
execution_order.append('high')
return None
def hook_low_priority(context):
execution_order.append('low')
return None
# Register low priority first, high priority second
manager.register_hook('pre_solve', hook_low_priority, 'Low', priority=200)
manager.register_hook('pre_solve', hook_high_priority, 'High', priority=10)
# Execute hooks
manager.execute_hooks('pre_solve', {'trial_number': 1})
# High priority should execute first
assert execution_order == ['high', 'low']
def test_disable_enable_hook(self):
"""Test disabling and enabling hooks."""
manager = HookManager()
execution_count = [0]
def counting_hook(context):
execution_count[0] += 1
return None
hook = manager.register_hook(
'pre_solve',
counting_hook,
'Counter',
name='counter_hook'
)
# Execute while enabled
manager.execute_hooks('pre_solve', {})
assert execution_count[0] == 1
# Disable and execute
manager.disable_hook('counter_hook')
manager.execute_hooks('pre_solve', {})
assert execution_count[0] == 1 # Should not increment
# Re-enable and execute
manager.enable_hook('counter_hook')
manager.execute_hooks('pre_solve', {})
assert execution_count[0] == 2
def test_remove_hook(self):
"""Test hook removal."""
manager = HookManager()
def test_hook(context):
return None
manager.register_hook('pre_solve', test_hook, 'Test', name='removable')
assert len(manager.get_hooks('pre_solve')) == 1
success = manager.remove_hook('removable')
assert success is True
assert len(manager.get_hooks('pre_solve')) == 0
# Try removing non-existent hook
success = manager.remove_hook('nonexistent')
assert success is False
class TestHookExecution:
"""Test hook execution behavior."""
def test_hook_receives_context(self):
"""Test that hooks receive correct context."""
manager = HookManager()
received_context = {}
def context_checker(context: Dict[str, Any]):
received_context.update(context)
return None
manager.register_hook('pre_solve', context_checker, 'Checker')
test_context = {
'trial_number': 42,
'design_variables': {'thickness': 5.0},
'sim_file': 'test.sim'
}
manager.execute_hooks('pre_solve', test_context)
assert received_context['trial_number'] == 42
assert received_context['design_variables']['thickness'] == 5.0
def test_hook_return_values(self):
"""Test that hook return values are collected."""
manager = HookManager()
def hook1(context):
return {'result': 'hook1'}
def hook2(context):
return {'result': 'hook2'}
manager.register_hook('pre_solve', hook1, 'Hook 1')
manager.register_hook('pre_solve', hook2, 'Hook 2')
results = manager.execute_hooks('pre_solve', {})
assert len(results) == 2
assert results[0]['result'] == 'hook1'
assert results[1]['result'] == 'hook2'
def test_hook_error_handling_fail_fast(self):
"""Test error handling with fail_fast=True."""
manager = HookManager()
def failing_hook(context):
raise ValueError("Intentional error")
manager.register_hook('pre_solve', failing_hook, 'Failing')
with pytest.raises(ValueError, match="Intentional error"):
manager.execute_hooks('pre_solve', {}, fail_fast=True)
def test_hook_error_handling_continue(self):
"""Test error handling with fail_fast=False."""
manager = HookManager()
execution_log = []
def failing_hook(context):
execution_log.append('failing')
raise ValueError("Intentional error")
def successful_hook(context):
execution_log.append('successful')
return {'status': 'ok'}
manager.register_hook('pre_solve', failing_hook, 'Failing', priority=10)
manager.register_hook('pre_solve', successful_hook, 'Success', priority=20)
results = manager.execute_hooks('pre_solve', {}, fail_fast=False)
# Both hooks should have attempted execution
assert execution_log == ['failing', 'successful']
# First result is None (error), second is successful
assert results[0] is None
assert results[1]['status'] == 'ok'
def test_hook_history_tracking(self):
"""Test that hook execution history is tracked."""
manager = HookManager()
def test_hook(context):
return {'result': 'success'}
manager.register_hook('pre_solve', test_hook, 'Test', name='tracked')
# Execute hooks multiple times
for i in range(3):
manager.execute_hooks('pre_solve', {'trial_number': i})
history = manager.get_history()
assert len(history) >= 3
# Check history contains success records
successful = [h for h in history if h['success']]
assert len(successful) >= 3
class TestCodeValidation:
"""Test plugin code validation."""
def test_safe_code_passes(self):
"""Test that safe code passes validation."""
safe_code = """
import numpy as np
import math
def my_hook(context):
x = context['design_variables']['thickness']
result = math.sqrt(x**2 + np.mean([1, 2, 3]))
return {'result': result}
"""
# Should not raise
validate_plugin_code(safe_code)
def test_dangerous_import_blocked(self):
"""Test that dangerous imports are blocked."""
dangerous_code = """
import os
def my_hook(context):
os.system('rm -rf /')
return None
"""
with pytest.raises(PluginValidationError, match="Unsafe import"):
validate_plugin_code(dangerous_code)
def test_dangerous_operation_blocked(self):
"""Test that dangerous operations are blocked."""
dangerous_code = """
def my_hook(context):
eval('malicious_code')
return None
"""
with pytest.raises(PluginValidationError, match="Dangerous operation"):
validate_plugin_code(dangerous_code)
def test_file_operations_with_permission(self):
"""Test that file operations work with allow_file_ops=True."""
code_with_file_ops = """
def my_hook(context):
with open('output.txt', 'w') as f:
f.write('test')
return None
"""
# Should raise without permission
with pytest.raises(PluginValidationError, match="Dangerous operation: open"):
validate_plugin_code(code_with_file_ops, allow_file_ops=False)
# Should pass with permission
validate_plugin_code(code_with_file_ops, allow_file_ops=True)
def test_syntax_error_detected(self):
"""Test that syntax errors are detected."""
bad_syntax = """
def my_hook(context)
return None # Missing colon
"""
with pytest.raises(PluginValidationError, match="Syntax error"):
validate_plugin_code(bad_syntax)
def test_hook_signature_validation(self):
"""Test hook function signature validation."""
# Valid signature
valid_code = """
def my_hook(context):
return None
"""
assert check_hook_function_signature(valid_code) is True
# Invalid: too many arguments
invalid_code = """
def my_hook(context, extra_arg):
return None
"""
with pytest.raises(PluginValidationError, match="must take exactly 1 argument"):
check_hook_function_signature(invalid_code)
# Invalid: no function
invalid_code = """
x = 5
"""
with pytest.raises(PluginValidationError, match="No function definition"):
check_hook_function_signature(invalid_code)
class TestUtilityFunctions:
"""Test plugin utility functions."""
def test_sanitize_plugin_name(self):
"""Test plugin name sanitization."""
assert sanitize_plugin_name('My-Plugin!') == 'my_plugin_'
assert sanitize_plugin_name('123_plugin') == 'plugin_123_plugin'
assert sanitize_plugin_name('valid_name') == 'valid_name'
def test_get_imported_modules(self):
"""Test module extraction from code."""
code = """
import numpy as np
import math
from pathlib import Path
from typing import Dict, Any
"""
modules = get_imported_modules(code)
assert 'numpy' in modules
assert 'math' in modules
assert 'pathlib' in modules
assert 'typing' in modules
def test_estimate_complexity(self):
"""Test cyclomatic complexity estimation."""
simple_code = """
def simple(x):
return x + 1
"""
complex_code = """
def complex(x):
if x > 0:
for i in range(x):
if i % 2 == 0:
while i > 0:
i -= 1
return x
"""
simple_complexity = estimate_complexity(simple_code)
complex_complexity = estimate_complexity(complex_code)
assert simple_complexity == 1
assert complex_complexity > simple_complexity
class TestHookManager:
"""Test HookManager functionality."""
def test_get_summary(self):
"""Test hook system summary."""
manager = HookManager()
def hook1(context):
return None
def hook2(context):
return None
manager.register_hook('pre_solve', hook1, 'Hook 1', name='hook1')
manager.register_hook('post_solve', hook2, 'Hook 2', name='hook2')
manager.disable_hook('hook2')
summary = manager.get_summary()
assert summary['total_hooks'] == 2
assert summary['enabled_hooks'] == 1
assert summary['disabled_hooks'] == 1
def test_clear_history(self):
"""Test clearing execution history."""
manager = HookManager()
def test_hook(context):
return None
manager.register_hook('pre_solve', test_hook, 'Test')
manager.execute_hooks('pre_solve', {})
assert len(manager.get_history()) > 0
manager.clear_history()
assert len(manager.get_history()) == 0
def test_hook_manager_repr(self):
"""Test HookManager string representation."""
manager = HookManager()
def hook(context):
return None
manager.register_hook('pre_solve', hook, 'Test')
repr_str = repr(manager)
assert 'HookManager' in repr_str
assert 'hooks=1' in repr_str
assert 'enabled=1' in repr_str
class TestPluginLoading:
"""Test plugin directory loading."""
def test_load_plugins_from_nonexistent_directory(self):
"""Test loading from non-existent directory."""
manager = HookManager()
# Should not raise, just log warning
manager.load_plugins_from_directory(Path('/nonexistent/path'))
def test_plugin_registration_function(self):
"""Test that plugins can register hooks via register_hooks()."""
manager = HookManager()
# Simulate what a plugin file would contain
def register_hooks(hook_manager):
def my_plugin_hook(context):
return {'plugin': 'loaded'}
hook_manager.register_hook(
'pre_solve',
my_plugin_hook,
'Plugin hook',
name='plugin_hook'
)
# Call the registration function
register_hooks(manager)
# Verify hook was registered
hooks = manager.get_hooks('pre_solve')
assert len(hooks) == 1
assert hooks[0].name == 'plugin_hook'
if __name__ == '__main__':
pytest.main([__file__, '-v'])