feat: Implement Phase 1 - Plugin & Hook System

Core plugin architecture for LLM-driven optimization:

New Features:
- Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives)
- HookManager for centralized registration and execution
- Code validation with AST-based safety checks
- Feature registry (JSON) for LLM capability discovery
- Example plugin: log_trial_start
- 23 comprehensive tests (all passing)

Integration:
- OptimizationRunner now loads plugins automatically
- Hooks execute at 5 points in optimization loop
- Custom objectives can override total_objective via hooks

Safety:
- Module whitelist (numpy, scipy, pandas, optuna, pyNastran)
- Dangerous operation blocking (eval, exec, os.system, subprocess)
- Optional file operation permission flag

Files Added:
- optimization_engine/plugins/__init__.py
- optimization_engine/plugins/hooks.py
- optimization_engine/plugins/hook_manager.py
- optimization_engine/plugins/validators.py
- optimization_engine/feature_registry.json
- optimization_engine/plugins/pre_solve/log_trial_start.py
- tests/test_plugin_system.py (23 tests)

Files Modified:
- optimization_engine/runner.py (added hook integration)

Ready for Phase 2: LLM interface layer

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-15 14:46:49 -05:00
parent 0ce9ddf3e2
commit a24e3f750c
14 changed files with 1473 additions and 0 deletions

View File

@@ -0,0 +1,214 @@
"""
Plugin Code Validators
Ensures safety and correctness of plugin code before execution.
"""
import ast
import re
from typing import List, Set
import logging
logger = logging.getLogger(__name__)
class PluginValidationError(Exception):
"""Raised when plugin code fails validation."""
pass
# Whitelist of safe modules for plugin imports
SAFE_MODULES = {
'math', 'numpy', 'scipy', 'pandas',
'pathlib', 'json', 'csv', 'datetime',
'logging', 'typing', 'dataclasses',
'optuna', 'pyNastran'
}
# Blacklist of dangerous operations
DANGEROUS_OPERATIONS = {
'eval', 'exec', 'compile', '__import__',
'open', # File operations should be explicit
'subprocess', 'os.system', 'os.popen',
'shutil.rmtree', # Dangerous file operations
}
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
"""
Validate plugin code for safety before execution.
Args:
code: Python source code to validate
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
Raises:
PluginValidationError: If code contains unsafe operations
Example:
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
>>> validate_plugin_code(code) # Passes
>>>
>>> code = "import os; os.system('rm -rf /')"
>>> validate_plugin_code(code) # Raises PluginValidationError
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error in plugin code: {e}")
# Check for dangerous imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import: {alias.name}. "
f"Allowed modules: {SAFE_MODULES}"
)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module not in SAFE_MODULES:
# Allow submodules of safe modules
base_module = node.module.split('.')[0]
if base_module not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import from: {node.module}"
)
# Check for dangerous function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in DANGEROUS_OPERATIONS:
if node.func.id == 'open' and allow_file_ops:
continue # Allow if explicitly permitted
raise PluginValidationError(
f"Dangerous operation: {node.func.id}"
)
# Check for attribute access to dangerous methods
elif isinstance(node, ast.Attribute):
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
if node.attr in dangerous_attrs and not allow_file_ops:
raise PluginValidationError(
f"Dangerous attribute access: {node.attr}"
)
logger.info("Plugin code validation passed")
def check_hook_function_signature(func_code: str) -> bool:
"""
Verify that a hook function has the correct signature.
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
Args:
func_code: Function source code
Returns:
True if signature is valid
Raises:
PluginValidationError: If signature is invalid
"""
try:
tree = ast.parse(func_code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error: {e}")
# Find function definition
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
raise PluginValidationError("No function definition found")
# Check that it takes exactly one argument (context)
if len(func_def.args.args) != 1:
raise PluginValidationError(
f"Hook function must take exactly 1 argument (context), "
f"got {len(func_def.args.args)}"
)
logger.info(f"Hook function '{func_def.name}' signature is valid")
return True
def sanitize_plugin_name(name: str) -> str:
"""
Sanitize a plugin name to be safe for filesystem and imports.
Args:
name: Original plugin name
Returns:
Sanitized name (lowercase, alphanumeric + underscore only)
"""
# Remove special characters, keep only alphanumeric and underscore
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
# Ensure it doesn't start with a number
if sanitized[0].isdigit():
sanitized = f"plugin_{sanitized}"
return sanitized
def get_imported_modules(code: str) -> Set[str]:
"""
Extract all imported modules from code.
Args:
code: Python source code
Returns:
Set of module names
"""
try:
tree = ast.parse(code)
except SyntaxError:
return set()
modules = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
modules.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.add(node.module)
return modules
def estimate_complexity(code: str) -> int:
"""
Estimate the cyclomatic complexity of code.
Higher values indicate more complex code with more branching.
Args:
code: Python source code
Returns:
Complexity estimate (1 = simple, 10+ = complex)
"""
try:
tree = ast.parse(code)
except SyntaxError:
return 0
complexity = 1 # Base complexity
for node in ast.walk(tree):
# Add 1 for each branching statement
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
return complexity