feat: Implement Phase 1 - Plugin & Hook System

Core plugin architecture for LLM-driven optimization:

New Features:
- Hook system with 6 lifecycle points (pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objectives)
- HookManager for centralized registration and execution
- Code validation with AST-based safety checks
- Feature registry (JSON) for LLM capability discovery
- Example plugin: log_trial_start
- 23 comprehensive tests (all passing)

Integration:
- OptimizationRunner now loads plugins automatically
- Hooks execute at 5 points in optimization loop
- Custom objectives can override total_objective via hooks

Safety:
- Module whitelist (numpy, scipy, pandas, optuna, pyNastran)
- Dangerous operation blocking (eval, exec, os.system, subprocess)
- Optional file operation permission flag

Files Added:
- optimization_engine/plugins/__init__.py
- optimization_engine/plugins/hooks.py
- optimization_engine/plugins/hook_manager.py
- optimization_engine/plugins/validators.py
- optimization_engine/feature_registry.json
- optimization_engine/plugins/pre_solve/log_trial_start.py
- tests/test_plugin_system.py (23 tests)

Files Modified:
- optimization_engine/runner.py (added hook integration)

Ready for Phase 2: LLM interface layer

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-15 14:46:49 -05:00
parent 0ce9ddf3e2
commit a24e3f750c
14 changed files with 1473 additions and 0 deletions

View File

@@ -0,0 +1,243 @@
{
"version": "1.0.0",
"last_updated": "2025-01-15",
"description": "Registry of all Atomizer capabilities for LLM discovery and usage",
"core_features": {
"optimization": {
"description": "Core optimization engine using Optuna",
"module": "optimization_engine.runner",
"capabilities": [
"Multi-objective optimization with weighted sum",
"TPE (Tree-structured Parzen Estimator) sampler",
"CMA-ES sampler",
"Gaussian Process sampler",
"50-trial default with 20 startup trials",
"Automatic checkpoint and resume",
"SQLite-based study persistence"
],
"usage": "python examples/test_journal_optimization.py",
"llm_hint": "Use this for Bayesian optimization with NX simulations"
},
"nx_integration": {
"description": "Siemens NX simulation automation via journal scripts",
"module": "optimization_engine.nx_solver",
"capabilities": [
"Update CAD expressions via NXOpen",
"Execute NX Nastran solver",
"Extract OP2 results (stress, displacement)",
"Extract mass properties",
"Precision control (4 decimals for mm/degrees/MPa)"
],
"usage": "from optimization_engine.nx_solver import run_nx_simulation",
"llm_hint": "Use for running FEA simulations and extracting results"
},
"result_extraction": {
"description": "Extract metrics from simulation results",
"module": "optimization_engine.result_extractors",
"extractors": {
"stress_extractor": {
"description": "Extract stress data from OP2 files",
"metrics": ["max_von_mises", "mean_von_mises", "max_principal"],
"file_type": "OP2",
"usage": "Returns stress in MPa"
},
"displacement_extractor": {
"description": "Extract displacement data from OP2 files",
"metrics": ["max_displacement", "mean_displacement"],
"file_type": "OP2",
"usage": "Returns displacement in mm"
},
"mass_extractor": {
"description": "Extract mass properties",
"metrics": ["total_mass", "center_of_gravity"],
"file_type": "NX Part",
"usage": "Returns mass in kg"
}
},
"llm_hint": "Use extractors to define objectives and constraints"
}
},
"plugin_system": {
"description": "Extensible hook system for custom functionality",
"module": "optimization_engine.plugins",
"version": "1.0.0",
"hook_points": {
"pre_mesh": {
"description": "Execute before meshing operations",
"context": ["trial_number", "design_variables", "sim_file", "working_dir"],
"use_cases": [
"Modify geometry based on parameters",
"Set up boundary conditions",
"Configure mesh settings"
]
},
"post_mesh": {
"description": "Execute after meshing, before solve",
"context": ["trial_number", "mesh_info", "element_count", "node_count"],
"use_cases": [
"Validate mesh quality",
"Export mesh for visualization",
"Log mesh statistics"
]
},
"pre_solve": {
"description": "Execute before solver launch",
"context": ["trial_number", "design_variables", "solver_settings"],
"use_cases": [
"Log trial parameters",
"Modify solver settings",
"Set up custom load cases"
]
},
"post_solve": {
"description": "Execute after solve, before result extraction",
"context": ["trial_number", "solve_status", "output_files"],
"use_cases": [
"Check solver convergence",
"Post-process results",
"Generate visualizations"
]
},
"post_extraction": {
"description": "Execute after result extraction",
"context": ["trial_number", "extracted_results", "objectives", "constraints"],
"use_cases": [
"Calculate custom metrics",
"Combine multiple objectives (RSS)",
"Apply penalty functions"
]
},
"custom_objective": {
"description": "Define custom objective functions",
"context": ["extracted_results", "design_variables"],
"use_cases": [
"RSS of stress and displacement",
"Weighted multi-criteria",
"Conditional objectives"
]
}
},
"api": {
"register_hook": {
"description": "Register a new hook function",
"signature": "hook_manager.register_hook(hook_point, function, description, name=None, priority=100)",
"parameters": {
"hook_point": "One of: pre_mesh, post_mesh, pre_solve, post_solve, post_extraction, custom_objective",
"function": "Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]",
"description": "Human-readable description",
"priority": "Execution order (lower = earlier)"
},
"example": "See optimization_engine/plugins/pre_solve/log_trial_start.py"
},
"execute_hooks": {
"description": "Execute all hooks at a specific point",
"signature": "hook_manager.execute_hooks(hook_point, context, fail_fast=False)",
"returns": "List of hook results"
}
},
"validators": {
"validate_plugin_code": {
"description": "Validate plugin code for safety",
"checks": [
"Syntax errors",
"Dangerous imports (os.system, subprocess, etc.)",
"File operations (optional allow)",
"Function signature correctness"
],
"safe_modules": ["math", "numpy", "scipy", "pandas", "pathlib", "json", "optuna", "pyNastran"],
"llm_hint": "Always validate LLM-generated code before execution"
}
}
},
"design_variables": {
"description": "Parametric CAD variables to optimize",
"schema": {
"name": "Unique identifier",
"expression_name": "NX expression name",
"min": "Lower bound (float)",
"max": "Upper bound (float)",
"units": "Unit system (mm, degrees, etc.)"
},
"example": {
"name": "wall_thickness",
"expression_name": "wall_thickness",
"min": 3.0,
"max": 8.0,
"units": "mm"
}
},
"objectives": {
"description": "Metrics to minimize or maximize",
"schema": {
"name": "Unique identifier",
"extractor": "Result extractor to use",
"metric": "Specific metric from extractor",
"direction": "minimize or maximize",
"weight": "Importance (for multi-objective)",
"units": "Unit system"
},
"example": {
"name": "max_stress",
"extractor": "stress_extractor",
"metric": "max_von_mises",
"direction": "minimize",
"weight": 1.0,
"units": "MPa"
}
},
"constraints": {
"description": "Limits on simulation outputs",
"schema": {
"name": "Unique identifier",
"extractor": "Result extractor to use",
"metric": "Specific metric",
"type": "upper_bound or lower_bound",
"limit": "Constraint value",
"units": "Unit system"
},
"example": {
"name": "max_displacement_limit",
"extractor": "displacement_extractor",
"metric": "max_displacement",
"type": "upper_bound",
"limit": 1.0,
"units": "mm"
}
},
"examples": {
"bracket_optimization": {
"description": "Minimize stress on a bracket by varying wall thickness",
"location": "examples/bracket/",
"design_variables": ["wall_thickness"],
"objectives": ["max_von_mises"],
"trials": 50,
"typical_runtime": "2-3 hours",
"llm_hint": "Good template for single-objective structural optimization"
}
},
"llm_guidelines": {
"code_generation": {
"hook_template": "Always include: function signature with context dict, docstring, return dict",
"validation": "Use validate_plugin_code() before registration",
"error_handling": "Wrap in try-except, log errors, return None on failure"
},
"natural_language_mapping": {
"minimize stress": "objective with direction='minimize', extractor='stress_extractor'",
"vary thickness 3-8mm": "design_variable with min=3.0, max=8.0, units='mm'",
"displacement < 1mm": "constraint with type='upper_bound', limit=1.0",
"RSS of stress and displacement": "custom_objective hook with sqrt(stress² + displacement²)"
}
}
}

View File

@@ -0,0 +1,32 @@
"""
Atomizer Plugin System
Enables extensibility through hooks at various stages of the optimization lifecycle.
Hook Points:
- pre_mesh: Before meshing operations
- post_mesh: After meshing, before solve
- pre_solve: Before solver execution
- post_solve: After solve, before result extraction
- post_extraction: After result extraction, before objective calculation
- custom_objectives: Custom objective/constraint functions
Usage:
from optimization_engine.plugins import HookManager
hook_manager = HookManager()
hook_manager.register_hook('pre_solve', my_custom_function)
hook_manager.execute_hooks('pre_solve', context={'trial_number': 5})
"""
from .hooks import Hook, HookPoint
from .hook_manager import HookManager
from .validators import validate_plugin_code, PluginValidationError
__all__ = [
'Hook',
'HookPoint',
'HookManager',
'validate_plugin_code',
'PluginValidationError'
]

View File

@@ -0,0 +1 @@
# custom_objectives hooks

View File

@@ -0,0 +1,303 @@
"""
Hook Manager for Atomizer Plugin System
Manages registration, execution, and lifecycle of hooks.
"""
from typing import Dict, List, Callable, Any, Optional
from pathlib import Path
import logging
import importlib.util
import json
from .hooks import Hook, HookPoint, HookContext
logger = logging.getLogger(__name__)
class HookManager:
"""
Central manager for all hooks in the optimization system.
Example:
>>> manager = HookManager()
>>> def my_hook(context):
>>> print(f"Trial {context['trial_number']} starting")
>>> return {'status': 'logged'}
>>>
>>> manager.register_hook('pre_solve', my_hook, 'Log trial start')
>>> manager.execute_hooks('pre_solve', {'trial_number': 5})
"""
def __init__(self):
self._hooks: Dict[HookPoint, List[Hook]] = {point: [] for point in HookPoint}
self._hook_history: List[Dict[str, Any]] = []
logger.info("HookManager initialized")
def register_hook(
self,
hook_point: str | HookPoint,
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]],
description: str,
name: Optional[str] = None,
priority: int = 100,
enabled: bool = True
) -> Hook:
"""
Register a new hook function.
Args:
hook_point: When to execute ('pre_solve', 'post_mesh', etc.)
function: Callable that takes context dict, returns optional dict
description: Human-readable description
name: Unique name (auto-generated if not provided)
priority: Execution order (lower = earlier)
enabled: Whether hook is active
Returns:
The created Hook object
Raises:
ValueError: If hook_point is invalid
"""
# Convert string to HookPoint enum
if isinstance(hook_point, str):
try:
hook_point = HookPoint(hook_point)
except ValueError:
valid_points = [p.value for p in HookPoint]
raise ValueError(
f"Invalid hook_point '{hook_point}'. "
f"Valid options: {valid_points}"
)
# Auto-generate name if not provided
if name is None:
name = f"{hook_point.value}_{function.__name__}_{len(self._hooks[hook_point])}"
# Create hook
hook = Hook(
name=name,
hook_point=hook_point,
function=function,
description=description,
priority=priority,
enabled=enabled
)
# Add to registry and sort by priority
self._hooks[hook_point].append(hook)
self._hooks[hook_point].sort(key=lambda h: h.priority)
logger.info(f"Registered hook: {hook}")
return hook
def execute_hooks(
self,
hook_point: str | HookPoint,
context: Dict[str, Any],
fail_fast: bool = False
) -> List[Optional[Dict[str, Any]]]:
"""
Execute all hooks registered at a specific point.
Args:
hook_point: The execution point
context: Data to pass to hooks
fail_fast: If True, stop on first error. If False, log and continue.
Returns:
List of results from each hook (None if hook returned nothing)
Raises:
Exception: If fail_fast=True and a hook fails
"""
# Convert string to enum
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
hooks = self._hooks[hook_point]
if not hooks:
logger.debug(f"No hooks registered for {hook_point.value}")
return []
logger.info(f"Executing {len(hooks)} hooks at {hook_point.value}")
results = []
for hook in hooks:
try:
result = hook.execute(context)
results.append(result)
# Record in history
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': True,
'trial_number': context.get('trial_number', -1)
})
except Exception as e:
logger.error(f"Hook '{hook.name}' failed: {e}")
# Record failure
self._hook_history.append({
'hook_name': hook.name,
'hook_point': hook_point.value,
'success': False,
'error': str(e),
'trial_number': context.get('trial_number', -1)
})
if fail_fast:
raise
else:
results.append(None)
return results
def load_plugins_from_directory(self, directory: Path):
"""
Auto-discover and load plugins from a directory.
Expected structure:
plugins/
pre_mesh/
my_plugin.py # Contains register_hooks(manager) function
post_solve/
another_plugin.py
Args:
directory: Path to plugins directory
"""
if not directory.exists():
logger.warning(f"Plugins directory not found: {directory}")
return
logger.info(f"Loading plugins from {directory}")
for hook_point in HookPoint:
hook_dir = directory / hook_point.value
if not hook_dir.exists():
continue
for plugin_file in hook_dir.glob("*.py"):
if plugin_file.name.startswith("_"):
continue # Skip __init__.py and private files
try:
self._load_plugin_file(plugin_file)
except Exception as e:
logger.error(f"Failed to load plugin {plugin_file}: {e}")
def _load_plugin_file(self, plugin_file: Path):
"""Load a single plugin file and call its register_hooks function."""
spec = importlib.util.spec_from_file_location(plugin_file.stem, plugin_file)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for {plugin_file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Call register_hooks if it exists
if hasattr(module, 'register_hooks'):
module.register_hooks(self)
logger.info(f"Loaded plugin: {plugin_file.stem}")
else:
logger.warning(f"Plugin {plugin_file} has no register_hooks() function")
def get_hooks(self, hook_point: str | HookPoint) -> List[Hook]:
"""Get all hooks registered at a specific point."""
if isinstance(hook_point, str):
hook_point = HookPoint(hook_point)
return self._hooks[hook_point].copy()
def remove_hook(self, name: str) -> bool:
"""
Remove a hook by name.
Returns:
True if hook was found and removed, False otherwise
"""
for point, hooks in self._hooks.items():
for i, hook in enumerate(hooks):
if hook.name == name:
del hooks[i]
logger.info(f"Removed hook: {name}")
return True
return False
def enable_hook(self, name: str):
"""Enable a hook by name."""
self._set_hook_enabled(name, True)
def disable_hook(self, name: str):
"""Disable a hook by name."""
self._set_hook_enabled(name, False)
def _set_hook_enabled(self, name: str, enabled: bool):
"""Set enabled status for a hook."""
for hooks in self._hooks.values():
for hook in hooks:
if hook.name == name:
hook.enabled = enabled
status = "enabled" if enabled else "disabled"
logger.info(f"Hook '{name}' {status}")
return
logger.warning(f"Hook '{name}' not found")
def get_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Get hook execution history.
Args:
limit: Maximum number of records to return (most recent)
Returns:
List of execution records
"""
if limit:
return self._hook_history[-limit:]
return self._hook_history.copy()
def clear_history(self):
"""Clear all execution history."""
self._hook_history.clear()
logger.info("Hook history cleared")
def get_summary(self) -> Dict[str, Any]:
"""
Get a summary of the hook system state.
Returns:
Dictionary with hook counts, enabled status, etc.
"""
total_hooks = sum(len(hooks) for hooks in self._hooks.values())
enabled_hooks = sum(
sum(1 for h in hooks if h.enabled)
for hooks in self._hooks.values()
)
by_point = {
point.value: {
'total': len(hooks),
'enabled': sum(1 for h in hooks if h.enabled),
'names': [h.name for h in hooks]
}
for point, hooks in self._hooks.items()
}
return {
'total_hooks': total_hooks,
'enabled_hooks': enabled_hooks,
'disabled_hooks': total_hooks - enabled_hooks,
'by_hook_point': by_point,
'history_records': len(self._hook_history)
}
def __repr__(self) -> str:
summary = self.get_summary()
return (
f"HookManager(hooks={summary['total_hooks']}, "
f"enabled={summary['enabled_hooks']})"
)

View File

@@ -0,0 +1,115 @@
"""
Core Hook System for Atomizer
Defines hook points in the optimization lifecycle and hook registration mechanism.
"""
from enum import Enum
from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
class HookPoint(Enum):
"""Enumeration of available hook points in the optimization lifecycle."""
PRE_MESH = "pre_mesh" # Before meshing
POST_MESH = "post_mesh" # After meshing, before solve
PRE_SOLVE = "pre_solve" # Before solver execution
POST_SOLVE = "post_solve" # After solve, before extraction
POST_EXTRACTION = "post_extraction" # After result extraction
CUSTOM_OBJECTIVE = "custom_objective" # Custom objective functions
@dataclass
class Hook:
"""
Represents a single hook function to be executed at a specific point.
Attributes:
name: Unique identifier for this hook
hook_point: When this hook should execute (HookPoint enum)
function: The callable to execute
description: Human-readable description of what this hook does
priority: Execution order (lower = earlier, default=100)
enabled: Whether this hook is currently active
"""
name: str
hook_point: HookPoint
function: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
description: str
priority: int = 100
enabled: bool = True
def execute(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Execute this hook with the given context.
Args:
context: Dictionary containing relevant data for this hook point
Common keys:
- trial_number: Current trial number
- design_variables: Current design variable values
- sim_file: Path to simulation file
- working_dir: Current working directory
Returns:
Optional dictionary with results or modifications to context
Raises:
Exception: Any exception from the hook function is logged and re-raised
"""
if not self.enabled:
logger.debug(f"Hook '{self.name}' is disabled, skipping")
return None
try:
logger.info(f"Executing hook '{self.name}' at {self.hook_point.value}")
result = self.function(context)
logger.debug(f"Hook '{self.name}' completed successfully")
return result
except Exception as e:
logger.error(f"Hook '{self.name}' failed: {e}", exc_info=True)
raise
def __repr__(self) -> str:
status = "enabled" if self.enabled else "disabled"
return f"Hook(name='{self.name}', point={self.hook_point.value}, priority={self.priority}, {status})"
class HookContext:
"""
Context object passed to hooks containing all relevant data.
This is a convenience wrapper around a dictionary that provides
both dict-like access and attribute access.
"""
def __init__(self, **kwargs):
self._data = kwargs
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, value: Any):
self._data[key] = value
def __contains__(self, key: str) -> bool:
return key in self._data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def update(self, other: Dict[str, Any]):
self._data.update(other)
def to_dict(self) -> Dict[str, Any]:
return self._data.copy()
def __repr__(self) -> str:
keys = list(self._data.keys())
return f"HookContext({', '.join(keys)})"

View File

@@ -0,0 +1 @@
# post_extraction hooks

View File

@@ -0,0 +1 @@
# post_mesh hooks

View File

@@ -0,0 +1 @@
# post_solve hooks

View File

@@ -0,0 +1 @@
# pre_mesh hooks

View File

@@ -0,0 +1 @@
# pre_solve hooks

View File

@@ -0,0 +1,56 @@
"""
Example Plugin: Log Trial Start
Simple pre-solve hook that logs trial information.
This demonstrates the plugin API and serves as a template for custom hooks.
"""
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
def trial_start_logger(context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Log trial information before solver execution.
Args:
context: Hook context containing:
- trial_number: Current trial number
- design_variables: Dict of variable values
- sim_file: Path to simulation file
Returns:
Dict with logging status
"""
trial_num = context.get('trial_number', '?')
design_vars = context.get('design_variables', {})
logger.info("=" * 60)
logger.info(f"TRIAL {trial_num} STARTING")
logger.info("=" * 60)
for var_name, var_value in design_vars.items():
logger.info(f" {var_name}: {var_value:.4f}")
return {'logged': True, 'trial': trial_num}
def register_hooks(hook_manager):
"""
Register this plugin's hooks with the manager.
This function is called automatically when the plugin is loaded.
Args:
hook_manager: The HookManager instance
"""
hook_manager.register_hook(
hook_point='pre_solve',
function=trial_start_logger,
description='Log trial number and design variables before solve',
name='log_trial_start',
priority=10 # Run early
)

View File

@@ -0,0 +1,214 @@
"""
Plugin Code Validators
Ensures safety and correctness of plugin code before execution.
"""
import ast
import re
from typing import List, Set
import logging
logger = logging.getLogger(__name__)
class PluginValidationError(Exception):
"""Raised when plugin code fails validation."""
pass
# Whitelist of safe modules for plugin imports
SAFE_MODULES = {
'math', 'numpy', 'scipy', 'pandas',
'pathlib', 'json', 'csv', 'datetime',
'logging', 'typing', 'dataclasses',
'optuna', 'pyNastran'
}
# Blacklist of dangerous operations
DANGEROUS_OPERATIONS = {
'eval', 'exec', 'compile', '__import__',
'open', # File operations should be explicit
'subprocess', 'os.system', 'os.popen',
'shutil.rmtree', # Dangerous file operations
}
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
"""
Validate plugin code for safety before execution.
Args:
code: Python source code to validate
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
Raises:
PluginValidationError: If code contains unsafe operations
Example:
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
>>> validate_plugin_code(code) # Passes
>>>
>>> code = "import os; os.system('rm -rf /')"
>>> validate_plugin_code(code) # Raises PluginValidationError
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error in plugin code: {e}")
# Check for dangerous imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import: {alias.name}. "
f"Allowed modules: {SAFE_MODULES}"
)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module not in SAFE_MODULES:
# Allow submodules of safe modules
base_module = node.module.split('.')[0]
if base_module not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import from: {node.module}"
)
# Check for dangerous function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in DANGEROUS_OPERATIONS:
if node.func.id == 'open' and allow_file_ops:
continue # Allow if explicitly permitted
raise PluginValidationError(
f"Dangerous operation: {node.func.id}"
)
# Check for attribute access to dangerous methods
elif isinstance(node, ast.Attribute):
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
if node.attr in dangerous_attrs and not allow_file_ops:
raise PluginValidationError(
f"Dangerous attribute access: {node.attr}"
)
logger.info("Plugin code validation passed")
def check_hook_function_signature(func_code: str) -> bool:
"""
Verify that a hook function has the correct signature.
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
Args:
func_code: Function source code
Returns:
True if signature is valid
Raises:
PluginValidationError: If signature is invalid
"""
try:
tree = ast.parse(func_code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error: {e}")
# Find function definition
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
raise PluginValidationError("No function definition found")
# Check that it takes exactly one argument (context)
if len(func_def.args.args) != 1:
raise PluginValidationError(
f"Hook function must take exactly 1 argument (context), "
f"got {len(func_def.args.args)}"
)
logger.info(f"Hook function '{func_def.name}' signature is valid")
return True
def sanitize_plugin_name(name: str) -> str:
"""
Sanitize a plugin name to be safe for filesystem and imports.
Args:
name: Original plugin name
Returns:
Sanitized name (lowercase, alphanumeric + underscore only)
"""
# Remove special characters, keep only alphanumeric and underscore
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
# Ensure it doesn't start with a number
if sanitized[0].isdigit():
sanitized = f"plugin_{sanitized}"
return sanitized
def get_imported_modules(code: str) -> Set[str]:
"""
Extract all imported modules from code.
Args:
code: Python source code
Returns:
Set of module names
"""
try:
tree = ast.parse(code)
except SyntaxError:
return set()
modules = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
modules.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.add(node.module)
return modules
def estimate_complexity(code: str) -> int:
"""
Estimate the cyclomatic complexity of code.
Higher values indicate more complex code with more branching.
Args:
code: Python source code
Returns:
Complexity estimate (1 = simple, 10+ = complex)
"""
try:
tree = ast.parse(code)
except SyntaxError:
return 0
complexity = 1 # Base complexity
for node in ast.walk(tree):
# Add 1 for each branching statement
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
return complexity

View File

@@ -23,6 +23,8 @@ import pandas as pd
from datetime import datetime
import pickle
from optimization_engine.plugins import HookManager
class OptimizationRunner:
"""
@@ -68,6 +70,15 @@ class OptimizationRunner:
self.output_dir = self.config_path.parent / 'optimization_results'
self.output_dir.mkdir(exist_ok=True)
# Initialize plugin/hook system
self.hook_manager = HookManager()
plugins_dir = Path(__file__).parent / 'plugins'
if plugins_dir.exists():
self.hook_manager.load_plugins_from_directory(plugins_dir)
summary = self.hook_manager.get_summary()
if summary['total_hooks'] > 0:
print(f"Loaded {summary['enabled_hooks']}/{summary['total_hooks']} plugins")
def _load_config(self) -> Dict[str, Any]:
"""Load and validate optimization configuration."""
with open(self.config_path, 'r') as f:
@@ -311,6 +322,16 @@ class OptimizationRunner:
int(dv['bounds'][1])
)
# Execute pre_solve hooks
pre_solve_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'sim_file': self.config.get('sim_file', ''),
'working_dir': str(Path.cwd()),
'config': self.config
}
self.hook_manager.execute_hooks('pre_solve', pre_solve_context, fail_fast=False)
# 2. Update NX model with new parameters
try:
self.model_updater(design_vars)
@@ -318,6 +339,15 @@ class OptimizationRunner:
print(f"Error updating model: {e}")
raise optuna.TrialPruned()
# Execute post_mesh hooks (after model update)
post_mesh_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'sim_file': self.config.get('sim_file', ''),
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_mesh', post_mesh_context, fail_fast=False)
# 3. Run simulation
try:
result_path = self.simulation_runner()
@@ -325,6 +355,15 @@ class OptimizationRunner:
print(f"Error running simulation: {e}")
raise optuna.TrialPruned()
# Execute post_solve hooks
post_solve_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'result_path': str(result_path) if result_path else '',
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_solve', post_solve_context, fail_fast=False)
# 4. Extract results with appropriate precision
extracted_results = {}
for obj in self.config['objectives']:
@@ -362,6 +401,16 @@ class OptimizationRunner:
print(f"Error extracting {const['name']}: {e}")
raise optuna.TrialPruned()
# Execute post_extraction hooks
post_extraction_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'extracted_results': extracted_results,
'result_path': str(result_path) if result_path else '',
'working_dir': str(Path.cwd())
}
self.hook_manager.execute_hooks('post_extraction', post_extraction_context, fail_fast=False)
# 5. Evaluate constraints
for const in self.config.get('constraints', []):
value = extracted_results[const['name']]
@@ -389,6 +438,23 @@ class OptimizationRunner:
else: # maximize
total_objective -= weight * value
# Execute custom_objective hooks (can modify total_objective)
custom_objective_context = {
'trial_number': trial.number,
'design_variables': design_vars,
'extracted_results': extracted_results,
'total_objective': total_objective,
'working_dir': str(Path.cwd())
}
custom_results = self.hook_manager.execute_hooks('custom_objective', custom_objective_context, fail_fast=False)
# Allow hooks to override objective value
for result in custom_results:
if result and 'total_objective' in result:
total_objective = result['total_objective']
print(f"Custom objective hook modified total_objective to {total_objective:.6f}")
break # Use first hook that provides override
# 7. Store results in history
history_entry = {
'trial_number': trial.number,