feat: Implement Phase 1 - Plugin & Hook System

Core plugin architecture for LLM-driven optimization:

New Features:
- Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives)
- HookManager for centralized registration and execution
- Code validation with AST-based safety checks
- Feature registry (JSON) for LLM capability discovery
- Example plugin: log_trial_start
- 23 comprehensive tests (all passing)

Integration:
- OptimizationRunner now loads plugins automatically
- Hooks execute at 5 points in optimization loop
- Custom objectives can override total_objective via hooks

Safety:
- Module whitelist (numpy, scipy, pandas, optuna, pyNastran)
- Dangerous operation blocking (eval, exec, os.system, subprocess)
- Optional file operation permission flag

Files Added:
- optimization_engine/plugins/__init__.py
- optimization_engine/plugins/hooks.py
- optimization_engine/plugins/hook_manager.py
- optimization_engine/plugins/validators.py
- optimization_engine/feature_registry.json
- optimization_engine/plugins/pre_solve/log_trial_start.py
- tests/test_plugin_system.py (23 tests)

Files Modified:
- optimization_engine/runner.py (added hook integration)

Ready for Phase 2: LLM interface layer

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-15 14:46:49 -05:00
parent 0ce9ddf3e2
commit a24e3f750c
14 changed files with 1473 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
"""
Atomizer Plugin System
Enables extensibility through hooks at various stages of the optimization lifecycle.
Hook Points:
- pre_mesh: Before meshing operations
- post_mesh: After meshing, before solve
- pre_solve: Before solver execution
- post_solve: After solve, before result extraction
- post_extraction: After result extraction, before objective calculation
- custom_objectives: Custom objective/constraint functions
Usage:
from optimization_engine.plugins import HookManager
hook_manager = HookManager()
hook_manager.register_hook('pre_solve', my_custom_function)
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
"""
from .hooks import Hook, HookPoint
from .hook_manager import HookManager
from .validators import validate_plugin_code, PluginValidationError
__all__ = [
'Hook',
'HookPoint',
'HookManager',
'validate_plugin_code',
'PluginValidationError'
]

View File

@@ -0,0 +1 @@
# custom_objectives hooks

View File

@@ -0,0 +1,303 @@
"""
Hook Manager for Atomizer Plugin System
Manages registration, execution, and lifecycle of hooks.
"""
from typing import Dict, List, Callable, Any, Optional
from pathlib import Path
import logging
import importlib.util
import json
from .hooks import Hook, HookPoint, HookContext
logger = logging.getLogger(__name__)
class HookManager:
"""
Central manager for all hooks in the optimization system.
Example:
>>> manager = HookManager()
>>> def my_hook(context):
>>> print(f"Trial {context['trial_number']} starting")
>>> return {'status': 'logged'}
>>>
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
"""
def __init__(self):
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
self._hook_history: List[Dict[str, Any]] = []
logger.info("HookManager initialized")
def register_hook(
self,
hook_point: str | HookPoint,
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
description: str,
name: Optional[str] = None,
priority: int = 100,
enabled: bool = True
) -> Hook:
"""
Register a new hook function.
Args:
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
function: Callable that takes context dict, returns optional dict
description: Human-readable description
name: Unique name (auto-generated if not provided)
priority: Execution order (lower = earlier)
enabled: Whether hook is active
Returns:
The created Hook object
Raises:
ValueError: If hook_point is invalid
"""
# Convert string to HookPoint enum
if isinstance(hook_point, str):
try:
hook_point = HookPoint(hook_point)
except ValueError:
valid_points = [p.value for p in HookPoint]
raise ValueError(
f"Invalid hook_point '{hook_point}'. "
f"Valid options: {valid_points}"
)
# Auto-generate name if not provided
if name is None:
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
# Create hook
hook = Hook(
name=name,
hook_point=hook_point,
function=function,
description=description,
priority=priority,
enabled=enabled
)
# Add to registry and sort by priority
self._hooks[hook_point].append(hook)
self._hooks[hook_point].sort(key=lambda h: h.priority)
logger.info(f"Registered hook: {hook}")
return hook
def execute_hooks(
self,
hook_point: str | HookPoint,
context: Dict[str, Any],
fail_fast: bool = False
) -> List[Optional[Dict[str, Any]]]:
"""
Execute all hooks registered at a specific point.
Args:
hook_point: The execution point
context: Data to pass to hooks
fail_fast: If True, stop on first error. If False, log and continue.
Returns:
List of results from each hook (None if hook returned nothing)
Raises:
Exception: If fail_fast=True and a hook fails
"""
# Convert string to enum
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
hooks = self._hooks[hook_point]
if not hooks:
logger.debug(f"No hooks registered for {hook_point.value}")
return []
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
results = []
for hook in hooks:
try:
result = hook.execute(context)
results.append(result)
# Record in history
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': True,
'trial_number': context.get('trial_number', -1)
})
except Exception as e:
logger.error(f"Hook '{hook.name}' failed: {e}")
# Record failure
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': False,
'error': str(e),
'trial_number': context.get('trial_number', -1)
})
if fail_fast:
raise
else:
results.append(None)
return results
def load_plugins_from_directory(self, directory: Path):
"""
Auto-discover and load plugins from a directory.
Expected structure:
plugins/
pre_mesh/
my_plugin.py # Contains register_hooks(manager) function
post_solve/
another_plugin.py
Args:
directory: Path to plugins directory
"""
if not directory.exists():
logger.warning(f"Plugins directory not found: {directory}")
return
logger.info(f"Loading plugins from {directory}")
for hook_point in HookPoint:
hook_dir = directory / hook_point.value
if not hook_dir.exists():
continue
for plugin_file in hook_dir.glob("*.py"):
if plugin_file.name.startswith("_"):
continue # Skip __init__.py and private files
try:
self._load_plugin_file(plugin_file)
except Exception as e:
logger.error(f"Failed to load plugin {plugin_file}: {e}")
def _load_plugin_file(self, plugin_file: Path):
"""Load a single plugin file and call its register_hooks function."""
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for {plugin_file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Call register_hooks if it exists
if hasattr(module, 'register_hooks'):
module.register_hooks(self)
logger.info(f"Loaded plugin: {plugin_file.stem}")
else:
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
"""Get all hooks registered at a specific point."""
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
return self._hooks[hook_point].copy()
def remove_hook(self, name: str) -> bool:
"""
Remove a hook by name.
Returns:
True if hook was found and removed, False otherwise
"""
for point, hooks in self._hooks.items():
for i, hook in enumerate(hooks):
if hook.name == name:
del hooks[i]
logger.info(f"Removed hook: {name}")
return True
return False
def enable_hook(self, name: str):
"""Enable a hook by name."""
self._set_hook_enabled(name, True)
def disable_hook(self, name: str):
"""Disable a hook by name."""
self._set_hook_enabled(name, False)
def _set_hook_enabled(self, name: str, enabled: bool):
"""Set enabled status for a hook."""
for hooks in self._hooks.values():
for hook in hooks:
if hook.name == name:
hook.enabled = enabled
status = "enabled" if enabled else "disabled"
logger.info(f"Hook '{name}' {status}")
return
logger.warning(f"Hook '{name}' not found")
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Get hook execution history.
Args:
limit: Maximum number of records to return (most recent)
Returns:
List of execution records
"""
if limit:
return self._hook_history[-limit:]
return self._hook_history.copy()
def clear_history(self):
"""Clear all execution history."""
self._hook_history.clear()
logger.info("Hook history cleared")
def get_summary(self) -> Dict[str, Any]:
"""
Get a summary of the hook system state.
Returns:
Dictionary with hook counts, enabled status, etc.
"""
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
enabled_hooks = sum(
sum(1 for h in hooks if h.enabled)
for hooks in self._hooks.values()
)
by_point = {
point.value: {
'total': len(hooks),
'enabled': sum(1 for h in hooks if h.enabled),
'names': [h.name for h in hooks]
}
for point, hooks in self._hooks.items()
}
return {
'total_hooks': total_hooks,
'enabled_hooks': enabled_hooks,
'disabled_hooks': total_hooks - enabled_hooks,
'by_hook_point': by_point,
'history_records': len(self._hook_history)
}
def __repr__(self) -> str:
summary = self.get_summary()
return (
f"HookManager(hooks={summary['total_hooks']}, "
f"enabled={summary['enabled_hooks']})"
)

View File

@@ -0,0 +1,115 @@
"""
Core Hook System for Atomizer
Defines hook points in the optimization lifecycle and hook registration mechanism.
"""
from enum import Enum
from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
class HookPoint(Enum):
"""Enumeration of available hook points in the optimization lifecycle."""
PRE_MESH = "pre_mesh" # Before meshing
POST_MESH = "post_mesh" # After meshing, before solve
PRE_SOLVE = "pre_solve" # Before solver execution
POST_SOLVE = "post_solve" # After solve, before extraction
POST_EXTRACTION = "post_extraction" # After result extraction
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
@dataclass
class Hook:
"""
Represents a single hook function to be executed at a specific point.
Attributes:
name: Unique identifier for this hook
hook_point: When this hook should execute (HookPoint enum)
function: The callable to execute
description: Human-readable description of what this hook does
priority: Execution order (lower = earlier, default=100)
enabled: Whether this hook is currently active
"""
name: str
hook_point: HookPoint
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
description: str
priority: int = 100
enabled: bool = True
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Execute this hook with the given context.
Args:
context: Dictionary containing relevant data for this hook point
Common keys:
- trial_number: Current trial number
- design_variables: Current design variable values
- sim_file: Path to simulation file
- working_dir: Current working directory
Returns:
Optional dictionary with results or modifications to context
Raises:
Exception: Any exception from the hook function is logged and re-raised
"""
if not self.enabled:
logger.debug(f"Hook '{self.name}' is disabled, skipping")
return None
try:
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
result = self.function(context)
logger.debug(f"Hook '{self.name}' completed successfully")
return result
except Exception as e:
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
raise
def __repr__(self) -> str:
status = "enabled" if self.enabled else "disabled"
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
class HookContext:
"""
Context object passed to hooks containing all relevant data.
This is a convenience wrapper around a dictionary that provides
both dict-like access and attribute access.
"""
def __init__(self, **kwargs):
self._data = kwargs
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, value: Any):
self._data[key] = value
def __contains__(self, key: str) -> bool:
return key in self._data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def update(self, other: Dict[str, Any]):
self._data.update(other)
def to_dict(self) -> Dict[str, Any]:
return self._data.copy()
def __repr__(self) -> str:
keys = list(self._data.keys())
return f"HookContext({', '.join(keys)})"

View File

@@ -0,0 +1 @@
# post_extraction hooks

View File

@@ -0,0 +1 @@
# post_mesh hooks

View File

@@ -0,0 +1 @@
# post_solve hooks

View File

@@ -0,0 +1 @@
# pre_mesh hooks

View File

@@ -0,0 +1 @@
# pre_solve hooks

View File

@@ -0,0 +1,56 @@
"""
Example Plugin: Log Trial Start
Simple pre-solve hook that logs trial information.
This demonstrates the plugin API and serves as a template for custom hooks.
"""
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Log trial information before solver execution.
Args:
context: Hook context containing:
- trial_number: Current trial number
- design_variables: Dict of variable values
- sim_file: Path to simulation file
Returns:
Dict with logging status
"""
trial_num = context.get('trial_number', '?')
design_vars = context.get('design_variables', {})
logger.info("=" * 60)
logger.info(f"TRIAL {trial_num} STARTING")
logger.info("=" * 60)
for var_name, var_value in design_vars.items():
logger.info(f" {var_name}: {var_value:.4f}")
return {'logged': True, 'trial': trial_num}
def register_hooks(hook_manager):
"""
Register this plugin's hooks with the manager.
This function is called automatically when the plugin is loaded.
Args:
hook_manager: The HookManager instance
"""
hook_manager.register_hook(
hook_point='pre_solve',
function=trial_start_logger,
description='Log trial number and design variables before solve',
name='log_trial_start',
priority=10 # Run early
)

View File

@@ -0,0 +1,214 @@
"""
Plugin Code Validators
Ensures safety and correctness of plugin code before execution.
"""
import ast
import re
from typing import List, Set
import logging
logger = logging.getLogger(__name__)
class PluginValidationError(Exception):
"""Raised when plugin code fails validation."""
pass
# Whitelist of safe modules for plugin imports
SAFE_MODULES = {
'math', 'numpy', 'scipy', 'pandas',
'pathlib', 'json', 'csv', 'datetime',
'logging', 'typing', 'dataclasses',
'optuna', 'pyNastran'
}
# Blacklist of dangerous operations
DANGEROUS_OPERATIONS = {
'eval', 'exec', 'compile', '__import__',
'open', # File operations should be explicit
'subprocess', 'os.system', 'os.popen',
'shutil.rmtree', # Dangerous file operations
}
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
"""
Validate plugin code for safety before execution.
Args:
code: Python source code to validate
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
Raises:
PluginValidationError: If code contains unsafe operations
Example:
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
>>> validate_plugin_code(code) # Passes
>>>
>>> code = "import os; os.system('rm -rf /')"
>>> validate_plugin_code(code) # Raises PluginValidationError
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error in plugin code: {e}")
# Check for dangerous imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import: {alias.name}. "
f"Allowed modules: {SAFE_MODULES}"
)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module not in SAFE_MODULES:
# Allow submodules of safe modules
base_module = node.module.split('.')[0]
if base_module not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import from: {node.module}"
)
# Check for dangerous function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in DANGEROUS_OPERATIONS:
if node.func.id == 'open' and allow_file_ops:
continue # Allow if explicitly permitted
raise PluginValidationError(
f"Dangerous operation: {node.func.id}"
)
# Check for attribute access to dangerous methods
elif isinstance(node, ast.Attribute):
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
if node.attr in dangerous_attrs and not allow_file_ops:
raise PluginValidationError(
f"Dangerous attribute access: {node.attr}"
)
logger.info("Plugin code validation passed")
def check_hook_function_signature(func_code: str) -> bool:
"""
Verify that a hook function has the correct signature.
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
Args:
func_code: Function source code
Returns:
True if signature is valid
Raises:
PluginValidationError: If signature is invalid
"""
try:
tree = ast.parse(func_code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error: {e}")
# Find function definition
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
raise PluginValidationError("No function definition found")
# Check that it takes exactly one argument (context)
if len(func_def.args.args) != 1:
raise PluginValidationError(
f"Hook function must take exactly 1 argument (context), "
f"got {len(func_def.args.args)}"
)
logger.info(f"Hook function '{func_def.name}' signature is valid")
return True
def sanitize_plugin_name(name: str) -> str:
"""
Sanitize a plugin name to be safe for filesystem and imports.
Args:
name: Original plugin name
Returns:
Sanitized name (lowercase, alphanumeric + underscore only)
"""
# Remove special characters, keep only alphanumeric and underscore
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
# Ensure it doesn't start with a number
if sanitized[0].isdigit():
sanitized = f"plugin_{sanitized}"
return sanitized
def get_imported_modules(code: str) -> Set[str]:
"""
Extract all imported modules from code.
Args:
code: Python source code
Returns:
Set of module names
"""
try:
tree = ast.parse(code)
except SyntaxError:
return set()
modules = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
modules.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.add(node.module)
return modules
def estimate_complexity(code: str) -> int:
"""
Estimate the cyclomatic complexity of code.
Higher values indicate more complex code with more branching.
Args:
code: Python source code
Returns:
Complexity estimate (1 = simple, 10+ = complex)
"""
try:
tree = ast.parse(code)
except SyntaxError:
return 0
complexity = 1 # Base complexity
for node in ast.walk(tree):
# Add 1 for each branching statement
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
return complexity