Files
Atomizer/optimization_engine/plugins/hook_manager.py

304 lines
9.6 KiB
Python
Raw Normal View History

"""
Hook Manager for Atomizer Plugin System
Manages registration, execution, and lifecycle of hooks.
"""
from typing import Dict, List, Callable, Any, Optional
from pathlib import Path
import logging
import importlib.util
import json
from .hooks import Hook, HookPoint, HookContext
logger = logging.getLogger(__name__)
class HookManager:
"""
Central manager for all hooks in the optimization system.
Example:
>>> manager = HookManager()
>>> def my_hook(context):
>>> print(f"Trial {context['trial_number']} starting")
>>> return {'status': 'logged'}
>>>
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
"""
def __init__(self):
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
self._hook_history: List[Dict[str, Any]] = []
logger.info("HookManager initialized")
def register_hook(
self,
hook_point: str | HookPoint,
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
description: str,
name: Optional[str] = None,
priority: int = 100,
enabled: bool = True
) -> Hook:
"""
Register a new hook function.
Args:
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
function: Callable that takes context dict, returns optional dict
description: Human-readable description
name: Unique name (auto-generated if not provided)
priority: Execution order (lower = earlier)
enabled: Whether hook is active
Returns:
The created Hook object
Raises:
ValueError: If hook_point is invalid
"""
# Convert string to HookPoint enum
if isinstance(hook_point, str):
try:
hook_point = HookPoint(hook_point)
except ValueError:
valid_points = [p.value for p in HookPoint]
raise ValueError(
f"Invalid hook_point '{hook_point}'. "
f"Valid options: {valid_points}"
)
# Auto-generate name if not provided
if name is None:
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
# Create hook
hook = Hook(
name=name,
hook_point=hook_point,
function=function,
description=description,
priority=priority,
enabled=enabled
)
# Add to registry and sort by priority
self._hooks[hook_point].append(hook)
self._hooks[hook_point].sort(key=lambda h: h.priority)
logger.info(f"Registered hook: {hook}")
return hook
def execute_hooks(
self,
hook_point: str | HookPoint,
context: Dict[str, Any],
fail_fast: bool = False
) -> List[Optional[Dict[str, Any]]]:
"""
Execute all hooks registered at a specific point.
Args:
hook_point: The execution point
context: Data to pass to hooks
fail_fast: If True, stop on first error. If False, log and continue.
Returns:
List of results from each hook (None if hook returned nothing)
Raises:
Exception: If fail_fast=True and a hook fails
"""
# Convert string to enum
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
hooks = self._hooks[hook_point]
if not hooks:
logger.debug(f"No hooks registered for {hook_point.value}")
return []
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
results = []
for hook in hooks:
try:
result = hook.execute(context)
results.append(result)
# Record in history
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': True,
'trial_number': context.get('trial_number', -1)
})
except Exception as e:
logger.error(f"Hook '{hook.name}' failed: {e}")
# Record failure
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': False,
'error': str(e),
'trial_number': context.get('trial_number', -1)
})
if fail_fast:
raise
else:
results.append(None)
return results
def load_plugins_from_directory(self, directory: Path):
"""
Auto-discover and load plugins from a directory.
Expected structure:
plugins/
pre_mesh/
my_plugin.py # Contains register_hooks(manager) function
post_solve/
another_plugin.py
Args:
directory: Path to plugins directory
"""
if not directory.exists():
logger.warning(f"Plugins directory not found: {directory}")
return
logger.info(f"Loading plugins from {directory}")
for hook_point in HookPoint:
hook_dir = directory / hook_point.value
if not hook_dir.exists():
continue
for plugin_file in hook_dir.glob("*.py"):
if plugin_file.name.startswith("_"):
continue # Skip __init__.py and private files
try:
self._load_plugin_file(plugin_file)
except Exception as e:
logger.error(f"Failed to load plugin {plugin_file}: {e}")
def _load_plugin_file(self, plugin_file: Path):
"""Load a single plugin file and call its register_hooks function."""
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for {plugin_file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Call register_hooks if it exists
if hasattr(module, 'register_hooks'):
module.register_hooks(self)
logger.info(f"Loaded plugin: {plugin_file.stem}")
else:
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
"""Get all hooks registered at a specific point."""
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
return self._hooks[hook_point].copy()
def remove_hook(self, name: str) -> bool:
"""
Remove a hook by name.
Returns:
True if hook was found and removed, False otherwise
"""
for point, hooks in self._hooks.items():
for i, hook in enumerate(hooks):
if hook.name == name:
del hooks[i]
logger.info(f"Removed hook: {name}")
return True
return False
def enable_hook(self, name: str):
"""Enable a hook by name."""
self._set_hook_enabled(name, True)
def disable_hook(self, name: str):
"""Disable a hook by name."""
self._set_hook_enabled(name, False)
def _set_hook_enabled(self, name: str, enabled: bool):
"""Set enabled status for a hook."""
for hooks in self._hooks.values():
for hook in hooks:
if hook.name == name:
hook.enabled = enabled
status = "enabled" if enabled else "disabled"
logger.info(f"Hook '{name}' {status}")
return
logger.warning(f"Hook '{name}' not found")
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Get hook execution history.
Args:
limit: Maximum number of records to return (most recent)
Returns:
List of execution records
"""
if limit:
return self._hook_history[-limit:]
return self._hook_history.copy()
def clear_history(self):
"""Clear all execution history."""
self._hook_history.clear()
logger.info("Hook history cleared")
def get_summary(self) -> Dict[str, Any]:
"""
Get a summary of the hook system state.
Returns:
Dictionary with hook counts, enabled status, etc.
"""
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
enabled_hooks = sum(
sum(1 for h in hooks if h.enabled)
for hooks in self._hooks.values()
)
by_point = {
point.value: {
'total': len(hooks),
'enabled': sum(1 for h in hooks if h.enabled),
'names': [h.name for h in hooks]
}
for point, hooks in self._hooks.items()
}
return {
'total_hooks': total_hooks,
'enabled_hooks': enabled_hooks,
'disabled_hooks': total_hooks - enabled_hooks,
'by_hook_point': by_point,
'history_records': len(self._hook_history)
}
def __repr__(self) -> str:
summary = self.get_summary()
return (
f"HookManager(hooks={summary['total_hooks']}, "
f"enabled={summary['enabled_hooks']})"
)