""" Hook Manager for Atomizer Plugin System Manages registration, execution, and lifecycle of hooks. """ from typing import Dict, List, Callable, Any, Optional from pathlib import Path import logging import importlib.util import json from .hooks import Hook, HookPoint, HookContext logger = logging.getLogger(__name__) class HookManager: """ Central manager for all hooks in the optimization system. Example: >>> manager = HookManager() >>> def my_hook(context): >>> print(f"Trial {context['trial_number']} starting") >>> return {'status': 'logged'} >>> >>> manager.register_hook('pre_solve', my_hook, 'Log trial start') >>> manager.execute_hooks('pre_solve', {'trial_number': 5}) """ def __init__(self): self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint} self._hook_history: List[Dict[str, Any]] = [] logger.info("HookManager initialized") def register_hook( self, hook_point: str | HookPoint, function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]], description: str, name: Optional[str] = None, priority: int = 100, enabled: bool = True ) -> Hook: """ Register a new hook function. Args: hook_point: When to execute ('pre_solve', 'post_mesh', etc.) function: Callable that takes context dict, returns optional dict description: Human-readable description name: Unique name (auto-generated if not provided) priority: Execution order (lower = earlier) enabled: Whether hook is active Returns: The created Hook object Raises: ValueError: If hook_point is invalid """ # Convert string to HookPoint enum if isinstance(hook_point, str): try: hook_point = HookPoint(hook_point) except ValueError: valid_points = [p.value for p in HookPoint] raise ValueError( f"Invalid hook_point '{hook_point}'. " f"Valid options: {valid_points}" ) # Auto-generate name if not provided if name is None: name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}" # Create hook hook = Hook( name=name, hook_point=hook_point, function=function, description=description, priority=priority, enabled=enabled ) # Add to registry and sort by priority self._hooks[hook_point].append(hook) self._hooks[hook_point].sort(key=lambda h: h.priority) logger.info(f"Registered hook: {hook}") return hook def execute_hooks( self, hook_point: str | HookPoint, context: Dict[str, Any], fail_fast: bool = False ) -> List[Optional[Dict[str, Any]]]: """ Execute all hooks registered at a specific point. Args: hook_point: The execution point context: Data to pass to hooks fail_fast: If True, stop on first error. If False, log and continue. Returns: List of results from each hook (None if hook returned nothing) Raises: Exception: If fail_fast=True and a hook fails """ # Convert string to enum if isinstance(hook_point, str): hook_point = HookPoint(hook_point) hooks = self._hooks[hook_point] if not hooks: logger.debug(f"No hooks registered for {hook_point.value}") return [] logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}") results = [] for hook in hooks: try: result = hook.execute(context) results.append(result) # Record in history self._hook_history.append({ 'hook_name': hook.name, 'hook_point': hook_point.value, 'success': True, 'trial_number': context.get('trial_number', -1) }) except Exception as e: logger.error(f"Hook '{hook.name}' failed: {e}") # Record failure self._hook_history.append({ 'hook_name': hook.name, 'hook_point': hook_point.value, 'success': False, 'error': str(e), 'trial_number': context.get('trial_number', -1) }) if fail_fast: raise else: results.append(None) return results def load_plugins_from_directory(self, directory: Path): """ Auto-discover and load plugins from a directory. Expected structure: plugins/ pre_mesh/ my_plugin.py # Contains register_hooks(manager) function post_solve/ another_plugin.py Args: directory: Path to plugins directory """ if not directory.exists(): logger.warning(f"Plugins directory not found: {directory}") return logger.info(f"Loading plugins from {directory}") for hook_point in HookPoint: hook_dir = directory / hook_point.value if not hook_dir.exists(): continue for plugin_file in hook_dir.glob("*.py"): if plugin_file.name.startswith("_"): continue # Skip __init__.py and private files try: self._load_plugin_file(plugin_file) except Exception as e: logger.error(f"Failed to load plugin {plugin_file}: {e}") def _load_plugin_file(self, plugin_file: Path): """Load a single plugin file and call its register_hooks function.""" spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file) if spec is None or spec.loader is None: raise ImportError(f"Could not load spec for {plugin_file}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Call register_hooks if it exists if hasattr(module, 'register_hooks'): module.register_hooks(self) logger.info(f"Loaded plugin: {plugin_file.stem}") else: logger.warning(f"Plugin {plugin_file} has no register_hooks() function") def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]: """Get all hooks registered at a specific point.""" if isinstance(hook_point, str): hook_point = HookPoint(hook_point) return self._hooks[hook_point].copy() def remove_hook(self, name: str) -> bool: """ Remove a hook by name. Returns: True if hook was found and removed, False otherwise """ for point, hooks in self._hooks.items(): for i, hook in enumerate(hooks): if hook.name == name: del hooks[i] logger.info(f"Removed hook: {name}") return True return False def enable_hook(self, name: str): """Enable a hook by name.""" self._set_hook_enabled(name, True) def disable_hook(self, name: str): """Disable a hook by name.""" self._set_hook_enabled(name, False) def _set_hook_enabled(self, name: str, enabled: bool): """Set enabled status for a hook.""" for hooks in self._hooks.values(): for hook in hooks: if hook.name == name: hook.enabled = enabled status = "enabled" if enabled else "disabled" logger.info(f"Hook '{name}' {status}") return logger.warning(f"Hook '{name}' not found") def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: """ Get hook execution history. Args: limit: Maximum number of records to return (most recent) Returns: List of execution records """ if limit: return self._hook_history[-limit:] return self._hook_history.copy() def clear_history(self): """Clear all execution history.""" self._hook_history.clear() logger.info("Hook history cleared") def get_summary(self) -> Dict[str, Any]: """ Get a summary of the hook system state. Returns: Dictionary with hook counts, enabled status, etc. """ total_hooks = sum(len(hooks) for hooks in self._hooks.values()) enabled_hooks = sum( sum(1 for h in hooks if h.enabled) for hooks in self._hooks.values() ) by_point = { point.value: { 'total': len(hooks), 'enabled': sum(1 for h in hooks if h.enabled), 'names': [h.name for h in hooks] } for point, hooks in self._hooks.items() } return { 'total_hooks': total_hooks, 'enabled_hooks': enabled_hooks, 'disabled_hooks': total_hooks - enabled_hooks, 'by_hook_point': by_point, 'history_records': len(self._hook_history) } def __repr__(self) -> str: summary = self.get_summary() return ( f"HookManager(hooks={summary['total_hooks']}, " f"enabled={summary['enabled_hooks']})" )