Files
Atomizer/optimization_engine/plugins/validators.py

215 lines
6.0 KiB
Python
Raw Normal View History

"""
Plugin Code Validators
Ensures safety and correctness of plugin code before execution.
"""
import ast
import re
from typing import List, Set
import logging
logger = logging.getLogger(__name__)
class PluginValidationError(Exception):
"""Raised when plugin code fails validation."""
pass
# Whitelist of safe modules for plugin imports
SAFE_MODULES = {
'math', 'numpy', 'scipy', 'pandas',
'pathlib', 'json', 'csv', 'datetime',
'logging', 'typing', 'dataclasses',
'optuna', 'pyNastran'
}
# Blacklist of dangerous operations
DANGEROUS_OPERATIONS = {
'eval', 'exec', 'compile', '__import__',
'open', # File operations should be explicit
'subprocess', 'os.system', 'os.popen',
'shutil.rmtree', # Dangerous file operations
}
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
"""
Validate plugin code for safety before execution.
Args:
code: Python source code to validate
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
Raises:
PluginValidationError: If code contains unsafe operations
Example:
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
>>> validate_plugin_code(code) # Passes
>>>
>>> code = "import os; os.system('rm -rf /')"
>>> validate_plugin_code(code) # Raises PluginValidationError
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error in plugin code: {e}")
# Check for dangerous imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import: {alias.name}. "
f"Allowed modules: {SAFE_MODULES}"
)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module not in SAFE_MODULES:
# Allow submodules of safe modules
base_module = node.module.split('.')[0]
if base_module not in SAFE_MODULES:
raise PluginValidationError(
f"Unsafe import from: {node.module}"
)
# Check for dangerous function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in DANGEROUS_OPERATIONS:
if node.func.id == 'open' and allow_file_ops:
continue # Allow if explicitly permitted
raise PluginValidationError(
f"Dangerous operation: {node.func.id}"
)
# Check for attribute access to dangerous methods
elif isinstance(node, ast.Attribute):
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
if node.attr in dangerous_attrs and not allow_file_ops:
raise PluginValidationError(
f"Dangerous attribute access: {node.attr}"
)
logger.info("Plugin code validation passed")
def check_hook_function_signature(func_code: str) -> bool:
"""
Verify that a hook function has the correct signature.
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
Args:
func_code: Function source code
Returns:
True if signature is valid
Raises:
PluginValidationError: If signature is invalid
"""
try:
tree = ast.parse(func_code)
except SyntaxError as e:
raise PluginValidationError(f"Syntax error: {e}")
# Find function definition
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
raise PluginValidationError("No function definition found")
# Check that it takes exactly one argument (context)
if len(func_def.args.args) != 1:
raise PluginValidationError(
f"Hook function must take exactly 1 argument (context), "
f"got {len(func_def.args.args)}"
)
logger.info(f"Hook function '{func_def.name}' signature is valid")
return True
def sanitize_plugin_name(name: str) -> str:
"""
Sanitize a plugin name to be safe for filesystem and imports.
Args:
name: Original plugin name
Returns:
Sanitized name (lowercase, alphanumeric + underscore only)
"""
# Remove special characters, keep only alphanumeric and underscore
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
# Ensure it doesn't start with a number
if sanitized[0].isdigit():
sanitized = f"plugin_{sanitized}"
return sanitized
def get_imported_modules(code: str) -> Set[str]:
"""
Extract all imported modules from code.
Args:
code: Python source code
Returns:
Set of module names
"""
try:
tree = ast.parse(code)
except SyntaxError:
return set()
modules = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
modules.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.add(node.module)
return modules
def estimate_complexity(code: str) -> int:
"""
Estimate the cyclomatic complexity of code.
Higher values indicate more complex code with more branching.
Args:
code: Python source code
Returns:
Complexity estimate (1 = simple, 10+ = complex)
"""
try:
tree = ast.parse(code)
except SyntaxError:
return 0
complexity = 1 # Base complexity
for node in ast.walk(tree):
# Add 1 for each branching statement
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
return complexity