feat: Implement Phase 1 - Plugin & Hook System
Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
214
optimization_engine/plugins/validators.py
Normal file
214
optimization_engine/plugins/validators.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Plugin Code Validators
|
||||
|
||||
Ensures safety and correctness of plugin code before execution.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginValidationError(Exception):
|
||||
"""Raised when plugin code fails validation."""
|
||||
pass
|
||||
|
||||
|
||||
# Whitelist of safe modules for plugin imports
|
||||
SAFE_MODULES = {
|
||||
'math', 'numpy', 'scipy', 'pandas',
|
||||
'pathlib', 'json', 'csv', 'datetime',
|
||||
'logging', 'typing', 'dataclasses',
|
||||
'optuna', 'pyNastran'
|
||||
}
|
||||
|
||||
# Blacklist of dangerous operations
|
||||
DANGEROUS_OPERATIONS = {
|
||||
'eval', 'exec', 'compile', '__import__',
|
||||
'open', # File operations should be explicit
|
||||
'subprocess', 'os.system', 'os.popen',
|
||||
'shutil.rmtree', # Dangerous file operations
|
||||
}
|
||||
|
||||
|
||||
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
|
||||
"""
|
||||
Validate plugin code for safety before execution.
|
||||
|
||||
Args:
|
||||
code: Python source code to validate
|
||||
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If code contains unsafe operations
|
||||
|
||||
Example:
|
||||
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
|
||||
>>> validate_plugin_code(code) # Passes
|
||||
>>>
|
||||
>>> code = "import os; os.system('rm -rf /')"
|
||||
>>> validate_plugin_code(code) # Raises PluginValidationError
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error in plugin code: {e}")
|
||||
|
||||
# Check for dangerous imports
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import: {alias.name}. "
|
||||
f"Allowed modules: {SAFE_MODULES}"
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module and node.module not in SAFE_MODULES:
|
||||
# Allow submodules of safe modules
|
||||
base_module = node.module.split('.')[0]
|
||||
if base_module not in SAFE_MODULES:
|
||||
raise PluginValidationError(
|
||||
f"Unsafe import from: {node.module}"
|
||||
)
|
||||
|
||||
# Check for dangerous function calls
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in DANGEROUS_OPERATIONS:
|
||||
if node.func.id == 'open' and allow_file_ops:
|
||||
continue # Allow if explicitly permitted
|
||||
raise PluginValidationError(
|
||||
f"Dangerous operation: {node.func.id}"
|
||||
)
|
||||
|
||||
# Check for attribute access to dangerous methods
|
||||
elif isinstance(node, ast.Attribute):
|
||||
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
|
||||
if node.attr in dangerous_attrs and not allow_file_ops:
|
||||
raise PluginValidationError(
|
||||
f"Dangerous attribute access: {node.attr}"
|
||||
)
|
||||
|
||||
logger.info("Plugin code validation passed")
|
||||
|
||||
|
||||
def check_hook_function_signature(func_code: str) -> bool:
|
||||
"""
|
||||
Verify that a hook function has the correct signature.
|
||||
|
||||
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||
|
||||
Args:
|
||||
func_code: Function source code
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
PluginValidationError: If signature is invalid
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(func_code)
|
||||
except SyntaxError as e:
|
||||
raise PluginValidationError(f"Syntax error: {e}")
|
||||
|
||||
# Find function definition
|
||||
func_def = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
func_def = node
|
||||
break
|
||||
|
||||
if func_def is None:
|
||||
raise PluginValidationError("No function definition found")
|
||||
|
||||
# Check that it takes exactly one argument (context)
|
||||
if len(func_def.args.args) != 1:
|
||||
raise PluginValidationError(
|
||||
f"Hook function must take exactly 1 argument (context), "
|
||||
f"got {len(func_def.args.args)}"
|
||||
)
|
||||
|
||||
logger.info(f"Hook function '{func_def.name}' signature is valid")
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_plugin_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize a plugin name to be safe for filesystem and imports.
|
||||
|
||||
Args:
|
||||
name: Original plugin name
|
||||
|
||||
Returns:
|
||||
Sanitized name (lowercase, alphanumeric + underscore only)
|
||||
"""
|
||||
# Remove special characters, keep only alphanumeric and underscore
|
||||
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
|
||||
|
||||
# Ensure it doesn't start with a number
|
||||
if sanitized[0].isdigit():
|
||||
sanitized = f"plugin_{sanitized}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_imported_modules(code: str) -> Set[str]:
|
||||
"""
|
||||
Extract all imported modules from code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Set of module names
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return set()
|
||||
|
||||
modules = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
modules.add(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
modules.add(node.module)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def estimate_complexity(code: str) -> int:
|
||||
"""
|
||||
Estimate the cyclomatic complexity of code.
|
||||
|
||||
Higher values indicate more complex code with more branching.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
Complexity estimate (1 = simple, 10+ = complex)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return 0
|
||||
|
||||
complexity = 1 # Base complexity
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Add 1 for each branching statement
|
||||
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
|
||||
complexity += 1
|
||||
elif isinstance(node, ast.BoolOp):
|
||||
complexity += len(node.values) - 1
|
||||
|
||||
return complexity
|
||||
Reference in New Issue
Block a user