feat: Implement Phase 1 - Plugin & Hook System
Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
32
optimization_engine/plugins/__init__.py
Normal file
32
optimization_engine/plugins/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Atomizer Plugin System
|
||||
|
||||
Enables extensibility through hooks at various stages of the optimization lifecycle.
|
||||
|
||||
Hook Points:
|
||||
- pre_mesh: Before meshing operations
|
||||
- post_mesh: After meshing, before solve
|
||||
- pre_solve: Before solver execution
|
||||
- post_solve: After solve, before result extraction
|
||||
- post_extraction: After result extraction, before objective calculation
|
||||
- custom_objectives: Custom objective/constraint functions
|
||||
|
||||
Usage:
|
||||
from optimization_engine.plugins import HookManager
|
||||
|
||||
hook_manager = HookManager()
|
||||
hook_manager.register_hook('pre_solve', my_custom_function)
|
||||
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
|
||||
"""
|
||||
|
||||
from .hooks import Hook, HookPoint
|
||||
from .hook_manager import HookManager
|
||||
from .validators import validate_plugin_code, PluginValidationError
|
||||
|
||||
__all__ = [
|
||||
'Hook',
|
||||
'HookPoint',
|
||||
'HookManager',
|
||||
'validate_plugin_code',
|
||||
'PluginValidationError'
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
# custom_objectives hooks
|
||||
303
optimization_engine/plugins/hook_manager.py
Normal file
303
optimization_engine/plugins/hook_manager.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Hook Manager for Atomizer Plugin System
|
||||
|
||||
Manages registration, execution, and lifecycle of hooks.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Callable, Any, Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import importlib.util
|
||||
import json
|
||||
|
||||
from .hooks import Hook, HookPoint, HookContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""
|
||||
Central manager for all hooks in the optimization system.
|
||||
|
||||
Example:
|
||||
>>> manager = HookManager()
|
||||
>>> def my_hook(context):
|
||||
>>> print(f"Trial {context['trial_number']} starting")
|
||||
>>> return {'status': 'logged'}
|
||||
>>>
|
||||
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
|
||||
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
|
||||
self._hook_history: List[Dict[str, Any]] = []
|
||||
logger.info("HookManager initialized")
|
||||
|
||||
def register_hook(
|
||||
self,
|
||||
hook_point: str | HookPoint,
|
||||
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
|
||||
description: str,
|
||||
name: Optional[str] = None,
|
||||
priority: int = 100,
|
||||
enabled: bool = True
|
||||
) -> Hook:
|
||||
"""
|
||||
Register a new hook function.
|
||||
|
||||
Args:
|
||||
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
|
||||
function: Callable that takes context dict, returns optional dict
|
||||
description: Human-readable description
|
||||
name: Unique name (auto-generated if not provided)
|
||||
priority: Execution order (lower = earlier)
|
||||
enabled: Whether hook is active
|
||||
|
||||
Returns:
|
||||
The created Hook object
|
||||
|
||||
Raises:
|
||||
ValueError: If hook_point is invalid
|
||||
"""
|
||||
# Convert string to HookPoint enum
|
||||
if isinstance(hook_point, str):
|
||||
try:
|
||||
hook_point = HookPoint(hook_point)
|
||||
except ValueError:
|
||||
valid_points = [p.value for p in HookPoint]
|
||||
raise ValueError(
|
||||
f"Invalid hook_point '{hook_point}'. "
|
||||
f"Valid options: {valid_points}"
|
||||
)
|
||||
|
||||
# Auto-generate name if not provided
|
||||
if name is None:
|
||||
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
|
||||
|
||||
# Create hook
|
||||
hook = Hook(
|
||||
name=name,
|
||||
hook_point=hook_point,
|
||||
function=function,
|
||||
description=description,
|
||||
priority=priority,
|
||||
enabled=enabled
|
||||
)
|
||||
|
||||
# Add to registry and sort by priority
|
||||
self._hooks[hook_point].append(hook)
|
||||
self._hooks[hook_point].sort(key=lambda h: h.priority)
|
||||
|
||||
logger.info(f"Registered hook: {hook}")
|
||||
return hook
|
||||
|
||||
def execute_hooks(
|
||||
self,
|
||||
hook_point: str | HookPoint,
|
||||
context: Dict[str, Any],
|
||||
fail_fast: bool = False
|
||||
) -> List[Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute all hooks registered at a specific point.
|
||||
|
||||
Args:
|
||||
hook_point: The execution point
|
||||
context: Data to pass to hooks
|
||||
fail_fast: If True, stop on first error. If False, log and continue.
|
||||
|
||||
Returns:
|
||||
List of results from each hook (None if hook returned nothing)
|
||||
|
||||
Raises:
|
||||
Exception: If fail_fast=True and a hook fails
|
||||
"""
|
||||
# Convert string to enum
|
||||
if isinstance(hook_point, str):
|
||||
hook_point = HookPoint(hook_point)
|
||||
|
||||
hooks = self._hooks[hook_point]
|
||||
if not hooks:
|
||||
logger.debug(f"No hooks registered for {hook_point.value}")
|
||||
return []
|
||||
|
||||
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
|
||||
|
||||
results = []
|
||||
for hook in hooks:
|
||||
try:
|
||||
result = hook.execute(context)
|
||||
results.append(result)
|
||||
|
||||
# Record in history
|
||||
self._hook_history.append({
|
||||
'hook_name': hook.name,
|
||||
'hook_point': hook_point.value,
|
||||
'success': True,
|
||||
'trial_number': context.get('trial_number', -1)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hook '{hook.name}' failed: {e}")
|
||||
|
||||
# Record failure
|
||||
self._hook_history.append({
|
||||
'hook_name': hook.name,
|
||||
'hook_point': hook_point.value,
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'trial_number': context.get('trial_number', -1)
|
||||
})
|
||||
|
||||
if fail_fast:
|
||||
raise
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
def load_plugins_from_directory(self, directory: Path):
|
||||
"""
|
||||
Auto-discover and load plugins from a directory.
|
||||
|
||||
Expected structure:
|
||||
plugins/
|
||||
pre_mesh/
|
||||
my_plugin.py # Contains register_hooks(manager) function
|
||||
post_solve/
|
||||
another_plugin.py
|
||||
|
||||
Args:
|
||||
directory: Path to plugins directory
|
||||
"""
|
||||
if not directory.exists():
|
||||
logger.warning(f"Plugins directory not found: {directory}")
|
||||
return
|
||||
|
||||
logger.info(f"Loading plugins from {directory}")
|
||||
|
||||
for hook_point in HookPoint:
|
||||
hook_dir = directory / hook_point.value
|
||||
if not hook_dir.exists():
|
||||
continue
|
||||
|
||||
for plugin_file in hook_dir.glob("*.py"):
|
||||
if plugin_file.name.startswith("_"):
|
||||
continue # Skip __init__.py and private files
|
||||
|
||||
try:
|
||||
self._load_plugin_file(plugin_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load plugin {plugin_file}: {e}")
|
||||
|
||||
def _load_plugin_file(self, plugin_file: Path):
|
||||
"""Load a single plugin file and call its register_hooks function."""
|
||||
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Could not load spec for {plugin_file}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Call register_hooks if it exists
|
||||
if hasattr(module, 'register_hooks'):
|
||||
module.register_hooks(self)
|
||||
logger.info(f"Loaded plugin: {plugin_file.stem}")
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
|
||||
|
||||
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
|
||||
"""Get all hooks registered at a specific point."""
|
||||
if isinstance(hook_point, str):
|
||||
hook_point = HookPoint(hook_point)
|
||||
return self._hooks[hook_point].copy()
|
||||
|
||||
def remove_hook(self, name: str) -> bool:
|
||||
"""
|
||||
Remove a hook by name.
|
||||
|
||||
Returns:
|
||||
True if hook was found and removed, False otherwise
|
||||
"""
|
||||
for point, hooks in self._hooks.items():
|
||||
for i, hook in enumerate(hooks):
|
||||
if hook.name == name:
|
||||
del hooks[i]
|
||||
logger.info(f"Removed hook: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_hook(self, name: str):
|
||||
"""Enable a hook by name."""
|
||||
self._set_hook_enabled(name, True)
|
||||
|
||||
def disable_hook(self, name: str):
|
||||
"""Disable a hook by name."""
|
||||
self._set_hook_enabled(name, False)
|
||||
|
||||
def _set_hook_enabled(self, name: str, enabled: bool):
|
||||
"""Set enabled status for a hook."""
|
||||
for hooks in self._hooks.values():
|
||||
for hook in hooks:
|
||||
if hook.name == name:
|
||||
hook.enabled = enabled
|
||||
status = "enabled" if enabled else "disabled"
|
||||
logger.info(f"Hook '{name}' {status}")
|
||||
return
|
||||
logger.warning(f"Hook '{name}' not found")
|
||||
|
||||
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get hook execution history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of records to return (most recent)
|
||||
|
||||
Returns:
|
||||
List of execution records
|
||||
"""
|
||||
if limit:
|
||||
return self._hook_history[-limit:]
|
||||
return self._hook_history.copy()
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear all execution history."""
|
||||
self._hook_history.clear()
|
||||
logger.info("Hook history cleared")
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of the hook system state.
|
||||
|
||||
Returns:
|
||||
Dictionary with hook counts, enabled status, etc.
|
||||
"""
|
||||
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
|
||||
enabled_hooks = sum(
|
||||
sum(1 for h in hooks if h.enabled)
|
||||
for hooks in self._hooks.values()
|
||||
)
|
||||
|
||||
by_point = {
|
||||
point.value: {
|
||||
'total': len(hooks),
|
||||
'enabled': sum(1 for h in hooks if h.enabled),
|
||||
'names': [h.name for h in hooks]
|
||||
}
|
||||
for point, hooks in self._hooks.items()
|
||||
}
|
||||
|
||||
return {
|
||||
'total_hooks': total_hooks,
|
||||
'enabled_hooks': enabled_hooks,
|
||||
'disabled_hooks': total_hooks - enabled_hooks,
|
||||
'by_hook_point': by_point,
|
||||
'history_records': len(self._hook_history)
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
summary = self.get_summary()
|
||||
return (
|
||||
f"HookManager(hooks={summary['total_hooks']}, "
|
||||
f"enabled={summary['enabled_hooks']})"
|
||||
)
|
||||
115
optimization_engine/plugins/hooks.py
Normal file
115
optimization_engine/plugins/hooks.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Core Hook System for Atomizer
|
||||
|
||||
Defines hook points in the optimization lifecycle and hook registration mechanism.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookPoint(Enum):
|
||||
"""Enumeration of available hook points in the optimization lifecycle."""
|
||||
|
||||
PRE_MESH = "pre_mesh" # Before meshing
|
||||
POST_MESH = "post_mesh" # After meshing, before solve
|
||||
PRE_SOLVE = "pre_solve" # Before solver execution
|
||||
POST_SOLVE = "post_solve" # After solve, before extraction
|
||||
POST_EXTRACTION = "post_extraction" # After result extraction
|
||||
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hook:
|
||||
"""
|
||||
Represents a single hook function to be executed at a specific point.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for this hook
|
||||
hook_point: When this hook should execute (HookPoint enum)
|
||||
function: The callable to execute
|
||||
description: Human-readable description of what this hook does
|
||||
priority: Execution order (lower = earlier, default=100)
|
||||
enabled: Whether this hook is currently active
|
||||
"""
|
||||
|
||||
name: str
|
||||
hook_point: HookPoint
|
||||
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
|
||||
description: str
|
||||
priority: int = 100
|
||||
enabled: bool = True
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Execute this hook with the given context.
|
||||
|
||||
Args:
|
||||
context: Dictionary containing relevant data for this hook point
|
||||
Common keys:
|
||||
- trial_number: Current trial number
|
||||
- design_variables: Current design variable values
|
||||
- sim_file: Path to simulation file
|
||||
- working_dir: Current working directory
|
||||
|
||||
Returns:
|
||||
Optional dictionary with results or modifications to context
|
||||
|
||||
Raises:
|
||||
Exception: Any exception from the hook function is logged and re-raised
|
||||
"""
|
||||
if not self.enabled:
|
||||
logger.debug(f"Hook '{self.name}' is disabled, skipping")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
|
||||
result = self.function(context)
|
||||
logger.debug(f"Hook '{self.name}' completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "enabled" if self.enabled else "disabled"
|
||||
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
|
||||
|
||||
|
||||
class HookContext:
|
||||
"""
|
||||
Context object passed to hooks containing all relevant data.
|
||||
|
||||
This is a convenience wrapper around a dictionary that provides
|
||||
both dict-like access and attribute access.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._data = kwargs
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
self._data[key] = value
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def update(self, other: Dict[str, Any]):
|
||||
self._data.update(other)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._data.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = list(self._data.keys())
|
||||
return f"HookContext({', '.join(keys)})"
|
||||
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_extraction hooks
|
||||
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_mesh hooks
|
||||
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# post_solve hooks
|
||||
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# pre_mesh hooks
|
||||
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# pre_solve hooks
|
||||
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Example Plugin: Log Trial Start
|
||||
|
||||
Simple pre-solve hook that logs trial information.
|
||||
|
||||
This demonstrates the plugin API and serves as a template for custom hooks.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Log trial information before solver execution.
|
||||
|
||||
Args:
|
||||
context: Hook context containing:
|
||||
- trial_number: Current trial number
|
||||
- design_variables: Dict of variable values
|
||||
- sim_file: Path to simulation file
|
||||
|
||||
Returns:
|
||||
Dict with logging status
|
||||
"""
|
||||
trial_num = context.get('trial_number', '?')
|
||||
design_vars = context.get('design_variables', {})
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"TRIAL {trial_num} STARTING")
|
||||
logger.info("=" * 60)
|
||||
|
||||
for var_name, var_value in design_vars.items():
|
||||
logger.info(f" {var_name}: {var_value:.4f}")
|
||||
|
||||
return {'logged': True, 'trial': trial_num}
|
||||
|
||||
|
||||
def register_hooks(hook_manager):
|
||||
"""
|
||||
Register this plugin's hooks with the manager.
|
||||
|
||||
This function is called automatically when the plugin is loaded.
|
||||
|
||||
Args:
|
||||
hook_manager: The HookManager instance
|
||||
"""
|
||||
hook_manager.register_hook(
|
||||
hook_point='pre_solve',
|
||||
function=trial_start_logger,
|
||||
description='Log trial number and design variables before solve',
|
||||
name='log_trial_start',
|
||||
priority=10 # Run early
|
||||
)
|
||||
214
optimization_engine/plugins/validators.py
Normal file
214
optimization_engine/plugins/validators.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Plugin Code Validators
|
||||
|
||||
Ensures safety and correctness of plugin code before execution.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginValidationError(Exception):
|
||||
"""Raised when plugin code fails validation."""
|
||||
pass
|
||||
|
||||
|
||||
# Whitelist of safe modules for plugin imports
|
||||
SAFE_MODULES = {
|
||||
'math', 'numpy', 'scipy', 'pandas',
|
||||
'pathlib', 'json', 'csv', 'datetime',
|
||||
'logging', 'typing', 'dataclasses',
|
||||
'optuna', 'pyNastran'
|
||||
}
|
||||
|
||||
# Blacklist of dangerous operations
|
||||
DANGEROUS_OPERATIONS = {
|
||||
'eval', 'exec', 'compile', '__import__',
|
||||
'open', # File operations should be explicit
|
||||
'subprocess', 'os.system', 'os.popen',
|
||||
'shutil.rmtree', # Dangerous file operations
|
||||
}
|
||||
|
||||
|
||||
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
|
||||
"""
|
||||
Validate plugin code for safety before execution.
|
||||
|
||||
Args:
|
||||
code: Python source code to validate
|
||||
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If code contains unsafe operations
|
||||
|
||||
Example:
|
||||
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
|
||||
>>> validate_plugin_code(code) # Passes
|
||||
>>>
|
||||
>>> code = "import os; os.system('rm -rf /')"
|
||||
>>> validate_plugin_code(code) # Raises PluginValidationError
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error in plugin code: {e}")
|
||||
|
||||
# Check for dangerous imports
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import: {alias.name}. "
|
||||
f"Allowed modules: {SAFE_MODULES}"
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module and node.module not in SAFE_MODULES:
|
||||
# Allow submodules of safe modules
|
||||
base_module = node.module.split('.')[0]
|
||||
if base_module not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import from: {node.module}"
|
||||
)
|
||||
|
||||
# Check for dangerous function calls
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in DANGEROUS_OPERATIONS:
|
||||
if node.func.id == 'open' and allow_file_ops:
|
||||
continue # Allow if explicitly permitted
|
||||
raise PluginValidationError(
|
||||
f"Dangerous operation: {node.func.id}"
|
||||
)
|
||||
|
||||
# Check for attribute access to dangerous methods
|
||||
elif isinstance(node, ast.Attribute):
|
||||
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
|
||||
if node.attr in dangerous_attrs and not allow_file_ops:
|
||||
raise PluginValidationError(
|
||||
f"Dangerous attribute access: {node.attr}"
|
||||
)
|
||||
|
||||
logger.info("Plugin code validation passed")
|
||||
|
||||
|
||||
def check_hook_function_signature(func_code: str) -> bool:
|
||||
"""
|
||||
Verify that a hook function has the correct signature.
|
||||
|
||||
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||
|
||||
Args:
|
||||
func_code: Function source code
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If signature is invalid
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(func_code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error: {e}")
|
||||
|
||||
# Find function definition
|
||||
func_def = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
func_def = node
|
||||
break
|
||||
|
||||
if func_def is None:
|
||||
raise PluginValidationError("No function definition found")
|
||||
|
||||
# Check that it takes exactly one argument (context)
|
||||
if len(func_def.args.args) != 1:
|
||||
raise PluginValidationError(
|
||||
f"Hook function must take exactly 1 argument (context), "
|
||||
f"got {len(func_def.args.args)}"
|
||||
)
|
||||
|
||||
logger.info(f"Hook function '{func_def.name}' signature is valid")
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_plugin_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize a plugin name to be safe for filesystem and imports.
|
||||
|
||||
Args:
|
||||
name: Original plugin name
|
||||
|
||||
Returns:
|
||||
Sanitized name (lowercase, alphanumeric + underscore only)
|
||||
"""
|
||||
# Remove special characters, keep only alphanumeric and underscore
|
||||
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
|
||||
|
||||
# Ensure it doesn't start with a number
|
||||
if sanitized[0].isdigit():
|
||||
sanitized = f"plugin_{sanitized}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_imported_modules(code: str) -> Set[str]:
|
||||
"""
|
||||
Extract all imported modules from code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Set of module names
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return set()
|
||||
|
||||
modules = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
modules.add(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
modules.add(node.module)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def estimate_complexity(code: str) -> int:
|
||||
"""
|
||||
Estimate the cyclomatic complexity of code.
|
||||
|
||||
Higher values indicate more complex code with more branching.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Complexity estimate (1 = simple, 10+ = complex)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return 0
|
||||
|
||||
complexity = 1 # Base complexity
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Add 1 for each branching statement
|
||||
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
|
||||
complexity += 1
|
||||
elif isinstance(node, ast.BoolOp):
|
||||
complexity += len(node.values) - 1
|
||||
|
||||
return complexity
|
||||
Reference in New Issue
Block a user