304 lines
9.6 KiB
Python
304 lines
9.6 KiB
Python
|
|
"""
|
||
|
|
Hook Manager for Atomizer Plugin System
|
||
|
|
|
||
|
|
Manages registration, execution, and lifecycle of hooks.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, List, Callable, Any, Optional
|
||
|
|
from pathlib import Path
|
||
|
|
import logging
|
||
|
|
import importlib.util
|
||
|
|
import json
|
||
|
|
|
||
|
|
from .hooks import Hook, HookPoint, HookContext
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class HookManager:
|
||
|
|
"""
|
||
|
|
Central manager for all hooks in the optimization system.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> manager = HookManager()
|
||
|
|
>>> def my_hook(context):
|
||
|
|
>>> print(f"Trial {context['trial_number']} starting")
|
||
|
|
>>> return {'status': 'logged'}
|
||
|
|
>>>
|
||
|
|
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
|
||
|
|
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
|
||
|
|
self._hook_history: List[Dict[str, Any]] = []
|
||
|
|
logger.info("HookManager initialized")
|
||
|
|
|
||
|
|
def register_hook(
|
||
|
|
self,
|
||
|
|
hook_point: str | HookPoint,
|
||
|
|
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
|
||
|
|
description: str,
|
||
|
|
name: Optional[str] = None,
|
||
|
|
priority: int = 100,
|
||
|
|
enabled: bool = True
|
||
|
|
) -> Hook:
|
||
|
|
"""
|
||
|
|
Register a new hook function.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
|
||
|
|
function: Callable that takes context dict, returns optional dict
|
||
|
|
description: Human-readable description
|
||
|
|
name: Unique name (auto-generated if not provided)
|
||
|
|
priority: Execution order (lower = earlier)
|
||
|
|
enabled: Whether hook is active
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The created Hook object
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If hook_point is invalid
|
||
|
|
"""
|
||
|
|
# Convert string to HookPoint enum
|
||
|
|
if isinstance(hook_point, str):
|
||
|
|
try:
|
||
|
|
hook_point = HookPoint(hook_point)
|
||
|
|
except ValueError:
|
||
|
|
valid_points = [p.value for p in HookPoint]
|
||
|
|
raise ValueError(
|
||
|
|
f"Invalid hook_point '{hook_point}'. "
|
||
|
|
f"Valid options: {valid_points}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Auto-generate name if not provided
|
||
|
|
if name is None:
|
||
|
|
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
|
||
|
|
|
||
|
|
# Create hook
|
||
|
|
hook = Hook(
|
||
|
|
name=name,
|
||
|
|
hook_point=hook_point,
|
||
|
|
function=function,
|
||
|
|
description=description,
|
||
|
|
priority=priority,
|
||
|
|
enabled=enabled
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add to registry and sort by priority
|
||
|
|
self._hooks[hook_point].append(hook)
|
||
|
|
self._hooks[hook_point].sort(key=lambda h: h.priority)
|
||
|
|
|
||
|
|
logger.info(f"Registered hook: {hook}")
|
||
|
|
return hook
|
||
|
|
|
||
|
|
def execute_hooks(
|
||
|
|
self,
|
||
|
|
hook_point: str | HookPoint,
|
||
|
|
context: Dict[str, Any],
|
||
|
|
fail_fast: bool = False
|
||
|
|
) -> List[Optional[Dict[str, Any]]]:
|
||
|
|
"""
|
||
|
|
Execute all hooks registered at a specific point.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
hook_point: The execution point
|
||
|
|
context: Data to pass to hooks
|
||
|
|
fail_fast: If True, stop on first error. If False, log and continue.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of results from each hook (None if hook returned nothing)
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
Exception: If fail_fast=True and a hook fails
|
||
|
|
"""
|
||
|
|
# Convert string to enum
|
||
|
|
if isinstance(hook_point, str):
|
||
|
|
hook_point = HookPoint(hook_point)
|
||
|
|
|
||
|
|
hooks = self._hooks[hook_point]
|
||
|
|
if not hooks:
|
||
|
|
logger.debug(f"No hooks registered for {hook_point.value}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
|
||
|
|
|
||
|
|
results = []
|
||
|
|
for hook in hooks:
|
||
|
|
try:
|
||
|
|
result = hook.execute(context)
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
# Record in history
|
||
|
|
self._hook_history.append({
|
||
|
|
'hook_name': hook.name,
|
||
|
|
'hook_point': hook_point.value,
|
||
|
|
'success': True,
|
||
|
|
'trial_number': context.get('trial_number', -1)
|
||
|
|
})
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Hook '{hook.name}' failed: {e}")
|
||
|
|
|
||
|
|
# Record failure
|
||
|
|
self._hook_history.append({
|
||
|
|
'hook_name': hook.name,
|
||
|
|
'hook_point': hook_point.value,
|
||
|
|
'success': False,
|
||
|
|
'error': str(e),
|
||
|
|
'trial_number': context.get('trial_number', -1)
|
||
|
|
})
|
||
|
|
|
||
|
|
if fail_fast:
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
results.append(None)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def load_plugins_from_directory(self, directory: Path):
|
||
|
|
"""
|
||
|
|
Auto-discover and load plugins from a directory.
|
||
|
|
|
||
|
|
Expected structure:
|
||
|
|
plugins/
|
||
|
|
pre_mesh/
|
||
|
|
my_plugin.py # Contains register_hooks(manager) function
|
||
|
|
post_solve/
|
||
|
|
another_plugin.py
|
||
|
|
|
||
|
|
Args:
|
||
|
|
directory: Path to plugins directory
|
||
|
|
"""
|
||
|
|
if not directory.exists():
|
||
|
|
logger.warning(f"Plugins directory not found: {directory}")
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.info(f"Loading plugins from {directory}")
|
||
|
|
|
||
|
|
for hook_point in HookPoint:
|
||
|
|
hook_dir = directory / hook_point.value
|
||
|
|
if not hook_dir.exists():
|
||
|
|
continue
|
||
|
|
|
||
|
|
for plugin_file in hook_dir.glob("*.py"):
|
||
|
|
if plugin_file.name.startswith("_"):
|
||
|
|
continue # Skip __init__.py and private files
|
||
|
|
|
||
|
|
try:
|
||
|
|
self._load_plugin_file(plugin_file)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to load plugin {plugin_file}: {e}")
|
||
|
|
|
||
|
|
def _load_plugin_file(self, plugin_file: Path):
|
||
|
|
"""Load a single plugin file and call its register_hooks function."""
|
||
|
|
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
|
||
|
|
if spec is None or spec.loader is None:
|
||
|
|
raise ImportError(f"Could not load spec for {plugin_file}")
|
||
|
|
|
||
|
|
module = importlib.util.module_from_spec(spec)
|
||
|
|
spec.loader.exec_module(module)
|
||
|
|
|
||
|
|
# Call register_hooks if it exists
|
||
|
|
if hasattr(module, 'register_hooks'):
|
||
|
|
module.register_hooks(self)
|
||
|
|
logger.info(f"Loaded plugin: {plugin_file.stem}")
|
||
|
|
else:
|
||
|
|
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
|
||
|
|
|
||
|
|
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
|
||
|
|
"""Get all hooks registered at a specific point."""
|
||
|
|
if isinstance(hook_point, str):
|
||
|
|
hook_point = HookPoint(hook_point)
|
||
|
|
return self._hooks[hook_point].copy()
|
||
|
|
|
||
|
|
def remove_hook(self, name: str) -> bool:
|
||
|
|
"""
|
||
|
|
Remove a hook by name.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if hook was found and removed, False otherwise
|
||
|
|
"""
|
||
|
|
for point, hooks in self._hooks.items():
|
||
|
|
for i, hook in enumerate(hooks):
|
||
|
|
if hook.name == name:
|
||
|
|
del hooks[i]
|
||
|
|
logger.info(f"Removed hook: {name}")
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def enable_hook(self, name: str):
|
||
|
|
"""Enable a hook by name."""
|
||
|
|
self._set_hook_enabled(name, True)
|
||
|
|
|
||
|
|
def disable_hook(self, name: str):
|
||
|
|
"""Disable a hook by name."""
|
||
|
|
self._set_hook_enabled(name, False)
|
||
|
|
|
||
|
|
def _set_hook_enabled(self, name: str, enabled: bool):
|
||
|
|
"""Set enabled status for a hook."""
|
||
|
|
for hooks in self._hooks.values():
|
||
|
|
for hook in hooks:
|
||
|
|
if hook.name == name:
|
||
|
|
hook.enabled = enabled
|
||
|
|
status = "enabled" if enabled else "disabled"
|
||
|
|
logger.info(f"Hook '{name}' {status}")
|
||
|
|
return
|
||
|
|
logger.warning(f"Hook '{name}' not found")
|
||
|
|
|
||
|
|
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get hook execution history.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
limit: Maximum number of records to return (most recent)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of execution records
|
||
|
|
"""
|
||
|
|
if limit:
|
||
|
|
return self._hook_history[-limit:]
|
||
|
|
return self._hook_history.copy()
|
||
|
|
|
||
|
|
def clear_history(self):
|
||
|
|
"""Clear all execution history."""
|
||
|
|
self._hook_history.clear()
|
||
|
|
logger.info("Hook history cleared")
|
||
|
|
|
||
|
|
def get_summary(self) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get a summary of the hook system state.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with hook counts, enabled status, etc.
|
||
|
|
"""
|
||
|
|
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
|
||
|
|
enabled_hooks = sum(
|
||
|
|
sum(1 for h in hooks if h.enabled)
|
||
|
|
for hooks in self._hooks.values()
|
||
|
|
)
|
||
|
|
|
||
|
|
by_point = {
|
||
|
|
point.value: {
|
||
|
|
'total': len(hooks),
|
||
|
|
'enabled': sum(1 for h in hooks if h.enabled),
|
||
|
|
'names': [h.name for h in hooks]
|
||
|
|
}
|
||
|
|
for point, hooks in self._hooks.items()
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
'total_hooks': total_hooks,
|
||
|
|
'enabled_hooks': enabled_hooks,
|
||
|
|
'disabled_hooks': total_hooks - enabled_hooks,
|
||
|
|
'by_hook_point': by_point,
|
||
|
|
'history_records': len(self._hook_history)
|
||
|
|
}
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
summary = self.get_summary()
|
||
|
|
return (
|
||
|
|
f"HookManager(hooks={summary['total_hooks']}, "
|
||
|
|
f"enabled={summary['enabled_hooks']})"
|
||
|
|
)
|