feat: Implement Phase 1 - Plugin & Hook System
Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
243
optimization_engine/feature_registry.json
Normal file
243
optimization_engine/feature_registry.json
Normal file
@@ -0,0 +1,243 @@
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"last_updated": "2025-01-15",
|
||||
"description": "Registry of all Atomizer capabilities for LLM discovery and usage",
|
||||
|
||||
"core_features": {
|
||||
"optimization": {
|
||||
"description": "Core optimization engine using Optuna",
|
||||
"module": "optimization_engine.runner",
|
||||
"capabilities": [
|
||||
"Multi-objective optimization with weighted sum",
|
||||
"TPE (Tree-structured Parzen Estimator) sampler",
|
||||
"CMA-ES sampler",
|
||||
"Gaussian Process sampler",
|
||||
"50-trial default with 20 startup trials",
|
||||
"Automatic checkpoint and resume",
|
||||
"SQLite-based study persistence"
|
||||
],
|
||||
"usage": "python examples/test_journal_optimization.py",
|
||||
"llm_hint": "Use this for Bayesian optimization with NX simulations"
|
||||
},
|
||||
|
||||
"nx_integration": {
|
||||
"description": "Siemens NX simulation automation via journal scripts",
|
||||
"module": "optimization_engine.nx_solver",
|
||||
"capabilities": [
|
||||
"Update CAD expressions via NXOpen",
|
||||
"Execute NX Nastran solver",
|
||||
"Extract OP2 results (stress, displacement)",
|
||||
"Extract mass properties",
|
||||
"Precision control (4 decimals for mm/degrees/MPa)"
|
||||
],
|
||||
"usage": "from optimization_engine.nx_solver import run_nx_simulation",
|
||||
"llm_hint": "Use for running FEA simulations and extracting results"
|
||||
},
|
||||
|
||||
"result_extraction": {
|
||||
"description": "Extract metrics from simulation results",
|
||||
"module": "optimization_engine.result_extractors",
|
||||
"extractors": {
|
||||
"stress_extractor": {
|
||||
"description": "Extract stress data from OP2 files",
|
||||
"metrics": ["max_von_mises", "mean_von_mises", "max_principal"],
|
||||
"file_type": "OP2",
|
||||
"usage": "Returns stress in MPa"
|
||||
},
|
||||
"displacement_extractor": {
|
||||
"description": "Extract displacement data from OP2 files",
|
||||
"metrics": ["max_displacement", "mean_displacement"],
|
||||
"file_type": "OP2",
|
||||
"usage": "Returns displacement in mm"
|
||||
},
|
||||
"mass_extractor": {
|
||||
"description": "Extract mass properties",
|
||||
"metrics": ["total_mass", "center_of_gravity"],
|
||||
"file_type": "NX Part",
|
||||
"usage": "Returns mass in kg"
|
||||
}
|
||||
},
|
||||
"llm_hint": "Use extractors to define objectives and constraints"
|
||||
}
|
||||
},
|
||||
|
||||
"plugin_system": {
|
||||
"description": "Extensible hook system for custom functionality",
|
||||
"module": "optimization_engine.plugins",
|
||||
"version": "1.0.0",
|
||||
|
||||
"hook_points": {
|
||||
"pre_mesh": {
|
||||
"description": "Execute before meshing operations",
|
||||
"context": ["trial_number", "design_variables", "sim_file", "working_dir"],
|
||||
"use_cases": [
|
||||
"Modify geometry based on parameters",
|
||||
"Set up boundary conditions",
|
||||
"Configure mesh settings"
|
||||
]
|
||||
},
|
||||
"post_mesh": {
|
||||
"description": "Execute after meshing, before solve",
|
||||
"context": ["trial_number", "mesh_info", "element_count", "node_count"],
|
||||
"use_cases": [
|
||||
"Validate mesh quality",
|
||||
"Export mesh for visualization",
|
||||
"Log mesh statistics"
|
||||
]
|
||||
},
|
||||
"pre_solve": {
|
||||
"description": "Execute before solver launch",
|
||||
"context": ["trial_number", "design_variables", "solver_settings"],
|
||||
"use_cases": [
|
||||
"Log trial parameters",
|
||||
"Modify solver settings",
|
||||
"Set up custom load cases"
|
||||
]
|
||||
},
|
||||
"post_solve": {
|
||||
"description": "Execute after solve, before result extraction",
|
||||
"context": ["trial_number", "solve_status", "output_files"],
|
||||
"use_cases": [
|
||||
"Check solver convergence",
|
||||
"Post-process results",
|
||||
"Generate visualizations"
|
||||
]
|
||||
},
|
||||
"post_extraction": {
|
||||
"description": "Execute after result extraction",
|
||||
"context": ["trial_number", "extracted_results", "objectives", "constraints"],
|
||||
"use_cases": [
|
||||
"Calculate custom metrics",
|
||||
"Combine multiple objectives (RSS)",
|
||||
"Apply penalty functions"
|
||||
]
|
||||
},
|
||||
"custom_objective": {
|
||||
"description": "Define custom objective functions",
|
||||
"context": ["extracted_results", "design_variables"],
|
||||
"use_cases": [
|
||||
"RSS of stress and displacement",
|
||||
"Weighted multi-criteria",
|
||||
"Conditional objectives"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
"api": {
|
||||
"register_hook": {
|
||||
"description": "Register a new hook function",
|
||||
"signature": "hook_manager.register_hook(hook_point, function, description, name=None, priority=100)",
|
||||
"parameters": {
|
||||
"hook_point": "One of: pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objective",
|
||||
"function": "Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]",
|
||||
"description": "Human-readable description",
|
||||
"priority": "Execution order (lower = earlier)"
|
||||
},
|
||||
"example": "See optimization_engine/plugins/pre_solve/log_trial_start.py"
|
||||
},
|
||||
"execute_hooks": {
|
||||
"description": "Execute all hooks at a specific point",
|
||||
"signature": "hook_manager.execute_hooks(hook_point, context, fail_fast=False)",
|
||||
"returns": "List of hook results"
|
||||
}
|
||||
},
|
||||
|
||||
"validators": {
|
||||
"validate_plugin_code": {
|
||||
"description": "Validate plugin code for safety",
|
||||
"checks": [
|
||||
"Syntax errors",
|
||||
"Dangerous imports (os.system, subprocess, etc.)",
|
||||
"File operations (optional allow)",
|
||||
"Function signature correctness"
|
||||
],
|
||||
"safe_modules": ["math", "numpy", "scipy", "pandas", "pathlib", "json", "optuna", "pyNastran"],
|
||||
"llm_hint": "Always validate LLM-generated code before execution"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"design_variables": {
|
||||
"description": "Parametric CAD variables to optimize",
|
||||
"schema": {
|
||||
"name": "Unique identifier",
|
||||
"expression_name": "NX expression name",
|
||||
"min": "Lower bound (float)",
|
||||
"max": "Upper bound (float)",
|
||||
"units": "Unit system (mm, degrees, etc.)"
|
||||
},
|
||||
"example": {
|
||||
"name": "wall_thickness",
|
||||
"expression_name": "wall_thickness",
|
||||
"min": 3.0,
|
||||
"max": 8.0,
|
||||
"units": "mm"
|
||||
}
|
||||
},
|
||||
|
||||
"objectives": {
|
||||
"description": "Metrics to minimize or maximize",
|
||||
"schema": {
|
||||
"name": "Unique identifier",
|
||||
"extractor": "Result extractor to use",
|
||||
"metric": "Specific metric from extractor",
|
||||
"direction": "minimize or maximize",
|
||||
"weight": "Importance (for multi-objective)",
|
||||
"units": "Unit system"
|
||||
},
|
||||
"example": {
|
||||
"name": "max_stress",
|
||||
"extractor": "stress_extractor",
|
||||
"metric": "max_von_mises",
|
||||
"direction": "minimize",
|
||||
"weight": 1.0,
|
||||
"units": "MPa"
|
||||
}
|
||||
},
|
||||
|
||||
"constraints": {
|
||||
"description": "Limits on simulation outputs",
|
||||
"schema": {
|
||||
"name": "Unique identifier",
|
||||
"extractor": "Result extractor to use",
|
||||
"metric": "Specific metric",
|
||||
"type": "upper_bound or lower_bound",
|
||||
"limit": "Constraint value",
|
||||
"units": "Unit system"
|
||||
},
|
||||
"example": {
|
||||
"name": "max_displacement_limit",
|
||||
"extractor": "displacement_extractor",
|
||||
"metric": "max_displacement",
|
||||
"type": "upper_bound",
|
||||
"limit": 1.0,
|
||||
"units": "mm"
|
||||
}
|
||||
},
|
||||
|
||||
"examples": {
|
||||
"bracket_optimization": {
|
||||
"description": "Minimize stress on a bracket by varying wall thickness",
|
||||
"location": "examples/bracket/",
|
||||
"design_variables": ["wall_thickness"],
|
||||
"objectives": ["max_von_mises"],
|
||||
"trials": 50,
|
||||
"typical_runtime": "2-3 hours",
|
||||
"llm_hint": "Good template for single-objective structural optimization"
|
||||
}
|
||||
},
|
||||
|
||||
"llm_guidelines": {
|
||||
"code_generation": {
|
||||
"hook_template": "Always include: function signature with context dict, docstring, return dict",
|
||||
"validation": "Use validate_plugin_code() before registration",
|
||||
"error_handling": "Wrap in try-except, log errors, return None on failure"
|
||||
},
|
||||
"natural_language_mapping": {
|
||||
"minimize stress": "objective with direction='minimize', extractor='stress_extractor'",
|
||||
"vary thickness 3-8mm": "design_variable with min=3.0, max=8.0, units='mm'",
|
||||
"displacement < 1mm": "constraint with type='upper_bound', limit=1.0",
|
||||
"RSS of stress and displacement": "custom_objective hook with sqrt(stress² + displacement²)"
|
||||
}
|
||||
}
|
||||
}
|
||||
32
optimization_engine/plugins/__init__.py
Normal file
32
optimization_engine/plugins/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Atomizer Plugin System
|
||||
|
||||
Enables extensibility through hooks at various stages of the optimization lifecycle.
|
||||
|
||||
Hook Points:
|
||||
- pre_mesh: Before meshing operations
|
||||
- post_mesh: After meshing, before solve
|
||||
- pre_solve: Before solver execution
|
||||
- post_solve: After solve, before result extraction
|
||||
- post_extraction: After result extraction, before objective calculation
|
||||
- custom_objectives: Custom objective/constraint functions
|
||||
|
||||
Usage:
|
||||
from optimization_engine.plugins import HookManager
|
||||
|
||||
hook_manager = HookManager()
|
||||
hook_manager.register_hook('pre_solve', my_custom_function)
|
||||
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
|
||||
"""
|
||||
|
||||
from .hooks import Hook, HookPoint
|
||||
from .hook_manager import HookManager
|
||||
from .validators import validate_plugin_code, PluginValidationError
|
||||
|
||||
__all__ = [
|
||||
'Hook',
|
||||
'HookPoint',
|
||||
'HookManager',
|
||||
'validate_plugin_code',
|
||||
'PluginValidationError'
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
# custom_objectives hooks
|
||||
303
optimization_engine/plugins/hook_manager.py
Normal file
303
optimization_engine/plugins/hook_manager.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Hook Manager for Atomizer Plugin System
|
||||
|
||||
Manages registration, execution, and lifecycle of hooks.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Callable, Any, Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import importlib.util
|
||||
import json
|
||||
|
||||
from .hooks import Hook, HookPoint, HookContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""
|
||||
Central manager for all hooks in the optimization system.
|
||||
|
||||
Example:
|
||||
>>> manager = HookManager()
|
||||
>>> def my_hook(context):
|
||||
>>> print(f"Trial {context['trial_number']} starting")
|
||||
>>> return {'status': 'logged'}
|
||||
>>>
|
||||
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
|
||||
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
|
||||
self._hook_history: List[Dict[str, Any]] = []
|
||||
logger.info("HookManager initialized")
|
||||
|
||||
def register_hook(
|
||||
self,
|
||||
hook_point: str | HookPoint,
|
||||
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
|
||||
description: str,
|
||||
name: Optional[str] = None,
|
||||
priority: int = 100,
|
||||
enabled: bool = True
|
||||
) -> Hook:
|
||||
"""
|
||||
Register a new hook function.
|
||||
|
||||
Args:
|
||||
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
|
||||
function: Callable that takes context dict, returns optional dict
|
||||
description: Human-readable description
|
||||
name: Unique name (auto-generated if not provided)
|
||||
priority: Execution order (lower = earlier)
|
||||
enabled: Whether hook is active
|
||||
|
||||
Returns:
|
||||
The created Hook object
|
||||
|
||||
Raises:
|
||||
ValueError: If hook_point is invalid
|
||||
"""
|
||||
# Convert string to HookPoint enum
|
||||
if isinstance(hook_point, str):
|
||||
try:
|
||||
hook_point = HookPoint(hook_point)
|
||||
except ValueError:
|
||||
valid_points = [p.value for p in HookPoint]
|
||||
raise ValueError(
|
||||
f"Invalid hook_point '{hook_point}'. "
|
||||
f"Valid options: {valid_points}"
|
||||
)
|
||||
|
||||
# Auto-generate name if not provided
|
||||
if name is None:
|
||||
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
|
||||
|
||||
# Create hook
|
||||
hook = Hook(
|
||||
name=name,
|
||||
hook_point=hook_point,
|
||||
function=function,
|
||||
description=description,
|
||||
priority=priority,
|
||||
enabled=enabled
|
||||
)
|
||||
|
||||
# Add to registry and sort by priority
|
||||
self._hooks[hook_point].append(hook)
|
||||
self._hooks[hook_point].sort(key=lambda h: h.priority)
|
||||
|
||||
logger.info(f"Registered hook: {hook}")
|
||||
return hook
|
||||
|
||||
def execute_hooks(
|
||||
self,
|
||||
hook_point: str | HookPoint,
|
||||
context: Dict[str, Any],
|
||||
fail_fast: bool = False
|
||||
) -> List[Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute all hooks registered at a specific point.
|
||||
|
||||
Args:
|
||||
hook_point: The execution point
|
||||
context: Data to pass to hooks
|
||||
fail_fast: If True, stop on first error. If False, log and continue.
|
||||
|
||||
Returns:
|
||||
List of results from each hook (None if hook returned nothing)
|
||||
|
||||
Raises:
|
||||
Exception: If fail_fast=True and a hook fails
|
||||
"""
|
||||
# Convert string to enum
|
||||
if isinstance(hook_point, str):
|
||||
hook_point = HookPoint(hook_point)
|
||||
|
||||
hooks = self._hooks[hook_point]
|
||||
if not hooks:
|
||||
logger.debug(f"No hooks registered for {hook_point.value}")
|
||||
return []
|
||||
|
||||
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
|
||||
|
||||
results = []
|
||||
for hook in hooks:
|
||||
try:
|
||||
result = hook.execute(context)
|
||||
results.append(result)
|
||||
|
||||
# Record in history
|
||||
self._hook_history.append({
|
||||
'hook_name': hook.name,
|
||||
'hook_point': hook_point.value,
|
||||
'success': True,
|
||||
'trial_number': context.get('trial_number', -1)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hook '{hook.name}' failed: {e}")
|
||||
|
||||
# Record failure
|
||||
self._hook_history.append({
|
||||
'hook_name': hook.name,
|
||||
'hook_point': hook_point.value,
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'trial_number': context.get('trial_number', -1)
|
||||
})
|
||||
|
||||
if fail_fast:
|
||||
raise
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
def load_plugins_from_directory(self, directory: Path):
|
||||
"""
|
||||
Auto-discover and load plugins from a directory.
|
||||
|
||||
Expected structure:
|
||||
plugins/
|
||||
pre_mesh/
|
||||
my_plugin.py # Contains register_hooks(manager) function
|
||||
post_solve/
|
||||
another_plugin.py
|
||||
|
||||
Args:
|
||||
directory: Path to plugins directory
|
||||
"""
|
||||
if not directory.exists():
|
||||
logger.warning(f"Plugins directory not found: {directory}")
|
||||
return
|
||||
|
||||
logger.info(f"Loading plugins from {directory}")
|
||||
|
||||
for hook_point in HookPoint:
|
||||
hook_dir = directory / hook_point.value
|
||||
if not hook_dir.exists():
|
||||
continue
|
||||
|
||||
for plugin_file in hook_dir.glob("*.py"):
|
||||
if plugin_file.name.startswith("_"):
|
||||
continue # Skip __init__.py and private files
|
||||
|
||||
try:
|
||||
self._load_plugin_file(plugin_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load plugin {plugin_file}: {e}")
|
||||
|
||||
def _load_plugin_file(self, plugin_file: Path):
|
||||
"""Load a single plugin file and call its register_hooks function."""
|
||||
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Could not load spec for {plugin_file}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Call register_hooks if it exists
|
||||
if hasattr(module, 'register_hooks'):
|
||||
module.register_hooks(self)
|
||||
logger.info(f"Loaded plugin: {plugin_file.stem}")
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
|
||||
|
||||
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
|
||||
"""Get all hooks registered at a specific point."""
|
||||
if isinstance(hook_point, str):
|
||||
hook_point = HookPoint(hook_point)
|
||||
return self._hooks[hook_point].copy()
|
||||
|
||||
def remove_hook(self, name: str) -> bool:
|
||||
"""
|
||||
Remove a hook by name.
|
||||
|
||||
Returns:
|
||||
True if hook was found and removed, False otherwise
|
||||
"""
|
||||
for point, hooks in self._hooks.items():
|
||||
for i, hook in enumerate(hooks):
|
||||
if hook.name == name:
|
||||
del hooks[i]
|
||||
logger.info(f"Removed hook: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_hook(self, name: str):
|
||||
"""Enable a hook by name."""
|
||||
self._set_hook_enabled(name, True)
|
||||
|
||||
def disable_hook(self, name: str):
|
||||
"""Disable a hook by name."""
|
||||
self._set_hook_enabled(name, False)
|
||||
|
||||
def _set_hook_enabled(self, name: str, enabled: bool):
|
||||
"""Set enabled status for a hook."""
|
||||
for hooks in self._hooks.values():
|
||||
for hook in hooks:
|
||||
if hook.name == name:
|
||||
hook.enabled = enabled
|
||||
status = "enabled" if enabled else "disabled"
|
||||
logger.info(f"Hook '{name}' {status}")
|
||||
return
|
||||
logger.warning(f"Hook '{name}' not found")
|
||||
|
||||
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get hook execution history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of records to return (most recent)
|
||||
|
||||
Returns:
|
||||
List of execution records
|
||||
"""
|
||||
if limit:
|
||||
return self._hook_history[-limit:]
|
||||
return self._hook_history.copy()
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear all execution history."""
|
||||
self._hook_history.clear()
|
||||
logger.info("Hook history cleared")
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of the hook system state.
|
||||
|
||||
Returns:
|
||||
Dictionary with hook counts, enabled status, etc.
|
||||
"""
|
||||
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
|
||||
enabled_hooks = sum(
|
||||
sum(1 for h in hooks if h.enabled)
|
||||
for hooks in self._hooks.values()
|
||||
)
|
||||
|
||||
by_point = {
|
||||
point.value: {
|
||||
'total': len(hooks),
|
||||
'enabled': sum(1 for h in hooks if h.enabled),
|
||||
'names': [h.name for h in hooks]
|
||||
}
|
||||
for point, hooks in self._hooks.items()
|
||||
}
|
||||
|
||||
return {
|
||||
'total_hooks': total_hooks,
|
||||
'enabled_hooks': enabled_hooks,
|
||||
'disabled_hooks': total_hooks - enabled_hooks,
|
||||
'by_hook_point': by_point,
|
||||
'history_records': len(self._hook_history)
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
summary = self.get_summary()
|
||||
return (
|
||||
f"HookManager(hooks={summary['total_hooks']}, "
|
||||
f"enabled={summary['enabled_hooks']})"
|
||||
)
|
||||
115
optimization_engine/plugins/hooks.py
Normal file
115
optimization_engine/plugins/hooks.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Core Hook System for Atomizer
|
||||
|
||||
Defines hook points in the optimization lifecycle and hook registration mechanism.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookPoint(Enum):
|
||||
"""Enumeration of available hook points in the optimization lifecycle."""
|
||||
|
||||
PRE_MESH = "pre_mesh" # Before meshing
|
||||
POST_MESH = "post_mesh" # After meshing, before solve
|
||||
PRE_SOLVE = "pre_solve" # Before solver execution
|
||||
POST_SOLVE = "post_solve" # After solve, before extraction
|
||||
POST_EXTRACTION = "post_extraction" # After result extraction
|
||||
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hook:
|
||||
"""
|
||||
Represents a single hook function to be executed at a specific point.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for this hook
|
||||
hook_point: When this hook should execute (HookPoint enum)
|
||||
function: The callable to execute
|
||||
description: Human-readable description of what this hook does
|
||||
priority: Execution order (lower = earlier, default=100)
|
||||
enabled: Whether this hook is currently active
|
||||
"""
|
||||
|
||||
name: str
|
||||
hook_point: HookPoint
|
||||
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
|
||||
description: str
|
||||
priority: int = 100
|
||||
enabled: bool = True
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Execute this hook with the given context.
|
||||
|
||||
Args:
|
||||
context: Dictionary containing relevant data for this hook point
|
||||
Common keys:
|
||||
- trial_number: Current trial number
|
||||
- design_variables: Current design variable values
|
||||
- sim_file: Path to simulation file
|
||||
- working_dir: Current working directory
|
||||
|
||||
Returns:
|
||||
Optional dictionary with results or modifications to context
|
||||
|
||||
Raises:
|
||||
Exception: Any exception from the hook function is logged and re-raised
|
||||
"""
|
||||
if not self.enabled:
|
||||
logger.debug(f"Hook '{self.name}' is disabled, skipping")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
|
||||
result = self.function(context)
|
||||
logger.debug(f"Hook '{self.name}' completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "enabled" if self.enabled else "disabled"
|
||||
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
|
||||
|
||||
|
||||
class HookContext:
|
||||
"""
|
||||
Context object passed to hooks containing all relevant data.
|
||||
|
||||
This is a convenience wrapper around a dictionary that provides
|
||||
both dict-like access and attribute access.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._data = kwargs
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
self._data[key] = value
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def update(self, other: Dict[str, Any]):
|
||||
self._data.update(other)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._data.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = list(self._data.keys())
|
||||
return f"HookContext({', '.join(keys)})"
|
||||
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_extraction hooks
|
||||
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_mesh hooks
|
||||
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_solve hooks
|
||||
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# pre_mesh hooks
|
||||
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# pre_solve hooks
|
||||
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Example Plugin: Log Trial Start
|
||||
|
||||
Simple pre-solve hook that logs trial information.
|
||||
|
||||
This demonstrates the plugin API and serves as a template for custom hooks.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Log trial information before solver execution.
|
||||
|
||||
Args:
|
||||
context: Hook context containing:
|
||||
- trial_number: Current trial number
|
||||
- design_variables: Dict of variable values
|
||||
- sim_file: Path to simulation file
|
||||
|
||||
Returns:
|
||||
Dict with logging status
|
||||
"""
|
||||
trial_num = context.get('trial_number', '?')
|
||||
design_vars = context.get('design_variables', {})
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"TRIAL {trial_num} STARTING")
|
||||
logger.info("=" * 60)
|
||||
|
||||
for var_name, var_value in design_vars.items():
|
||||
logger.info(f" {var_name}: {var_value:.4f}")
|
||||
|
||||
return {'logged': True, 'trial': trial_num}
|
||||
|
||||
|
||||
def register_hooks(hook_manager):
|
||||
"""
|
||||
Register this plugin's hooks with the manager.
|
||||
|
||||
This function is called automatically when the plugin is loaded.
|
||||
|
||||
Args:
|
||||
hook_manager: The HookManager instance
|
||||
"""
|
||||
hook_manager.register_hook(
|
||||
hook_point='pre_solve',
|
||||
function=trial_start_logger,
|
||||
description='Log trial number and design variables before solve',
|
||||
name='log_trial_start',
|
||||
priority=10 # Run early
|
||||
)
|
||||
214
optimization_engine/plugins/validators.py
Normal file
214
optimization_engine/plugins/validators.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Plugin Code Validators
|
||||
|
||||
Ensures safety and correctness of plugin code before execution.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginValidationError(Exception):
|
||||
"""Raised when plugin code fails validation."""
|
||||
pass
|
||||
|
||||
|
||||
# Whitelist of safe modules for plugin imports
|
||||
SAFE_MODULES = {
|
||||
'math', 'numpy', 'scipy', 'pandas',
|
||||
'pathlib', 'json', 'csv', 'datetime',
|
||||
'logging', 'typing', 'dataclasses',
|
||||
'optuna', 'pyNastran'
|
||||
}
|
||||
|
||||
# Blacklist of dangerous operations
|
||||
DANGEROUS_OPERATIONS = {
|
||||
'eval', 'exec', 'compile', '__import__',
|
||||
'open', # File operations should be explicit
|
||||
'subprocess', 'os.system', 'os.popen',
|
||||
'shutil.rmtree', # Dangerous file operations
|
||||
}
|
||||
|
||||
|
||||
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
|
||||
"""
|
||||
Validate plugin code for safety before execution.
|
||||
|
||||
Args:
|
||||
code: Python source code to validate
|
||||
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If code contains unsafe operations
|
||||
|
||||
Example:
|
||||
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
|
||||
>>> validate_plugin_code(code) # Passes
|
||||
>>>
|
||||
>>> code = "import os; os.system('rm -rf /')"
|
||||
>>> validate_plugin_code(code) # Raises PluginValidationError
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error in plugin code: {e}")
|
||||
|
||||
# Check for dangerous imports
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import: {alias.name}. "
|
||||
f"Allowed modules: {SAFE_MODULES}"
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module and node.module not in SAFE_MODULES:
|
||||
# Allow submodules of safe modules
|
||||
base_module = node.module.split('.')[0]
|
||||
if base_module not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import from: {node.module}"
|
||||
)
|
||||
|
||||
# Check for dangerous function calls
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in DANGEROUS_OPERATIONS:
|
||||
if node.func.id == 'open' and allow_file_ops:
|
||||
continue # Allow if explicitly permitted
|
||||
raise PluginValidationError(
|
||||
f"Dangerous operation: {node.func.id}"
|
||||
)
|
||||
|
||||
# Check for attribute access to dangerous methods
|
||||
elif isinstance(node, ast.Attribute):
|
||||
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
|
||||
if node.attr in dangerous_attrs and not allow_file_ops:
|
||||
raise PluginValidationError(
|
||||
f"Dangerous attribute access: {node.attr}"
|
||||
)
|
||||
|
||||
logger.info("Plugin code validation passed")
|
||||
|
||||
|
||||
def check_hook_function_signature(func_code: str) -> bool:
|
||||
"""
|
||||
Verify that a hook function has the correct signature.
|
||||
|
||||
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||
|
||||
Args:
|
||||
func_code: Function source code
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If signature is invalid
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(func_code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error: {e}")
|
||||
|
||||
# Find function definition
|
||||
func_def = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
func_def = node
|
||||
break
|
||||
|
||||
if func_def is None:
|
||||
raise PluginValidationError("No function definition found")
|
||||
|
||||
# Check that it takes exactly one argument (context)
|
||||
if len(func_def.args.args) != 1:
|
||||
raise PluginValidationError(
|
||||
f"Hook function must take exactly 1 argument (context), "
|
||||
f"got {len(func_def.args.args)}"
|
||||
)
|
||||
|
||||
logger.info(f"Hook function '{func_def.name}' signature is valid")
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_plugin_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize a plugin name to be safe for filesystem and imports.
|
||||
|
||||
Args:
|
||||
name: Original plugin name
|
||||
|
||||
Returns:
|
||||
Sanitized name (lowercase, alphanumeric + underscore only)
|
||||
"""
|
||||
# Remove special characters, keep only alphanumeric and underscore
|
||||
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
|
||||
|
||||
# Ensure it doesn't start with a number
|
||||
if sanitized[0].isdigit():
|
||||
sanitized = f"plugin_{sanitized}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_imported_modules(code: str) -> Set[str]:
|
||||
"""
|
||||
Extract all imported modules from code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Set of module names
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return set()
|
||||
|
||||
modules = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
modules.add(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
modules.add(node.module)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def estimate_complexity(code: str) -> int:
|
||||
"""
|
||||
Estimate the cyclomatic complexity of code.
|
||||
|
||||
Higher values indicate more complex code with more branching.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Complexity estimate (1 = simple, 10+ = complex)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return 0
|
||||
|
||||
complexity = 1 # Base complexity
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Add 1 for each branching statement
|
||||
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
|
||||
complexity += 1
|
||||
elif isinstance(node, ast.BoolOp):
|
||||
complexity += len(node.values) - 1
|
||||
|
||||
return complexity
|
||||
@@ -23,6 +23,8 @@ import pandas as pd
|
||||
from datetime import datetime
|
||||
import pickle
|
||||
|
||||
from optimization_engine.plugins import HookManager
|
||||
|
||||
|
||||
class OptimizationRunner:
|
||||
"""
|
||||
@@ -68,6 +70,15 @@ class OptimizationRunner:
|
||||
self.output_dir = self.config_path.parent / 'optimization_results'
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Initialize plugin/hook system
|
||||
self.hook_manager = HookManager()
|
||||
plugins_dir = Path(__file__).parent / 'plugins'
|
||||
if plugins_dir.exists():
|
||||
self.hook_manager.load_plugins_from_directory(plugins_dir)
|
||||
summary = self.hook_manager.get_summary()
|
||||
if summary['total_hooks'] > 0:
|
||||
print(f"Loaded {summary['enabled_hooks']}/{summary['total_hooks']} plugins")
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load and validate optimization configuration."""
|
||||
with open(self.config_path, 'r') as f:
|
||||
@@ -311,6 +322,16 @@ class OptimizationRunner:
|
||||
int(dv['bounds'][1])
|
||||
)
|
||||
|
||||
# Execute pre_solve hooks
|
||||
pre_solve_context = {
|
||||
'trial_number': trial.number,
|
||||
'design_variables': design_vars,
|
||||
'sim_file': self.config.get('sim_file', ''),
|
||||
'working_dir': str(Path.cwd()),
|
||||
'config': self.config
|
||||
}
|
||||
self.hook_manager.execute_hooks('pre_solve', pre_solve_context, fail_fast=False)
|
||||
|
||||
# 2. Update NX model with new parameters
|
||||
try:
|
||||
self.model_updater(design_vars)
|
||||
@@ -318,6 +339,15 @@ class OptimizationRunner:
|
||||
print(f"Error updating model: {e}")
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# Execute post_mesh hooks (after model update)
|
||||
post_mesh_context = {
|
||||
'trial_number': trial.number,
|
||||
'design_variables': design_vars,
|
||||
'sim_file': self.config.get('sim_file', ''),
|
||||
'working_dir': str(Path.cwd())
|
||||
}
|
||||
self.hook_manager.execute_hooks('post_mesh', post_mesh_context, fail_fast=False)
|
||||
|
||||
# 3. Run simulation
|
||||
try:
|
||||
result_path = self.simulation_runner()
|
||||
@@ -325,6 +355,15 @@ class OptimizationRunner:
|
||||
print(f"Error running simulation: {e}")
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# Execute post_solve hooks
|
||||
post_solve_context = {
|
||||
'trial_number': trial.number,
|
||||
'design_variables': design_vars,
|
||||
'result_path': str(result_path) if result_path else '',
|
||||
'working_dir': str(Path.cwd())
|
||||
}
|
||||
self.hook_manager.execute_hooks('post_solve', post_solve_context, fail_fast=False)
|
||||
|
||||
# 4. Extract results with appropriate precision
|
||||
extracted_results = {}
|
||||
for obj in self.config['objectives']:
|
||||
@@ -362,6 +401,16 @@ class OptimizationRunner:
|
||||
print(f"Error extracting {const['name']}: {e}")
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# Execute post_extraction hooks
|
||||
post_extraction_context = {
|
||||
'trial_number': trial.number,
|
||||
'design_variables': design_vars,
|
||||
'extracted_results': extracted_results,
|
||||
'result_path': str(result_path) if result_path else '',
|
||||
'working_dir': str(Path.cwd())
|
||||
}
|
||||
self.hook_manager.execute_hooks('post_extraction', post_extraction_context, fail_fast=False)
|
||||
|
||||
# 5. Evaluate constraints
|
||||
for const in self.config.get('constraints', []):
|
||||
value = extracted_results[const['name']]
|
||||
@@ -389,6 +438,23 @@ class OptimizationRunner:
|
||||
else: # maximize
|
||||
total_objective -= weight * value
|
||||
|
||||
# Execute custom_objective hooks (can modify total_objective)
|
||||
custom_objective_context = {
|
||||
'trial_number': trial.number,
|
||||
'design_variables': design_vars,
|
||||
'extracted_results': extracted_results,
|
||||
'total_objective': total_objective,
|
||||
'working_dir': str(Path.cwd())
|
||||
}
|
||||
custom_results = self.hook_manager.execute_hooks('custom_objective', custom_objective_context, fail_fast=False)
|
||||
|
||||
# Allow hooks to override objective value
|
||||
for result in custom_results:
|
||||
if result and 'total_objective' in result:
|
||||
total_objective = result['total_objective']
|
||||
print(f"Custom objective hook modified total_objective to {total_objective:.6f}")
|
||||
break # Use first hook that provides override
|
||||
|
||||
# 7. Store results in history
|
||||
history_entry = {
|
||||
'trial_number': trial.number,
|
||||
|
||||
438
tests/test_plugin_system.py
Normal file
438
tests/test_plugin_system.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Tests for the Plugin System
|
||||
|
||||
Validates hook registration, execution, validation, and integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from optimization_engine.plugins import HookManager, Hook, HookPoint
|
||||
from optimization_engine.plugins.validators import (
|
||||
validate_plugin_code,
|
||||
PluginValidationError,
|
||||
check_hook_function_signature,
|
||||
sanitize_plugin_name,
|
||||
get_imported_modules,
|
||||
estimate_complexity
|
||||
)
|
||||
|
||||
|
||||
class TestHookRegistration:
|
||||
"""Test hook registration and management."""
|
||||
|
||||
def test_register_simple_hook(self):
|
||||
"""Test basic hook registration."""
|
||||
manager = HookManager()
|
||||
|
||||
def my_hook(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
return {'status': 'success'}
|
||||
|
||||
hook = manager.register_hook(
|
||||
hook_point='pre_solve',
|
||||
function=my_hook,
|
||||
description='Test hook',
|
||||
name='test_hook'
|
||||
)
|
||||
|
||||
assert hook.name == 'test_hook'
|
||||
assert hook.hook_point == HookPoint.PRE_SOLVE
|
||||
assert hook.enabled is True
|
||||
|
||||
def test_hook_priority_ordering(self):
|
||||
"""Test that hooks execute in priority order."""
|
||||
manager = HookManager()
|
||||
execution_order = []
|
||||
|
||||
def hook_high_priority(context):
|
||||
execution_order.append('high')
|
||||
return None
|
||||
|
||||
def hook_low_priority(context):
|
||||
execution_order.append('low')
|
||||
return None
|
||||
|
||||
# Register low priority first, high priority second
|
||||
manager.register_hook('pre_solve', hook_low_priority, 'Low', priority=200)
|
||||
manager.register_hook('pre_solve', hook_high_priority, 'High', priority=10)
|
||||
|
||||
# Execute hooks
|
||||
manager.execute_hooks('pre_solve', {'trial_number': 1})
|
||||
|
||||
# High priority should execute first
|
||||
assert execution_order == ['high', 'low']
|
||||
|
||||
def test_disable_enable_hook(self):
|
||||
"""Test disabling and enabling hooks."""
|
||||
manager = HookManager()
|
||||
execution_count = [0]
|
||||
|
||||
def counting_hook(context):
|
||||
execution_count[0] += 1
|
||||
return None
|
||||
|
||||
hook = manager.register_hook(
|
||||
'pre_solve',
|
||||
counting_hook,
|
||||
'Counter',
|
||||
name='counter_hook'
|
||||
)
|
||||
|
||||
# Execute while enabled
|
||||
manager.execute_hooks('pre_solve', {})
|
||||
assert execution_count[0] == 1
|
||||
|
||||
# Disable and execute
|
||||
manager.disable_hook('counter_hook')
|
||||
manager.execute_hooks('pre_solve', {})
|
||||
assert execution_count[0] == 1 # Should not increment
|
||||
|
||||
# Re-enable and execute
|
||||
manager.enable_hook('counter_hook')
|
||||
manager.execute_hooks('pre_solve', {})
|
||||
assert execution_count[0] == 2
|
||||
|
||||
def test_remove_hook(self):
|
||||
"""Test hook removal."""
|
||||
manager = HookManager()
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
manager.register_hook('pre_solve', test_hook, 'Test', name='removable')
|
||||
|
||||
assert len(manager.get_hooks('pre_solve')) == 1
|
||||
|
||||
success = manager.remove_hook('removable')
|
||||
assert success is True
|
||||
assert len(manager.get_hooks('pre_solve')) == 0
|
||||
|
||||
# Try removing non-existent hook
|
||||
success = manager.remove_hook('nonexistent')
|
||||
assert success is False
|
||||
|
||||
|
||||
class TestHookExecution:
|
||||
"""Test hook execution behavior."""
|
||||
|
||||
def test_hook_receives_context(self):
|
||||
"""Test that hooks receive correct context."""
|
||||
manager = HookManager()
|
||||
received_context = {}
|
||||
|
||||
def context_checker(context: Dict[str, Any]):
|
||||
received_context.update(context)
|
||||
return None
|
||||
|
||||
manager.register_hook('pre_solve', context_checker, 'Checker')
|
||||
|
||||
test_context = {
|
||||
'trial_number': 42,
|
||||
'design_variables': {'thickness': 5.0},
|
||||
'sim_file': 'test.sim'
|
||||
}
|
||||
|
||||
manager.execute_hooks('pre_solve', test_context)
|
||||
|
||||
assert received_context['trial_number'] == 42
|
||||
assert received_context['design_variables']['thickness'] == 5.0
|
||||
|
||||
def test_hook_return_values(self):
|
||||
"""Test that hook return values are collected."""
|
||||
manager = HookManager()
|
||||
|
||||
def hook1(context):
|
||||
return {'result': 'hook1'}
|
||||
|
||||
def hook2(context):
|
||||
return {'result': 'hook2'}
|
||||
|
||||
manager.register_hook('pre_solve', hook1, 'Hook 1')
|
||||
manager.register_hook('pre_solve', hook2, 'Hook 2')
|
||||
|
||||
results = manager.execute_hooks('pre_solve', {})
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]['result'] == 'hook1'
|
||||
assert results[1]['result'] == 'hook2'
|
||||
|
||||
def test_hook_error_handling_fail_fast(self):
|
||||
"""Test error handling with fail_fast=True."""
|
||||
manager = HookManager()
|
||||
|
||||
def failing_hook(context):
|
||||
raise ValueError("Intentional error")
|
||||
|
||||
manager.register_hook('pre_solve', failing_hook, 'Failing')
|
||||
|
||||
with pytest.raises(ValueError, match="Intentional error"):
|
||||
manager.execute_hooks('pre_solve', {}, fail_fast=True)
|
||||
|
||||
def test_hook_error_handling_continue(self):
|
||||
"""Test error handling with fail_fast=False."""
|
||||
manager = HookManager()
|
||||
execution_log = []
|
||||
|
||||
def failing_hook(context):
|
||||
execution_log.append('failing')
|
||||
raise ValueError("Intentional error")
|
||||
|
||||
def successful_hook(context):
|
||||
execution_log.append('successful')
|
||||
return {'status': 'ok'}
|
||||
|
||||
manager.register_hook('pre_solve', failing_hook, 'Failing', priority=10)
|
||||
manager.register_hook('pre_solve', successful_hook, 'Success', priority=20)
|
||||
|
||||
results = manager.execute_hooks('pre_solve', {}, fail_fast=False)
|
||||
|
||||
# Both hooks should have attempted execution
|
||||
assert execution_log == ['failing', 'successful']
|
||||
# First result is None (error), second is successful
|
||||
assert results[0] is None
|
||||
assert results[1]['status'] == 'ok'
|
||||
|
||||
def test_hook_history_tracking(self):
|
||||
"""Test that hook execution history is tracked."""
|
||||
manager = HookManager()
|
||||
|
||||
def test_hook(context):
|
||||
return {'result': 'success'}
|
||||
|
||||
manager.register_hook('pre_solve', test_hook, 'Test', name='tracked')
|
||||
|
||||
# Execute hooks multiple times
|
||||
for i in range(3):
|
||||
manager.execute_hooks('pre_solve', {'trial_number': i})
|
||||
|
||||
history = manager.get_history()
|
||||
assert len(history) >= 3
|
||||
|
||||
# Check history contains success records
|
||||
successful = [h for h in history if h['success']]
|
||||
assert len(successful) >= 3
|
||||
|
||||
|
||||
class TestCodeValidation:
|
||||
"""Test plugin code validation."""
|
||||
|
||||
def test_safe_code_passes(self):
|
||||
"""Test that safe code passes validation."""
|
||||
safe_code = """
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
def my_hook(context):
|
||||
x = context['design_variables']['thickness']
|
||||
result = math.sqrt(x**2 + np.mean([1, 2, 3]))
|
||||
return {'result': result}
|
||||
"""
|
||||
# Should not raise
|
||||
validate_plugin_code(safe_code)
|
||||
|
||||
def test_dangerous_import_blocked(self):
|
||||
"""Test that dangerous imports are blocked."""
|
||||
dangerous_code = """
|
||||
import os
|
||||
|
||||
def my_hook(context):
|
||||
os.system('rm -rf /')
|
||||
return None
|
||||
"""
|
||||
with pytest.raises(PluginValidationError, match="Unsafe import"):
|
||||
validate_plugin_code(dangerous_code)
|
||||
|
||||
def test_dangerous_operation_blocked(self):
|
||||
"""Test that dangerous operations are blocked."""
|
||||
dangerous_code = """
|
||||
def my_hook(context):
|
||||
eval('malicious_code')
|
||||
return None
|
||||
"""
|
||||
with pytest.raises(PluginValidationError, match="Dangerous operation"):
|
||||
validate_plugin_code(dangerous_code)
|
||||
|
||||
def test_file_operations_with_permission(self):
|
||||
"""Test that file operations work with allow_file_ops=True."""
|
||||
code_with_file_ops = """
|
||||
def my_hook(context):
|
||||
with open('output.txt', 'w') as f:
|
||||
f.write('test')
|
||||
return None
|
||||
"""
|
||||
# Should raise without permission
|
||||
with pytest.raises(PluginValidationError, match="Dangerous operation: open"):
|
||||
validate_plugin_code(code_with_file_ops, allow_file_ops=False)
|
||||
|
||||
# Should pass with permission
|
||||
validate_plugin_code(code_with_file_ops, allow_file_ops=True)
|
||||
|
||||
def test_syntax_error_detected(self):
|
||||
"""Test that syntax errors are detected."""
|
||||
bad_syntax = """
|
||||
def my_hook(context)
|
||||
return None # Missing colon
|
||||
"""
|
||||
with pytest.raises(PluginValidationError, match="Syntax error"):
|
||||
validate_plugin_code(bad_syntax)
|
||||
|
||||
def test_hook_signature_validation(self):
|
||||
"""Test hook function signature validation."""
|
||||
# Valid signature
|
||||
valid_code = """
|
||||
def my_hook(context):
|
||||
return None
|
||||
"""
|
||||
assert check_hook_function_signature(valid_code) is True
|
||||
|
||||
# Invalid: too many arguments
|
||||
invalid_code = """
|
||||
def my_hook(context, extra_arg):
|
||||
return None
|
||||
"""
|
||||
with pytest.raises(PluginValidationError, match="must take exactly 1 argument"):
|
||||
check_hook_function_signature(invalid_code)
|
||||
|
||||
# Invalid: no function
|
||||
invalid_code = """
|
||||
x = 5
|
||||
"""
|
||||
with pytest.raises(PluginValidationError, match="No function definition"):
|
||||
check_hook_function_signature(invalid_code)
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test plugin utility functions."""
|
||||
|
||||
def test_sanitize_plugin_name(self):
|
||||
"""Test plugin name sanitization."""
|
||||
assert sanitize_plugin_name('My-Plugin!') == 'my_plugin_'
|
||||
assert sanitize_plugin_name('123_plugin') == 'plugin_123_plugin'
|
||||
assert sanitize_plugin_name('valid_name') == 'valid_name'
|
||||
|
||||
def test_get_imported_modules(self):
|
||||
"""Test module extraction from code."""
|
||||
code = """
|
||||
import numpy as np
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
"""
|
||||
modules = get_imported_modules(code)
|
||||
assert 'numpy' in modules
|
||||
assert 'math' in modules
|
||||
assert 'pathlib' in modules
|
||||
assert 'typing' in modules
|
||||
|
||||
def test_estimate_complexity(self):
|
||||
"""Test cyclomatic complexity estimation."""
|
||||
simple_code = """
|
||||
def simple(x):
|
||||
return x + 1
|
||||
"""
|
||||
complex_code = """
|
||||
def complex(x):
|
||||
if x > 0:
|
||||
for i in range(x):
|
||||
if i % 2 == 0:
|
||||
while i > 0:
|
||||
i -= 1
|
||||
return x
|
||||
"""
|
||||
simple_complexity = estimate_complexity(simple_code)
|
||||
complex_complexity = estimate_complexity(complex_code)
|
||||
|
||||
assert simple_complexity == 1
|
||||
assert complex_complexity > simple_complexity
|
||||
|
||||
|
||||
class TestHookManager:
|
||||
"""Test HookManager functionality."""
|
||||
|
||||
def test_get_summary(self):
|
||||
"""Test hook system summary."""
|
||||
manager = HookManager()
|
||||
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
manager.register_hook('pre_solve', hook1, 'Hook 1', name='hook1')
|
||||
manager.register_hook('post_solve', hook2, 'Hook 2', name='hook2')
|
||||
manager.disable_hook('hook2')
|
||||
|
||||
summary = manager.get_summary()
|
||||
|
||||
assert summary['total_hooks'] == 2
|
||||
assert summary['enabled_hooks'] == 1
|
||||
assert summary['disabled_hooks'] == 1
|
||||
|
||||
def test_clear_history(self):
|
||||
"""Test clearing execution history."""
|
||||
manager = HookManager()
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
manager.register_hook('pre_solve', test_hook, 'Test')
|
||||
manager.execute_hooks('pre_solve', {})
|
||||
|
||||
assert len(manager.get_history()) > 0
|
||||
|
||||
manager.clear_history()
|
||||
assert len(manager.get_history()) == 0
|
||||
|
||||
def test_hook_manager_repr(self):
|
||||
"""Test HookManager string representation."""
|
||||
manager = HookManager()
|
||||
|
||||
def hook(context):
|
||||
return None
|
||||
|
||||
manager.register_hook('pre_solve', hook, 'Test')
|
||||
|
||||
repr_str = repr(manager)
|
||||
assert 'HookManager' in repr_str
|
||||
assert 'hooks=1' in repr_str
|
||||
assert 'enabled=1' in repr_str
|
||||
|
||||
|
||||
class TestPluginLoading:
|
||||
"""Test plugin directory loading."""
|
||||
|
||||
def test_load_plugins_from_nonexistent_directory(self):
|
||||
"""Test loading from non-existent directory."""
|
||||
manager = HookManager()
|
||||
# Should not raise, just log warning
|
||||
manager.load_plugins_from_directory(Path('/nonexistent/path'))
|
||||
|
||||
def test_plugin_registration_function(self):
|
||||
"""Test that plugins can register hooks via register_hooks()."""
|
||||
manager = HookManager()
|
||||
|
||||
# Simulate what a plugin file would contain
|
||||
def register_hooks(hook_manager):
|
||||
def my_plugin_hook(context):
|
||||
return {'plugin': 'loaded'}
|
||||
|
||||
hook_manager.register_hook(
|
||||
'pre_solve',
|
||||
my_plugin_hook,
|
||||
'Plugin hook',
|
||||
name='plugin_hook'
|
||||
)
|
||||
|
||||
# Call the registration function
|
||||
register_hooks(manager)
|
||||
|
||||
# Verify hook was registered
|
||||
hooks = manager.get_hooks('pre_solve')
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0].name == 'plugin_hook'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user