feat: Implement Phase 1 - Plugin & Hook System
Core plugin architecture for LLM-driven optimization: New Features: - Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives) - HookManager for centralized registration and execution - Code validation with AST-based safety checks - Feature registry (JSON) for LLM capability discovery - Example plugin: log_trial_start - 23 comprehensive tests (all passing) Integration: - OptimizationRunner now loads plugins automatically - Hooks execute at 5 points in optimization loop - Custom objectives can override total_objective via hooks Safety: - Module whitelist (numpy, scipy, pandas, optuna, pyNastran) - Dangerous operation blocking (eval, exec, os.system, subprocess) - Optional file operation permission flag Files Added: - optimization_engine/plugins/__init__.py - optimization_engine/plugins/hooks.py - optimization_engine/plugins/hook_manager.py - optimization_engine/plugins/validators.py - optimization_engine/feature_registry.json - optimization_engine/plugins/pre_solve/log_trial_start.py - tests/test_plugin_system.py (23 tests) Files Modified: - optimization_engine/runner.py (added hook integration) Ready for Phase 2: LLM interface layer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
243
optimization_engine/feature_registry.json
Normal file
243
optimization_engine/feature_registry.json
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
{
|
||||||
|
"version": "1.0.0",
|
||||||
|
"last_updated": "2025-01-15",
|
||||||
|
"description": "Registry of all Atomizer capabilities for LLM discovery and usage",
|
||||||
|
|
||||||
|
"core_features": {
|
||||||
|
"optimization": {
|
||||||
|
"description": "Core optimization engine using Optuna",
|
||||||
|
"module": "optimization_engine.runner",
|
||||||
|
"capabilities": [
|
||||||
|
"Multi-objective optimization with weighted sum",
|
||||||
|
"TPE (Tree-structured Parzen Estimator) sampler",
|
||||||
|
"CMA-ES sampler",
|
||||||
|
"Gaussian Process sampler",
|
||||||
|
"50-trial default with 20 startup trials",
|
||||||
|
"Automatic checkpoint and resume",
|
||||||
|
"SQLite-based study persistence"
|
||||||
|
],
|
||||||
|
"usage": "python examples/test_journal_optimization.py",
|
||||||
|
"llm_hint": "Use this for Bayesian optimization with NX simulations"
|
||||||
|
},
|
||||||
|
|
||||||
|
"nx_integration": {
|
||||||
|
"description": "Siemens NX simulation automation via journal scripts",
|
||||||
|
"module": "optimization_engine.nx_solver",
|
||||||
|
"capabilities": [
|
||||||
|
"Update CAD expressions via NXOpen",
|
||||||
|
"Execute NX Nastran solver",
|
||||||
|
"Extract OP2 results (stress, displacement)",
|
||||||
|
"Extract mass properties",
|
||||||
|
"Precision control (4 decimals for mm/degrees/MPa)"
|
||||||
|
],
|
||||||
|
"usage": "from optimization_engine.nx_solver import run_nx_simulation",
|
||||||
|
"llm_hint": "Use for running FEA simulations and extracting results"
|
||||||
|
},
|
||||||
|
|
||||||
|
"result_extraction": {
|
||||||
|
"description": "Extract metrics from simulation results",
|
||||||
|
"module": "optimization_engine.result_extractors",
|
||||||
|
"extractors": {
|
||||||
|
"stress_extractor": {
|
||||||
|
"description": "Extract stress data from OP2 files",
|
||||||
|
"metrics": ["max_von_mises", "mean_von_mises", "max_principal"],
|
||||||
|
"file_type": "OP2",
|
||||||
|
"usage": "Returns stress in MPa"
|
||||||
|
},
|
||||||
|
"displacement_extractor": {
|
||||||
|
"description": "Extract displacement data from OP2 files",
|
||||||
|
"metrics": ["max_displacement", "mean_displacement"],
|
||||||
|
"file_type": "OP2",
|
||||||
|
"usage": "Returns displacement in mm"
|
||||||
|
},
|
||||||
|
"mass_extractor": {
|
||||||
|
"description": "Extract mass properties",
|
||||||
|
"metrics": ["total_mass", "center_of_gravity"],
|
||||||
|
"file_type": "NX Part",
|
||||||
|
"usage": "Returns mass in kg"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"llm_hint": "Use extractors to define objectives and constraints"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"plugin_system": {
|
||||||
|
"description": "Extensible hook system for custom functionality",
|
||||||
|
"module": "optimization_engine.plugins",
|
||||||
|
"version": "1.0.0",
|
||||||
|
|
||||||
|
"hook_points": {
|
||||||
|
"pre_mesh": {
|
||||||
|
"description": "Execute before meshing operations",
|
||||||
|
"context": ["trial_number", "design_variables", "sim_file", "working_dir"],
|
||||||
|
"use_cases": [
|
||||||
|
"Modify geometry based on parameters",
|
||||||
|
"Set up boundary conditions",
|
||||||
|
"Configure mesh settings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"post_mesh": {
|
||||||
|
"description": "Execute after meshing, before solve",
|
||||||
|
"context": ["trial_number", "mesh_info", "element_count", "node_count"],
|
||||||
|
"use_cases": [
|
||||||
|
"Validate mesh quality",
|
||||||
|
"Export mesh for visualization",
|
||||||
|
"Log mesh statistics"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"pre_solve": {
|
||||||
|
"description": "Execute before solver launch",
|
||||||
|
"context": ["trial_number", "design_variables", "solver_settings"],
|
||||||
|
"use_cases": [
|
||||||
|
"Log trial parameters",
|
||||||
|
"Modify solver settings",
|
||||||
|
"Set up custom load cases"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"post_solve": {
|
||||||
|
"description": "Execute after solve, before result extraction",
|
||||||
|
"context": ["trial_number", "solve_status", "output_files"],
|
||||||
|
"use_cases": [
|
||||||
|
"Check solver convergence",
|
||||||
|
"Post-process results",
|
||||||
|
"Generate visualizations"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"post_extraction": {
|
||||||
|
"description": "Execute after result extraction",
|
||||||
|
"context": ["trial_number", "extracted_results", "objectives", "constraints"],
|
||||||
|
"use_cases": [
|
||||||
|
"Calculate custom metrics",
|
||||||
|
"Combine multiple objectives (RSS)",
|
||||||
|
"Apply penalty functions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"custom_objective": {
|
||||||
|
"description": "Define custom objective functions",
|
||||||
|
"context": ["extracted_results", "design_variables"],
|
||||||
|
"use_cases": [
|
||||||
|
"RSS of stress and displacement",
|
||||||
|
"Weighted multi-criteria",
|
||||||
|
"Conditional objectives"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"api": {
|
||||||
|
"register_hook": {
|
||||||
|
"description": "Register a new hook function",
|
||||||
|
"signature": "hook_manager.register_hook(hook_point, function, description, name=None, priority=100)",
|
||||||
|
"parameters": {
|
||||||
|
"hook_point": "One of: pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objective",
|
||||||
|
"function": "Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]",
|
||||||
|
"description": "Human-readable description",
|
||||||
|
"priority": "Execution order (lower = earlier)"
|
||||||
|
},
|
||||||
|
"example": "See optimization_engine/plugins/pre_solve/log_trial_start.py"
|
||||||
|
},
|
||||||
|
"execute_hooks": {
|
||||||
|
"description": "Execute all hooks at a specific point",
|
||||||
|
"signature": "hook_manager.execute_hooks(hook_point, context, fail_fast=False)",
|
||||||
|
"returns": "List of hook results"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"validators": {
|
||||||
|
"validate_plugin_code": {
|
||||||
|
"description": "Validate plugin code for safety",
|
||||||
|
"checks": [
|
||||||
|
"Syntax errors",
|
||||||
|
"Dangerous imports (os.system, subprocess, etc.)",
|
||||||
|
"File operations (optional allow)",
|
||||||
|
"Function signature correctness"
|
||||||
|
],
|
||||||
|
"safe_modules": ["math", "numpy", "scipy", "pandas", "pathlib", "json", "optuna", "pyNastran"],
|
||||||
|
"llm_hint": "Always validate LLM-generated code before execution"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"design_variables": {
|
||||||
|
"description": "Parametric CAD variables to optimize",
|
||||||
|
"schema": {
|
||||||
|
"name": "Unique identifier",
|
||||||
|
"expression_name": "NX expression name",
|
||||||
|
"min": "Lower bound (float)",
|
||||||
|
"max": "Upper bound (float)",
|
||||||
|
"units": "Unit system (mm, degrees, etc.)"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"name": "wall_thickness",
|
||||||
|
"expression_name": "wall_thickness",
|
||||||
|
"min": 3.0,
|
||||||
|
"max": 8.0,
|
||||||
|
"units": "mm"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"objectives": {
|
||||||
|
"description": "Metrics to minimize or maximize",
|
||||||
|
"schema": {
|
||||||
|
"name": "Unique identifier",
|
||||||
|
"extractor": "Result extractor to use",
|
||||||
|
"metric": "Specific metric from extractor",
|
||||||
|
"direction": "minimize or maximize",
|
||||||
|
"weight": "Importance (for multi-objective)",
|
||||||
|
"units": "Unit system"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"name": "max_stress",
|
||||||
|
"extractor": "stress_extractor",
|
||||||
|
"metric": "max_von_mises",
|
||||||
|
"direction": "minimize",
|
||||||
|
"weight": 1.0,
|
||||||
|
"units": "MPa"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"constraints": {
|
||||||
|
"description": "Limits on simulation outputs",
|
||||||
|
"schema": {
|
||||||
|
"name": "Unique identifier",
|
||||||
|
"extractor": "Result extractor to use",
|
||||||
|
"metric": "Specific metric",
|
||||||
|
"type": "upper_bound or lower_bound",
|
||||||
|
"limit": "Constraint value",
|
||||||
|
"units": "Unit system"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"name": "max_displacement_limit",
|
||||||
|
"extractor": "displacement_extractor",
|
||||||
|
"metric": "max_displacement",
|
||||||
|
"type": "upper_bound",
|
||||||
|
"limit": 1.0,
|
||||||
|
"units": "mm"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"examples": {
|
||||||
|
"bracket_optimization": {
|
||||||
|
"description": "Minimize stress on a bracket by varying wall thickness",
|
||||||
|
"location": "examples/bracket/",
|
||||||
|
"design_variables": ["wall_thickness"],
|
||||||
|
"objectives": ["max_von_mises"],
|
||||||
|
"trials": 50,
|
||||||
|
"typical_runtime": "2-3 hours",
|
||||||
|
"llm_hint": "Good template for single-objective structural optimization"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"llm_guidelines": {
|
||||||
|
"code_generation": {
|
||||||
|
"hook_template": "Always include: function signature with context dict, docstring, return dict",
|
||||||
|
"validation": "Use validate_plugin_code() before registration",
|
||||||
|
"error_handling": "Wrap in try-except, log errors, return None on failure"
|
||||||
|
},
|
||||||
|
"natural_language_mapping": {
|
||||||
|
"minimize stress": "objective with direction='minimize', extractor='stress_extractor'",
|
||||||
|
"vary thickness 3-8mm": "design_variable with min=3.0, max=8.0, units='mm'",
|
||||||
|
"displacement < 1mm": "constraint with type='upper_bound', limit=1.0",
|
||||||
|
"RSS of stress and displacement": "custom_objective hook with sqrt(stress² + displacement²)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
32
optimization_engine/plugins/__init__.py
Normal file
32
optimization_engine/plugins/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Atomizer Plugin System
|
||||||
|
|
||||||
|
Enables extensibility through hooks at various stages of the optimization lifecycle.
|
||||||
|
|
||||||
|
Hook Points:
|
||||||
|
- pre_mesh: Before meshing operations
|
||||||
|
- post_mesh: After meshing, before solve
|
||||||
|
- pre_solve: Before solver execution
|
||||||
|
- post_solve: After solve, before result extraction
|
||||||
|
- post_extraction: After result extraction, before objective calculation
|
||||||
|
- custom_objectives: Custom objective/constraint functions
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from optimization_engine.plugins import HookManager
|
||||||
|
|
||||||
|
hook_manager = HookManager()
|
||||||
|
hook_manager.register_hook('pre_solve', my_custom_function)
|
||||||
|
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .hooks import Hook, HookPoint
|
||||||
|
from .hook_manager import HookManager
|
||||||
|
from .validators import validate_plugin_code, PluginValidationError
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Hook',
|
||||||
|
'HookPoint',
|
||||||
|
'HookManager',
|
||||||
|
'validate_plugin_code',
|
||||||
|
'PluginValidationError'
|
||||||
|
]
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# custom_objectives hooks
|
||||||
303
optimization_engine/plugins/hook_manager.py
Normal file
303
optimization_engine/plugins/hook_manager.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
"""
|
||||||
|
Hook Manager for Atomizer Plugin System
|
||||||
|
|
||||||
|
Manages registration, execution, and lifecycle of hooks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Callable, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .hooks import Hook, HookPoint, HookContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HookManager:
|
||||||
|
"""
|
||||||
|
Central manager for all hooks in the optimization system.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> manager = HookManager()
|
||||||
|
>>> def my_hook(context):
|
||||||
|
>>> print(f"Trial {context['trial_number']} starting")
|
||||||
|
>>> return {'status': 'logged'}
|
||||||
|
>>>
|
||||||
|
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
|
||||||
|
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
|
||||||
|
self._hook_history: List[Dict[str, Any]] = []
|
||||||
|
logger.info("HookManager initialized")
|
||||||
|
|
||||||
|
def register_hook(
|
||||||
|
self,
|
||||||
|
hook_point: str | HookPoint,
|
||||||
|
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
|
||||||
|
description: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
priority: int = 100,
|
||||||
|
enabled: bool = True
|
||||||
|
) -> Hook:
|
||||||
|
"""
|
||||||
|
Register a new hook function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
|
||||||
|
function: Callable that takes context dict, returns optional dict
|
||||||
|
description: Human-readable description
|
||||||
|
name: Unique name (auto-generated if not provided)
|
||||||
|
priority: Execution order (lower = earlier)
|
||||||
|
enabled: Whether hook is active
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created Hook object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If hook_point is invalid
|
||||||
|
"""
|
||||||
|
# Convert string to HookPoint enum
|
||||||
|
if isinstance(hook_point, str):
|
||||||
|
try:
|
||||||
|
hook_point = HookPoint(hook_point)
|
||||||
|
except ValueError:
|
||||||
|
valid_points = [p.value for p in HookPoint]
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid hook_point '{hook_point}'. "
|
||||||
|
f"Valid options: {valid_points}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-generate name if not provided
|
||||||
|
if name is None:
|
||||||
|
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
|
||||||
|
|
||||||
|
# Create hook
|
||||||
|
hook = Hook(
|
||||||
|
name=name,
|
||||||
|
hook_point=hook_point,
|
||||||
|
function=function,
|
||||||
|
description=description,
|
||||||
|
priority=priority,
|
||||||
|
enabled=enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to registry and sort by priority
|
||||||
|
self._hooks[hook_point].append(hook)
|
||||||
|
self._hooks[hook_point].sort(key=lambda h: h.priority)
|
||||||
|
|
||||||
|
logger.info(f"Registered hook: {hook}")
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def execute_hooks(
|
||||||
|
self,
|
||||||
|
hook_point: str | HookPoint,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
fail_fast: bool = False
|
||||||
|
) -> List[Optional[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Execute all hooks registered at a specific point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_point: The execution point
|
||||||
|
context: Data to pass to hooks
|
||||||
|
fail_fast: If True, stop on first error. If False, log and continue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results from each hook (None if hook returned nothing)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If fail_fast=True and a hook fails
|
||||||
|
"""
|
||||||
|
# Convert string to enum
|
||||||
|
if isinstance(hook_point, str):
|
||||||
|
hook_point = HookPoint(hook_point)
|
||||||
|
|
||||||
|
hooks = self._hooks[hook_point]
|
||||||
|
if not hooks:
|
||||||
|
logger.debug(f"No hooks registered for {hook_point.value}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
result = hook.execute(context)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Record in history
|
||||||
|
self._hook_history.append({
|
||||||
|
'hook_name': hook.name,
|
||||||
|
'hook_point': hook_point.value,
|
||||||
|
'success': True,
|
||||||
|
'trial_number': context.get('trial_number', -1)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Hook '{hook.name}' failed: {e}")
|
||||||
|
|
||||||
|
# Record failure
|
||||||
|
self._hook_history.append({
|
||||||
|
'hook_name': hook.name,
|
||||||
|
'hook_point': hook_point.value,
|
||||||
|
'success': False,
|
||||||
|
'error': str(e),
|
||||||
|
'trial_number': context.get('trial_number', -1)
|
||||||
|
})
|
||||||
|
|
||||||
|
if fail_fast:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def load_plugins_from_directory(self, directory: Path):
|
||||||
|
"""
|
||||||
|
Auto-discover and load plugins from a directory.
|
||||||
|
|
||||||
|
Expected structure:
|
||||||
|
plugins/
|
||||||
|
pre_mesh/
|
||||||
|
my_plugin.py # Contains register_hooks(manager) function
|
||||||
|
post_solve/
|
||||||
|
another_plugin.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Path to plugins directory
|
||||||
|
"""
|
||||||
|
if not directory.exists():
|
||||||
|
logger.warning(f"Plugins directory not found: {directory}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading plugins from {directory}")
|
||||||
|
|
||||||
|
for hook_point in HookPoint:
|
||||||
|
hook_dir = directory / hook_point.value
|
||||||
|
if not hook_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for plugin_file in hook_dir.glob("*.py"):
|
||||||
|
if plugin_file.name.startswith("_"):
|
||||||
|
continue # Skip __init__.py and private files
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._load_plugin_file(plugin_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load plugin {plugin_file}: {e}")
|
||||||
|
|
||||||
|
def _load_plugin_file(self, plugin_file: Path):
|
||||||
|
"""Load a single plugin file and call its register_hooks function."""
|
||||||
|
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise ImportError(f"Could not load spec for {plugin_file}")
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Call register_hooks if it exists
|
||||||
|
if hasattr(module, 'register_hooks'):
|
||||||
|
module.register_hooks(self)
|
||||||
|
logger.info(f"Loaded plugin: {plugin_file.stem}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
|
||||||
|
|
||||||
|
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
|
||||||
|
"""Get all hooks registered at a specific point."""
|
||||||
|
if isinstance(hook_point, str):
|
||||||
|
hook_point = HookPoint(hook_point)
|
||||||
|
return self._hooks[hook_point].copy()
|
||||||
|
|
||||||
|
def remove_hook(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a hook by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if hook was found and removed, False otherwise
|
||||||
|
"""
|
||||||
|
for point, hooks in self._hooks.items():
|
||||||
|
for i, hook in enumerate(hooks):
|
||||||
|
if hook.name == name:
|
||||||
|
del hooks[i]
|
||||||
|
logger.info(f"Removed hook: {name}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def enable_hook(self, name: str):
|
||||||
|
"""Enable a hook by name."""
|
||||||
|
self._set_hook_enabled(name, True)
|
||||||
|
|
||||||
|
def disable_hook(self, name: str):
|
||||||
|
"""Disable a hook by name."""
|
||||||
|
self._set_hook_enabled(name, False)
|
||||||
|
|
||||||
|
def _set_hook_enabled(self, name: str, enabled: bool):
|
||||||
|
"""Set enabled status for a hook."""
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
for hook in hooks:
|
||||||
|
if hook.name == name:
|
||||||
|
hook.enabled = enabled
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
logger.info(f"Hook '{name}' {status}")
|
||||||
|
return
|
||||||
|
logger.warning(f"Hook '{name}' not found")
|
||||||
|
|
||||||
|
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get hook execution history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of records to return (most recent)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of execution records
|
||||||
|
"""
|
||||||
|
if limit:
|
||||||
|
return self._hook_history[-limit:]
|
||||||
|
return self._hook_history.copy()
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""Clear all execution history."""
|
||||||
|
self._hook_history.clear()
|
||||||
|
logger.info("Hook history cleared")
|
||||||
|
|
||||||
|
def get_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a summary of the hook system state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with hook counts, enabled status, etc.
|
||||||
|
"""
|
||||||
|
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
|
||||||
|
enabled_hooks = sum(
|
||||||
|
sum(1 for h in hooks if h.enabled)
|
||||||
|
for hooks in self._hooks.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
by_point = {
|
||||||
|
point.value: {
|
||||||
|
'total': len(hooks),
|
||||||
|
'enabled': sum(1 for h in hooks if h.enabled),
|
||||||
|
'names': [h.name for h in hooks]
|
||||||
|
}
|
||||||
|
for point, hooks in self._hooks.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_hooks': total_hooks,
|
||||||
|
'enabled_hooks': enabled_hooks,
|
||||||
|
'disabled_hooks': total_hooks - enabled_hooks,
|
||||||
|
'by_hook_point': by_point,
|
||||||
|
'history_records': len(self._hook_history)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
summary = self.get_summary()
|
||||||
|
return (
|
||||||
|
f"HookManager(hooks={summary['total_hooks']}, "
|
||||||
|
f"enabled={summary['enabled_hooks']})"
|
||||||
|
)
|
||||||
115
optimization_engine/plugins/hooks.py
Normal file
115
optimization_engine/plugins/hooks.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
Core Hook System for Atomizer
|
||||||
|
|
||||||
|
Defines hook points in the optimization lifecycle and hook registration mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HookPoint(Enum):
|
||||||
|
"""Enumeration of available hook points in the optimization lifecycle."""
|
||||||
|
|
||||||
|
PRE_MESH = "pre_mesh" # Before meshing
|
||||||
|
POST_MESH = "post_mesh" # After meshing, before solve
|
||||||
|
PRE_SOLVE = "pre_solve" # Before solver execution
|
||||||
|
POST_SOLVE = "post_solve" # After solve, before extraction
|
||||||
|
POST_EXTRACTION = "post_extraction" # After result extraction
|
||||||
|
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hook:
|
||||||
|
"""
|
||||||
|
Represents a single hook function to be executed at a specific point.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Unique identifier for this hook
|
||||||
|
hook_point: When this hook should execute (HookPoint enum)
|
||||||
|
function: The callable to execute
|
||||||
|
description: Human-readable description of what this hook does
|
||||||
|
priority: Execution order (lower = earlier, default=100)
|
||||||
|
enabled: Whether this hook is currently active
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_point: HookPoint
|
||||||
|
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
|
||||||
|
description: str
|
||||||
|
priority: int = 100
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute this hook with the given context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Dictionary containing relevant data for this hook point
|
||||||
|
Common keys:
|
||||||
|
- trial_number: Current trial number
|
||||||
|
- design_variables: Current design variable values
|
||||||
|
- sim_file: Path to simulation file
|
||||||
|
- working_dir: Current working directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional dictionary with results or modifications to context
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Any exception from the hook function is logged and re-raised
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
logger.debug(f"Hook '{self.name}' is disabled, skipping")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
|
||||||
|
result = self.function(context)
|
||||||
|
logger.debug(f"Hook '{self.name}' completed successfully")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
status = "enabled" if self.enabled else "disabled"
|
||||||
|
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
|
||||||
|
|
||||||
|
|
||||||
|
class HookContext:
|
||||||
|
"""
|
||||||
|
Context object passed to hooks containing all relevant data.
|
||||||
|
|
||||||
|
This is a convenience wrapper around a dictionary that provides
|
||||||
|
both dict-like access and attribute access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._data = kwargs
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any):
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return self._data.get(key, default)
|
||||||
|
|
||||||
|
def update(self, other: Dict[str, Any]):
|
||||||
|
self._data.update(other)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._data.copy()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
keys = list(self._data.keys())
|
||||||
|
return f"HookContext({', '.join(keys)})"
|
||||||
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
1
optimization_engine/plugins/post_extraction/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# post_extraction hooks
|
||||||
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
1
optimization_engine/plugins/post_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# post_mesh hooks
|
||||||
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
1
optimization_engine/plugins/post_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# post_solve hooks
|
||||||
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
1
optimization_engine/plugins/pre_mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# pre_mesh hooks
|
||||||
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
1
optimization_engine/plugins/pre_solve/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# pre_solve hooks
|
||||||
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
56
optimization_engine/plugins/pre_solve/log_trial_start.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
Example Plugin: Log Trial Start
|
||||||
|
|
||||||
|
Simple pre-solve hook that logs trial information.
|
||||||
|
|
||||||
|
This demonstrates the plugin API and serves as a template for custom hooks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Log trial information before solver execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Hook context containing:
|
||||||
|
- trial_number: Current trial number
|
||||||
|
- design_variables: Dict of variable values
|
||||||
|
- sim_file: Path to simulation file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with logging status
|
||||||
|
"""
|
||||||
|
trial_num = context.get('trial_number', '?')
|
||||||
|
design_vars = context.get('design_variables', {})
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"TRIAL {trial_num} STARTING")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
for var_name, var_value in design_vars.items():
|
||||||
|
logger.info(f" {var_name}: {var_value:.4f}")
|
||||||
|
|
||||||
|
return {'logged': True, 'trial': trial_num}
|
||||||
|
|
||||||
|
|
||||||
|
def register_hooks(hook_manager):
|
||||||
|
"""
|
||||||
|
Register this plugin's hooks with the manager.
|
||||||
|
|
||||||
|
This function is called automatically when the plugin is loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_manager: The HookManager instance
|
||||||
|
"""
|
||||||
|
hook_manager.register_hook(
|
||||||
|
hook_point='pre_solve',
|
||||||
|
function=trial_start_logger,
|
||||||
|
description='Log trial number and design variables before solve',
|
||||||
|
name='log_trial_start',
|
||||||
|
priority=10 # Run early
|
||||||
|
)
|
||||||
214
optimization_engine/plugins/validators.py
Normal file
214
optimization_engine/plugins/validators.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
Plugin Code Validators
|
||||||
|
|
||||||
|
Ensures safety and correctness of plugin code before execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from typing import List, Set
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginValidationError(Exception):
|
||||||
|
"""Raised when plugin code fails validation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Whitelist of safe modules for plugin imports
|
||||||
|
SAFE_MODULES = {
|
||||||
|
'math', 'numpy', 'scipy', 'pandas',
|
||||||
|
'pathlib', 'json', 'csv', 'datetime',
|
||||||
|
'logging', 'typing', 'dataclasses',
|
||||||
|
'optuna', 'pyNastran'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Blacklist of dangerous operations
|
||||||
|
DANGEROUS_OPERATIONS = {
|
||||||
|
'eval', 'exec', 'compile', '__import__',
|
||||||
|
'open', # File operations should be explicit
|
||||||
|
'subprocess', 'os.system', 'os.popen',
|
||||||
|
'shutil.rmtree', # Dangerous file operations
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Validate plugin code for safety before execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Python source code to validate
|
||||||
|
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PluginValidationError: If code contains unsafe operations
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
|
||||||
|
>>> validate_plugin_code(code) # Passes
|
||||||
|
>>>
|
||||||
|
>>> code = "import os; os.system('rm -rf /')"
|
||||||
|
>>> validate_plugin_code(code) # Raises PluginValidationError
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise PluginValidationError(f"Syntax error in plugin code: {e}")
|
||||||
|
|
||||||
|
# Check for dangerous imports
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
if alias.name not in SAFE_MODULES:
|
||||||
|
raise PluginValidationError(
|
||||||
|
f"Unsafe import: {alias.name}. "
|
||||||
|
f"Allowed modules: {SAFE_MODULES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module and node.module not in SAFE_MODULES:
|
||||||
|
# Allow submodules of safe modules
|
||||||
|
base_module = node.module.split('.')[0]
|
||||||
|
if base_module not in SAFE_MODULES:
|
||||||
|
raise PluginValidationError(
|
||||||
|
f"Unsafe import from: {node.module}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for dangerous function calls
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
if node.func.id in DANGEROUS_OPERATIONS:
|
||||||
|
if node.func.id == 'open' and allow_file_ops:
|
||||||
|
continue # Allow if explicitly permitted
|
||||||
|
raise PluginValidationError(
|
||||||
|
f"Dangerous operation: {node.func.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for attribute access to dangerous methods
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
|
||||||
|
if node.attr in dangerous_attrs and not allow_file_ops:
|
||||||
|
raise PluginValidationError(
|
||||||
|
f"Dangerous attribute access: {node.attr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Plugin code validation passed")
|
||||||
|
|
||||||
|
|
||||||
|
def check_hook_function_signature(func_code: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify that a hook function has the correct signature.
|
||||||
|
|
||||||
|
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_code: Function source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if signature is valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PluginValidationError: If signature is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(func_code)
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise PluginValidationError(f"Syntax error: {e}")
|
||||||
|
|
||||||
|
# Find function definition
|
||||||
|
func_def = None
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.FunctionDef):
|
||||||
|
func_def = node
|
||||||
|
break
|
||||||
|
|
||||||
|
if func_def is None:
|
||||||
|
raise PluginValidationError("No function definition found")
|
||||||
|
|
||||||
|
# Check that it takes exactly one argument (context)
|
||||||
|
if len(func_def.args.args) != 1:
|
||||||
|
raise PluginValidationError(
|
||||||
|
f"Hook function must take exactly 1 argument (context), "
|
||||||
|
f"got {len(func_def.args.args)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Hook function '{func_def.name}' signature is valid")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_plugin_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a plugin name to be safe for filesystem and imports.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Original plugin name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized name (lowercase, alphanumeric + underscore only)
|
||||||
|
"""
|
||||||
|
# Remove special characters, keep only alphanumeric and underscore
|
||||||
|
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
|
||||||
|
|
||||||
|
# Ensure it doesn't start with a number
|
||||||
|
if sanitized[0].isdigit():
|
||||||
|
sanitized = f"plugin_{sanitized}"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def get_imported_modules(code: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Extract all imported modules from code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Python source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of module names
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
except SyntaxError:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
modules = set()
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
modules.add(alias.name)
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module:
|
||||||
|
modules.add(node.module)
|
||||||
|
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_complexity(code: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate the cyclomatic complexity of code.
|
||||||
|
|
||||||
|
Higher values indicate more complex code with more branching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Python source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complexity estimate (1 = simple, 10+ = complex)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
except SyntaxError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
complexity = 1 # Base complexity
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
# Add 1 for each branching statement
|
||||||
|
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
|
||||||
|
complexity += 1
|
||||||
|
elif isinstance(node, ast.BoolOp):
|
||||||
|
complexity += len(node.values) - 1
|
||||||
|
|
||||||
|
return complexity
|
||||||
@@ -23,6 +23,8 @@ import pandas as pd
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
from optimization_engine.plugins import HookManager
|
||||||
|
|
||||||
|
|
||||||
class OptimizationRunner:
|
class OptimizationRunner:
|
||||||
"""
|
"""
|
||||||
@@ -68,6 +70,15 @@ class OptimizationRunner:
|
|||||||
self.output_dir = self.config_path.parent / 'optimization_results'
|
self.output_dir = self.config_path.parent / 'optimization_results'
|
||||||
self.output_dir.mkdir(exist_ok=True)
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize plugin/hook system
|
||||||
|
self.hook_manager = HookManager()
|
||||||
|
plugins_dir = Path(__file__).parent / 'plugins'
|
||||||
|
if plugins_dir.exists():
|
||||||
|
self.hook_manager.load_plugins_from_directory(plugins_dir)
|
||||||
|
summary = self.hook_manager.get_summary()
|
||||||
|
if summary['total_hooks'] > 0:
|
||||||
|
print(f"Loaded {summary['enabled_hooks']}/{summary['total_hooks']} plugins")
|
||||||
|
|
||||||
def _load_config(self) -> Dict[str, Any]:
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
"""Load and validate optimization configuration."""
|
"""Load and validate optimization configuration."""
|
||||||
with open(self.config_path, 'r') as f:
|
with open(self.config_path, 'r') as f:
|
||||||
@@ -311,6 +322,16 @@ class OptimizationRunner:
|
|||||||
int(dv['bounds'][1])
|
int(dv['bounds'][1])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Execute pre_solve hooks
|
||||||
|
pre_solve_context = {
|
||||||
|
'trial_number': trial.number,
|
||||||
|
'design_variables': design_vars,
|
||||||
|
'sim_file': self.config.get('sim_file', ''),
|
||||||
|
'working_dir': str(Path.cwd()),
|
||||||
|
'config': self.config
|
||||||
|
}
|
||||||
|
self.hook_manager.execute_hooks('pre_solve', pre_solve_context, fail_fast=False)
|
||||||
|
|
||||||
# 2. Update NX model with new parameters
|
# 2. Update NX model with new parameters
|
||||||
try:
|
try:
|
||||||
self.model_updater(design_vars)
|
self.model_updater(design_vars)
|
||||||
@@ -318,6 +339,15 @@ class OptimizationRunner:
|
|||||||
print(f"Error updating model: {e}")
|
print(f"Error updating model: {e}")
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
# Execute post_mesh hooks (after model update)
|
||||||
|
post_mesh_context = {
|
||||||
|
'trial_number': trial.number,
|
||||||
|
'design_variables': design_vars,
|
||||||
|
'sim_file': self.config.get('sim_file', ''),
|
||||||
|
'working_dir': str(Path.cwd())
|
||||||
|
}
|
||||||
|
self.hook_manager.execute_hooks('post_mesh', post_mesh_context, fail_fast=False)
|
||||||
|
|
||||||
# 3. Run simulation
|
# 3. Run simulation
|
||||||
try:
|
try:
|
||||||
result_path = self.simulation_runner()
|
result_path = self.simulation_runner()
|
||||||
@@ -325,6 +355,15 @@ class OptimizationRunner:
|
|||||||
print(f"Error running simulation: {e}")
|
print(f"Error running simulation: {e}")
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
# Execute post_solve hooks
|
||||||
|
post_solve_context = {
|
||||||
|
'trial_number': trial.number,
|
||||||
|
'design_variables': design_vars,
|
||||||
|
'result_path': str(result_path) if result_path else '',
|
||||||
|
'working_dir': str(Path.cwd())
|
||||||
|
}
|
||||||
|
self.hook_manager.execute_hooks('post_solve', post_solve_context, fail_fast=False)
|
||||||
|
|
||||||
# 4. Extract results with appropriate precision
|
# 4. Extract results with appropriate precision
|
||||||
extracted_results = {}
|
extracted_results = {}
|
||||||
for obj in self.config['objectives']:
|
for obj in self.config['objectives']:
|
||||||
@@ -362,6 +401,16 @@ class OptimizationRunner:
|
|||||||
print(f"Error extracting {const['name']}: {e}")
|
print(f"Error extracting {const['name']}: {e}")
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
# Execute post_extraction hooks
|
||||||
|
post_extraction_context = {
|
||||||
|
'trial_number': trial.number,
|
||||||
|
'design_variables': design_vars,
|
||||||
|
'extracted_results': extracted_results,
|
||||||
|
'result_path': str(result_path) if result_path else '',
|
||||||
|
'working_dir': str(Path.cwd())
|
||||||
|
}
|
||||||
|
self.hook_manager.execute_hooks('post_extraction', post_extraction_context, fail_fast=False)
|
||||||
|
|
||||||
# 5. Evaluate constraints
|
# 5. Evaluate constraints
|
||||||
for const in self.config.get('constraints', []):
|
for const in self.config.get('constraints', []):
|
||||||
value = extracted_results[const['name']]
|
value = extracted_results[const['name']]
|
||||||
@@ -389,6 +438,23 @@ class OptimizationRunner:
|
|||||||
else: # maximize
|
else: # maximize
|
||||||
total_objective -= weight * value
|
total_objective -= weight * value
|
||||||
|
|
||||||
|
# Execute custom_objective hooks (can modify total_objective)
|
||||||
|
custom_objective_context = {
|
||||||
|
'trial_number': trial.number,
|
||||||
|
'design_variables': design_vars,
|
||||||
|
'extracted_results': extracted_results,
|
||||||
|
'total_objective': total_objective,
|
||||||
|
'working_dir': str(Path.cwd())
|
||||||
|
}
|
||||||
|
custom_results = self.hook_manager.execute_hooks('custom_objective', custom_objective_context, fail_fast=False)
|
||||||
|
|
||||||
|
# Allow hooks to override objective value
|
||||||
|
for result in custom_results:
|
||||||
|
if result and 'total_objective' in result:
|
||||||
|
total_objective = result['total_objective']
|
||||||
|
print(f"Custom objective hook modified total_objective to {total_objective:.6f}")
|
||||||
|
break # Use first hook that provides override
|
||||||
|
|
||||||
# 7. Store results in history
|
# 7. Store results in history
|
||||||
history_entry = {
|
history_entry = {
|
||||||
'trial_number': trial.number,
|
'trial_number': trial.number,
|
||||||
|
|||||||
438
tests/test_plugin_system.py
Normal file
438
tests/test_plugin_system.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Plugin System
|
||||||
|
|
||||||
|
Validates hook registration, execution, validation, and integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
from optimization_engine.plugins import HookManager, Hook, HookPoint
|
||||||
|
from optimization_engine.plugins.validators import (
|
||||||
|
validate_plugin_code,
|
||||||
|
PluginValidationError,
|
||||||
|
check_hook_function_signature,
|
||||||
|
sanitize_plugin_name,
|
||||||
|
get_imported_modules,
|
||||||
|
estimate_complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHookRegistration:
|
||||||
|
"""Test hook registration and management."""
|
||||||
|
|
||||||
|
def test_register_simple_hook(self):
|
||||||
|
"""Test basic hook registration."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def my_hook(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
return {'status': 'success'}
|
||||||
|
|
||||||
|
hook = manager.register_hook(
|
||||||
|
hook_point='pre_solve',
|
||||||
|
function=my_hook,
|
||||||
|
description='Test hook',
|
||||||
|
name='test_hook'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hook.name == 'test_hook'
|
||||||
|
assert hook.hook_point == HookPoint.PRE_SOLVE
|
||||||
|
assert hook.enabled is True
|
||||||
|
|
||||||
|
def test_hook_priority_ordering(self):
|
||||||
|
"""Test that hooks execute in priority order."""
|
||||||
|
manager = HookManager()
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
def hook_high_priority(context):
|
||||||
|
execution_order.append('high')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook_low_priority(context):
|
||||||
|
execution_order.append('low')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Register low priority first, high priority second
|
||||||
|
manager.register_hook('pre_solve', hook_low_priority, 'Low', priority=200)
|
||||||
|
manager.register_hook('pre_solve', hook_high_priority, 'High', priority=10)
|
||||||
|
|
||||||
|
# Execute hooks
|
||||||
|
manager.execute_hooks('pre_solve', {'trial_number': 1})
|
||||||
|
|
||||||
|
# High priority should execute first
|
||||||
|
assert execution_order == ['high', 'low']
|
||||||
|
|
||||||
|
def test_disable_enable_hook(self):
|
||||||
|
"""Test disabling and enabling hooks."""
|
||||||
|
manager = HookManager()
|
||||||
|
execution_count = [0]
|
||||||
|
|
||||||
|
def counting_hook(context):
|
||||||
|
execution_count[0] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
hook = manager.register_hook(
|
||||||
|
'pre_solve',
|
||||||
|
counting_hook,
|
||||||
|
'Counter',
|
||||||
|
name='counter_hook'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute while enabled
|
||||||
|
manager.execute_hooks('pre_solve', {})
|
||||||
|
assert execution_count[0] == 1
|
||||||
|
|
||||||
|
# Disable and execute
|
||||||
|
manager.disable_hook('counter_hook')
|
||||||
|
manager.execute_hooks('pre_solve', {})
|
||||||
|
assert execution_count[0] == 1 # Should not increment
|
||||||
|
|
||||||
|
# Re-enable and execute
|
||||||
|
manager.enable_hook('counter_hook')
|
||||||
|
manager.execute_hooks('pre_solve', {})
|
||||||
|
assert execution_count[0] == 2
|
||||||
|
|
||||||
|
def test_remove_hook(self):
|
||||||
|
"""Test hook removal."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def test_hook(context):
|
||||||
|
return None
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', test_hook, 'Test', name='removable')
|
||||||
|
|
||||||
|
assert len(manager.get_hooks('pre_solve')) == 1
|
||||||
|
|
||||||
|
success = manager.remove_hook('removable')
|
||||||
|
assert success is True
|
||||||
|
assert len(manager.get_hooks('pre_solve')) == 0
|
||||||
|
|
||||||
|
# Try removing non-existent hook
|
||||||
|
success = manager.remove_hook('nonexistent')
|
||||||
|
assert success is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestHookExecution:
|
||||||
|
"""Test hook execution behavior."""
|
||||||
|
|
||||||
|
def test_hook_receives_context(self):
|
||||||
|
"""Test that hooks receive correct context."""
|
||||||
|
manager = HookManager()
|
||||||
|
received_context = {}
|
||||||
|
|
||||||
|
def context_checker(context: Dict[str, Any]):
|
||||||
|
received_context.update(context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', context_checker, 'Checker')
|
||||||
|
|
||||||
|
test_context = {
|
||||||
|
'trial_number': 42,
|
||||||
|
'design_variables': {'thickness': 5.0},
|
||||||
|
'sim_file': 'test.sim'
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.execute_hooks('pre_solve', test_context)
|
||||||
|
|
||||||
|
assert received_context['trial_number'] == 42
|
||||||
|
assert received_context['design_variables']['thickness'] == 5.0
|
||||||
|
|
||||||
|
def test_hook_return_values(self):
|
||||||
|
"""Test that hook return values are collected."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def hook1(context):
|
||||||
|
return {'result': 'hook1'}
|
||||||
|
|
||||||
|
def hook2(context):
|
||||||
|
return {'result': 'hook2'}
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', hook1, 'Hook 1')
|
||||||
|
manager.register_hook('pre_solve', hook2, 'Hook 2')
|
||||||
|
|
||||||
|
results = manager.execute_hooks('pre_solve', {})
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]['result'] == 'hook1'
|
||||||
|
assert results[1]['result'] == 'hook2'
|
||||||
|
|
||||||
|
def test_hook_error_handling_fail_fast(self):
|
||||||
|
"""Test error handling with fail_fast=True."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def failing_hook(context):
|
||||||
|
raise ValueError("Intentional error")
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', failing_hook, 'Failing')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Intentional error"):
|
||||||
|
manager.execute_hooks('pre_solve', {}, fail_fast=True)
|
||||||
|
|
||||||
|
def test_hook_error_handling_continue(self):
|
||||||
|
"""Test error handling with fail_fast=False."""
|
||||||
|
manager = HookManager()
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
def failing_hook(context):
|
||||||
|
execution_log.append('failing')
|
||||||
|
raise ValueError("Intentional error")
|
||||||
|
|
||||||
|
def successful_hook(context):
|
||||||
|
execution_log.append('successful')
|
||||||
|
return {'status': 'ok'}
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', failing_hook, 'Failing', priority=10)
|
||||||
|
manager.register_hook('pre_solve', successful_hook, 'Success', priority=20)
|
||||||
|
|
||||||
|
results = manager.execute_hooks('pre_solve', {}, fail_fast=False)
|
||||||
|
|
||||||
|
# Both hooks should have attempted execution
|
||||||
|
assert execution_log == ['failing', 'successful']
|
||||||
|
# First result is None (error), second is successful
|
||||||
|
assert results[0] is None
|
||||||
|
assert results[1]['status'] == 'ok'
|
||||||
|
|
||||||
|
def test_hook_history_tracking(self):
|
||||||
|
"""Test that hook execution history is tracked."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def test_hook(context):
|
||||||
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', test_hook, 'Test', name='tracked')
|
||||||
|
|
||||||
|
# Execute hooks multiple times
|
||||||
|
for i in range(3):
|
||||||
|
manager.execute_hooks('pre_solve', {'trial_number': i})
|
||||||
|
|
||||||
|
history = manager.get_history()
|
||||||
|
assert len(history) >= 3
|
||||||
|
|
||||||
|
# Check history contains success records
|
||||||
|
successful = [h for h in history if h['success']]
|
||||||
|
assert len(successful) >= 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeValidation:
|
||||||
|
"""Test plugin code validation."""
|
||||||
|
|
||||||
|
def test_safe_code_passes(self):
|
||||||
|
"""Test that safe code passes validation."""
|
||||||
|
safe_code = """
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
def my_hook(context):
|
||||||
|
x = context['design_variables']['thickness']
|
||||||
|
result = math.sqrt(x**2 + np.mean([1, 2, 3]))
|
||||||
|
return {'result': result}
|
||||||
|
"""
|
||||||
|
# Should not raise
|
||||||
|
validate_plugin_code(safe_code)
|
||||||
|
|
||||||
|
def test_dangerous_import_blocked(self):
|
||||||
|
"""Test that dangerous imports are blocked."""
|
||||||
|
dangerous_code = """
|
||||||
|
import os
|
||||||
|
|
||||||
|
def my_hook(context):
|
||||||
|
os.system('rm -rf /')
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
with pytest.raises(PluginValidationError, match="Unsafe import"):
|
||||||
|
validate_plugin_code(dangerous_code)
|
||||||
|
|
||||||
|
def test_dangerous_operation_blocked(self):
|
||||||
|
"""Test that dangerous operations are blocked."""
|
||||||
|
dangerous_code = """
|
||||||
|
def my_hook(context):
|
||||||
|
eval('malicious_code')
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
with pytest.raises(PluginValidationError, match="Dangerous operation"):
|
||||||
|
validate_plugin_code(dangerous_code)
|
||||||
|
|
||||||
|
def test_file_operations_with_permission(self):
|
||||||
|
"""Test that file operations work with allow_file_ops=True."""
|
||||||
|
code_with_file_ops = """
|
||||||
|
def my_hook(context):
|
||||||
|
with open('output.txt', 'w') as f:
|
||||||
|
f.write('test')
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
# Should raise without permission
|
||||||
|
with pytest.raises(PluginValidationError, match="Dangerous operation: open"):
|
||||||
|
validate_plugin_code(code_with_file_ops, allow_file_ops=False)
|
||||||
|
|
||||||
|
# Should pass with permission
|
||||||
|
validate_plugin_code(code_with_file_ops, allow_file_ops=True)
|
||||||
|
|
||||||
|
def test_syntax_error_detected(self):
|
||||||
|
"""Test that syntax errors are detected."""
|
||||||
|
bad_syntax = """
|
||||||
|
def my_hook(context)
|
||||||
|
return None # Missing colon
|
||||||
|
"""
|
||||||
|
with pytest.raises(PluginValidationError, match="Syntax error"):
|
||||||
|
validate_plugin_code(bad_syntax)
|
||||||
|
|
||||||
|
def test_hook_signature_validation(self):
|
||||||
|
"""Test hook function signature validation."""
|
||||||
|
# Valid signature
|
||||||
|
valid_code = """
|
||||||
|
def my_hook(context):
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
assert check_hook_function_signature(valid_code) is True
|
||||||
|
|
||||||
|
# Invalid: too many arguments
|
||||||
|
invalid_code = """
|
||||||
|
def my_hook(context, extra_arg):
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
with pytest.raises(PluginValidationError, match="must take exactly 1 argument"):
|
||||||
|
check_hook_function_signature(invalid_code)
|
||||||
|
|
||||||
|
# Invalid: no function
|
||||||
|
invalid_code = """
|
||||||
|
x = 5
|
||||||
|
"""
|
||||||
|
with pytest.raises(PluginValidationError, match="No function definition"):
|
||||||
|
check_hook_function_signature(invalid_code)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilityFunctions:
|
||||||
|
"""Test plugin utility functions."""
|
||||||
|
|
||||||
|
def test_sanitize_plugin_name(self):
|
||||||
|
"""Test plugin name sanitization."""
|
||||||
|
assert sanitize_plugin_name('My-Plugin!') == 'my_plugin_'
|
||||||
|
assert sanitize_plugin_name('123_plugin') == 'plugin_123_plugin'
|
||||||
|
assert sanitize_plugin_name('valid_name') == 'valid_name'
|
||||||
|
|
||||||
|
def test_get_imported_modules(self):
|
||||||
|
"""Test module extraction from code."""
|
||||||
|
code = """
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
"""
|
||||||
|
modules = get_imported_modules(code)
|
||||||
|
assert 'numpy' in modules
|
||||||
|
assert 'math' in modules
|
||||||
|
assert 'pathlib' in modules
|
||||||
|
assert 'typing' in modules
|
||||||
|
|
||||||
|
def test_estimate_complexity(self):
|
||||||
|
"""Test cyclomatic complexity estimation."""
|
||||||
|
simple_code = """
|
||||||
|
def simple(x):
|
||||||
|
return x + 1
|
||||||
|
"""
|
||||||
|
complex_code = """
|
||||||
|
def complex(x):
|
||||||
|
if x > 0:
|
||||||
|
for i in range(x):
|
||||||
|
if i % 2 == 0:
|
||||||
|
while i > 0:
|
||||||
|
i -= 1
|
||||||
|
return x
|
||||||
|
"""
|
||||||
|
simple_complexity = estimate_complexity(simple_code)
|
||||||
|
complex_complexity = estimate_complexity(complex_code)
|
||||||
|
|
||||||
|
assert simple_complexity == 1
|
||||||
|
assert complex_complexity > simple_complexity
|
||||||
|
|
||||||
|
|
||||||
|
class TestHookManager:
|
||||||
|
"""Test HookManager functionality."""
|
||||||
|
|
||||||
|
def test_get_summary(self):
|
||||||
|
"""Test hook system summary."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def hook1(context):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook2(context):
|
||||||
|
return None
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', hook1, 'Hook 1', name='hook1')
|
||||||
|
manager.register_hook('post_solve', hook2, 'Hook 2', name='hook2')
|
||||||
|
manager.disable_hook('hook2')
|
||||||
|
|
||||||
|
summary = manager.get_summary()
|
||||||
|
|
||||||
|
assert summary['total_hooks'] == 2
|
||||||
|
assert summary['enabled_hooks'] == 1
|
||||||
|
assert summary['disabled_hooks'] == 1
|
||||||
|
|
||||||
|
def test_clear_history(self):
|
||||||
|
"""Test clearing execution history."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def test_hook(context):
|
||||||
|
return None
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', test_hook, 'Test')
|
||||||
|
manager.execute_hooks('pre_solve', {})
|
||||||
|
|
||||||
|
assert len(manager.get_history()) > 0
|
||||||
|
|
||||||
|
manager.clear_history()
|
||||||
|
assert len(manager.get_history()) == 0
|
||||||
|
|
||||||
|
def test_hook_manager_repr(self):
|
||||||
|
"""Test HookManager string representation."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
def hook(context):
|
||||||
|
return None
|
||||||
|
|
||||||
|
manager.register_hook('pre_solve', hook, 'Test')
|
||||||
|
|
||||||
|
repr_str = repr(manager)
|
||||||
|
assert 'HookManager' in repr_str
|
||||||
|
assert 'hooks=1' in repr_str
|
||||||
|
assert 'enabled=1' in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginLoading:
|
||||||
|
"""Test plugin directory loading."""
|
||||||
|
|
||||||
|
def test_load_plugins_from_nonexistent_directory(self):
|
||||||
|
"""Test loading from non-existent directory."""
|
||||||
|
manager = HookManager()
|
||||||
|
# Should not raise, just log warning
|
||||||
|
manager.load_plugins_from_directory(Path('/nonexistent/path'))
|
||||||
|
|
||||||
|
def test_plugin_registration_function(self):
|
||||||
|
"""Test that plugins can register hooks via register_hooks()."""
|
||||||
|
manager = HookManager()
|
||||||
|
|
||||||
|
# Simulate what a plugin file would contain
|
||||||
|
def register_hooks(hook_manager):
|
||||||
|
def my_plugin_hook(context):
|
||||||
|
return {'plugin': 'loaded'}
|
||||||
|
|
||||||
|
hook_manager.register_hook(
|
||||||
|
'pre_solve',
|
||||||
|
my_plugin_hook,
|
||||||
|
'Plugin hook',
|
||||||
|
name='plugin_hook'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the registration function
|
||||||
|
register_hooks(manager)
|
||||||
|
|
||||||
|
# Verify hook was registered
|
||||||
|
hooks = manager.get_hooks('pre_solve')
|
||||||
|
assert len(hooks) == 1
|
||||||
|
assert hooks[0].name == 'plugin_hook'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user