Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
439 lines
13 KiB
Python
439 lines
13 KiB
Python
"""
|
|
Tests for the Plugin System
|
|
|
|
Validates hook registration, execution, validation, and integration.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional
|
|
|
|
from optimization_engine.plugins import HookManager, Hook, HookPoint
|
|
from optimization_engine.plugins.validators import (
|
|
validate_plugin_code,
|
|
PluginValidationError,
|
|
check_hook_function_signature,
|
|
sanitize_plugin_name,
|
|
get_imported_modules,
|
|
estimate_complexity
|
|
)
|
|
|
|
|
|
class TestHookRegistration:
|
|
"""Test hook registration and management."""
|
|
|
|
def test_register_simple_hook(self):
|
|
"""Test basic hook registration."""
|
|
manager = HookManager()
|
|
|
|
def my_hook(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
return {'status': 'success'}
|
|
|
|
hook = manager.register_hook(
|
|
hook_point='pre_solve',
|
|
function=my_hook,
|
|
description='Test hook',
|
|
name='test_hook'
|
|
)
|
|
|
|
assert hook.name == 'test_hook'
|
|
assert hook.hook_point == HookPoint.PRE_SOLVE
|
|
assert hook.enabled is True
|
|
|
|
def test_hook_priority_ordering(self):
|
|
"""Test that hooks execute in priority order."""
|
|
manager = HookManager()
|
|
execution_order = []
|
|
|
|
def hook_high_priority(context):
|
|
execution_order.append('high')
|
|
return None
|
|
|
|
def hook_low_priority(context):
|
|
execution_order.append('low')
|
|
return None
|
|
|
|
# Register low priority first, high priority second
|
|
manager.register_hook('pre_solve', hook_low_priority, 'Low', priority=200)
|
|
manager.register_hook('pre_solve', hook_high_priority, 'High', priority=10)
|
|
|
|
# Execute hooks
|
|
manager.execute_hooks('pre_solve', {'trial_number': 1})
|
|
|
|
# High priority should execute first
|
|
assert execution_order == ['high', 'low']
|
|
|
|
def test_disable_enable_hook(self):
|
|
"""Test disabling and enabling hooks."""
|
|
manager = HookManager()
|
|
execution_count = [0]
|
|
|
|
def counting_hook(context):
|
|
execution_count[0] += 1
|
|
return None
|
|
|
|
hook = manager.register_hook(
|
|
'pre_solve',
|
|
counting_hook,
|
|
'Counter',
|
|
name='counter_hook'
|
|
)
|
|
|
|
# Execute while enabled
|
|
manager.execute_hooks('pre_solve', {})
|
|
assert execution_count[0] == 1
|
|
|
|
# Disable and execute
|
|
manager.disable_hook('counter_hook')
|
|
manager.execute_hooks('pre_solve', {})
|
|
assert execution_count[0] == 1 # Should not increment
|
|
|
|
# Re-enable and execute
|
|
manager.enable_hook('counter_hook')
|
|
manager.execute_hooks('pre_solve', {})
|
|
assert execution_count[0] == 2
|
|
|
|
def test_remove_hook(self):
|
|
"""Test hook removal."""
|
|
manager = HookManager()
|
|
|
|
def test_hook(context):
|
|
return None
|
|
|
|
manager.register_hook('pre_solve', test_hook, 'Test', name='removable')
|
|
|
|
assert len(manager.get_hooks('pre_solve')) == 1
|
|
|
|
success = manager.remove_hook('removable')
|
|
assert success is True
|
|
assert len(manager.get_hooks('pre_solve')) == 0
|
|
|
|
# Try removing non-existent hook
|
|
success = manager.remove_hook('nonexistent')
|
|
assert success is False
|
|
|
|
|
|
class TestHookExecution:
|
|
"""Test hook execution behavior."""
|
|
|
|
def test_hook_receives_context(self):
|
|
"""Test that hooks receive correct context."""
|
|
manager = HookManager()
|
|
received_context = {}
|
|
|
|
def context_checker(context: Dict[str, Any]):
|
|
received_context.update(context)
|
|
return None
|
|
|
|
manager.register_hook('pre_solve', context_checker, 'Checker')
|
|
|
|
test_context = {
|
|
'trial_number': 42,
|
|
'design_variables': {'thickness': 5.0},
|
|
'sim_file': 'test.sim'
|
|
}
|
|
|
|
manager.execute_hooks('pre_solve', test_context)
|
|
|
|
assert received_context['trial_number'] == 42
|
|
assert received_context['design_variables']['thickness'] == 5.0
|
|
|
|
def test_hook_return_values(self):
|
|
"""Test that hook return values are collected."""
|
|
manager = HookManager()
|
|
|
|
def hook1(context):
|
|
return {'result': 'hook1'}
|
|
|
|
def hook2(context):
|
|
return {'result': 'hook2'}
|
|
|
|
manager.register_hook('pre_solve', hook1, 'Hook 1')
|
|
manager.register_hook('pre_solve', hook2, 'Hook 2')
|
|
|
|
results = manager.execute_hooks('pre_solve', {})
|
|
|
|
assert len(results) == 2
|
|
assert results[0]['result'] == 'hook1'
|
|
assert results[1]['result'] == 'hook2'
|
|
|
|
def test_hook_error_handling_fail_fast(self):
|
|
"""Test error handling with fail_fast=True."""
|
|
manager = HookManager()
|
|
|
|
def failing_hook(context):
|
|
raise ValueError("Intentional error")
|
|
|
|
manager.register_hook('pre_solve', failing_hook, 'Failing')
|
|
|
|
with pytest.raises(ValueError, match="Intentional error"):
|
|
manager.execute_hooks('pre_solve', {}, fail_fast=True)
|
|
|
|
def test_hook_error_handling_continue(self):
|
|
"""Test error handling with fail_fast=False."""
|
|
manager = HookManager()
|
|
execution_log = []
|
|
|
|
def failing_hook(context):
|
|
execution_log.append('failing')
|
|
raise ValueError("Intentional error")
|
|
|
|
def successful_hook(context):
|
|
execution_log.append('successful')
|
|
return {'status': 'ok'}
|
|
|
|
manager.register_hook('pre_solve', failing_hook, 'Failing', priority=10)
|
|
manager.register_hook('pre_solve', successful_hook, 'Success', priority=20)
|
|
|
|
results = manager.execute_hooks('pre_solve', {}, fail_fast=False)
|
|
|
|
# Both hooks should have attempted execution
|
|
assert execution_log == ['failing', 'successful']
|
|
# First result is None (error), second is successful
|
|
assert results[0] is None
|
|
assert results[1]['status'] == 'ok'
|
|
|
|
def test_hook_history_tracking(self):
|
|
"""Test that hook execution history is tracked."""
|
|
manager = HookManager()
|
|
|
|
def test_hook(context):
|
|
return {'result': 'success'}
|
|
|
|
manager.register_hook('pre_solve', test_hook, 'Test', name='tracked')
|
|
|
|
# Execute hooks multiple times
|
|
for i in range(3):
|
|
manager.execute_hooks('pre_solve', {'trial_number': i})
|
|
|
|
history = manager.get_history()
|
|
assert len(history) >= 3
|
|
|
|
# Check history contains success records
|
|
successful = [h for h in history if h['success']]
|
|
assert len(successful) >= 3
|
|
|
|
|
|
class TestCodeValidation:
|
|
"""Test plugin code validation."""
|
|
|
|
def test_safe_code_passes(self):
|
|
"""Test that safe code passes validation."""
|
|
safe_code = """
|
|
import numpy as np
|
|
import math
|
|
|
|
def my_hook(context):
|
|
x = context['design_variables']['thickness']
|
|
result = math.sqrt(x**2 + np.mean([1, 2, 3]))
|
|
return {'result': result}
|
|
"""
|
|
# Should not raise
|
|
validate_plugin_code(safe_code)
|
|
|
|
def test_dangerous_import_blocked(self):
|
|
"""Test that dangerous imports are blocked."""
|
|
dangerous_code = """
|
|
import os
|
|
|
|
def my_hook(context):
|
|
os.system('rm -rf /')
|
|
return None
|
|
"""
|
|
with pytest.raises(PluginValidationError, match="Unsafe import"):
|
|
validate_plugin_code(dangerous_code)
|
|
|
|
def test_dangerous_operation_blocked(self):
|
|
"""Test that dangerous operations are blocked."""
|
|
dangerous_code = """
|
|
def my_hook(context):
|
|
eval('malicious_code')
|
|
return None
|
|
"""
|
|
with pytest.raises(PluginValidationError, match="Dangerous operation"):
|
|
validate_plugin_code(dangerous_code)
|
|
|
|
def test_file_operations_with_permission(self):
|
|
"""Test that file operations work with allow_file_ops=True."""
|
|
code_with_file_ops = """
|
|
def my_hook(context):
|
|
with open('output.txt', 'w') as f:
|
|
f.write('test')
|
|
return None
|
|
"""
|
|
# Should raise without permission
|
|
with pytest.raises(PluginValidationError, match="Dangerous operation: open"):
|
|
validate_plugin_code(code_with_file_ops, allow_file_ops=False)
|
|
|
|
# Should pass with permission
|
|
validate_plugin_code(code_with_file_ops, allow_file_ops=True)
|
|
|
|
def test_syntax_error_detected(self):
|
|
"""Test that syntax errors are detected."""
|
|
bad_syntax = """
|
|
def my_hook(context)
|
|
return None # Missing colon
|
|
"""
|
|
with pytest.raises(PluginValidationError, match="Syntax error"):
|
|
validate_plugin_code(bad_syntax)
|
|
|
|
def test_hook_signature_validation(self):
|
|
"""Test hook function signature validation."""
|
|
# Valid signature
|
|
valid_code = """
|
|
def my_hook(context):
|
|
return None
|
|
"""
|
|
assert check_hook_function_signature(valid_code) is True
|
|
|
|
# Invalid: too many arguments
|
|
invalid_code = """
|
|
def my_hook(context, extra_arg):
|
|
return None
|
|
"""
|
|
with pytest.raises(PluginValidationError, match="must take exactly 1 argument"):
|
|
check_hook_function_signature(invalid_code)
|
|
|
|
# Invalid: no function
|
|
invalid_code = """
|
|
x = 5
|
|
"""
|
|
with pytest.raises(PluginValidationError, match="No function definition"):
|
|
check_hook_function_signature(invalid_code)
|
|
|
|
|
|
class TestUtilityFunctions:
|
|
"""Test plugin utility functions."""
|
|
|
|
def test_sanitize_plugin_name(self):
|
|
"""Test plugin name sanitization."""
|
|
assert sanitize_plugin_name('My-Plugin!') == 'my_plugin_'
|
|
assert sanitize_plugin_name('123_plugin') == 'plugin_123_plugin'
|
|
assert sanitize_plugin_name('valid_name') == 'valid_name'
|
|
|
|
def test_get_imported_modules(self):
|
|
"""Test module extraction from code."""
|
|
code = """
|
|
import numpy as np
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
"""
|
|
modules = get_imported_modules(code)
|
|
assert 'numpy' in modules
|
|
assert 'math' in modules
|
|
assert 'pathlib' in modules
|
|
assert 'typing' in modules
|
|
|
|
def test_estimate_complexity(self):
|
|
"""Test cyclomatic complexity estimation."""
|
|
simple_code = """
|
|
def simple(x):
|
|
return x + 1
|
|
"""
|
|
complex_code = """
|
|
def complex(x):
|
|
if x > 0:
|
|
for i in range(x):
|
|
if i % 2 == 0:
|
|
while i > 0:
|
|
i -= 1
|
|
return x
|
|
"""
|
|
simple_complexity = estimate_complexity(simple_code)
|
|
complex_complexity = estimate_complexity(complex_code)
|
|
|
|
assert simple_complexity == 1
|
|
assert complex_complexity > simple_complexity
|
|
|
|
|
|
class TestHookManager:
|
|
"""Test HookManager functionality."""
|
|
|
|
def test_get_summary(self):
|
|
"""Test hook system summary."""
|
|
manager = HookManager()
|
|
|
|
def hook1(context):
|
|
return None
|
|
|
|
def hook2(context):
|
|
return None
|
|
|
|
manager.register_hook('pre_solve', hook1, 'Hook 1', name='hook1')
|
|
manager.register_hook('post_solve', hook2, 'Hook 2', name='hook2')
|
|
manager.disable_hook('hook2')
|
|
|
|
summary = manager.get_summary()
|
|
|
|
assert summary['total_hooks'] == 2
|
|
assert summary['enabled_hooks'] == 1
|
|
assert summary['disabled_hooks'] == 1
|
|
|
|
def test_clear_history(self):
|
|
"""Test clearing execution history."""
|
|
manager = HookManager()
|
|
|
|
def test_hook(context):
|
|
return None
|
|
|
|
manager.register_hook('pre_solve', test_hook, 'Test')
|
|
manager.execute_hooks('pre_solve', {})
|
|
|
|
assert len(manager.get_history()) > 0
|
|
|
|
manager.clear_history()
|
|
assert len(manager.get_history()) == 0
|
|
|
|
def test_hook_manager_repr(self):
|
|
"""Test HookManager string representation."""
|
|
manager = HookManager()
|
|
|
|
def hook(context):
|
|
return None
|
|
|
|
manager.register_hook('pre_solve', hook, 'Test')
|
|
|
|
repr_str = repr(manager)
|
|
assert 'HookManager' in repr_str
|
|
assert 'hooks=1' in repr_str
|
|
assert 'enabled=1' in repr_str
|
|
|
|
|
|
class TestPluginLoading:
|
|
"""Test plugin directory loading."""
|
|
|
|
def test_load_plugins_from_nonexistent_directory(self):
|
|
"""Test loading from non-existent directory."""
|
|
manager = HookManager()
|
|
# Should not raise, just log warning
|
|
manager.load_plugins_from_directory(Path('/nonexistent/path'))
|
|
|
|
def test_plugin_registration_function(self):
|
|
"""Test that plugins can register hooks via register_hooks()."""
|
|
manager = HookManager()
|
|
|
|
# Simulate what a plugin file would contain
|
|
def register_hooks(hook_manager):
|
|
def my_plugin_hook(context):
|
|
return {'plugin': 'loaded'}
|
|
|
|
hook_manager.register_hook(
|
|
'pre_solve',
|
|
my_plugin_hook,
|
|
'Plugin hook',
|
|
name='plugin_hook'
|
|
)
|
|
|
|
# Call the registration function
|
|
register_hooks(manager)
|
|
|
|
# Verify hook was registered
|
|
hooks = manager.get_hooks('pre_solve')
|
|
assert len(hooks) == 1
|
|
assert hooks[0].name == 'plugin_hook'
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|