215 lines
6.0 KiB
Python
215 lines
6.0 KiB
Python
|
|
"""
|
||
|
|
Plugin Code Validators
|
||
|
|
|
||
|
|
Ensures safety and correctness of plugin code before execution.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import ast
|
||
|
|
import re
|
||
|
|
from typing import List, Set
|
||
|
|
import logging
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class PluginValidationError(Exception):
|
||
|
|
"""Raised when plugin code fails validation."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
# Whitelist of safe modules for plugin imports
|
||
|
|
SAFE_MODULES = {
|
||
|
|
'math', 'numpy', 'scipy', 'pandas',
|
||
|
|
'pathlib', 'json', 'csv', 'datetime',
|
||
|
|
'logging', 'typing', 'dataclasses',
|
||
|
|
'optuna', 'pyNastran'
|
||
|
|
}
|
||
|
|
|
||
|
|
# Blacklist of dangerous operations
|
||
|
|
DANGEROUS_OPERATIONS = {
|
||
|
|
'eval', 'exec', 'compile', '__import__',
|
||
|
|
'open', # File operations should be explicit
|
||
|
|
'subprocess', 'os.system', 'os.popen',
|
||
|
|
'shutil.rmtree', # Dangerous file operations
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None:
|
||
|
|
"""
|
||
|
|
Validate plugin code for safety before execution.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
code: Python source code to validate
|
||
|
|
allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.)
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
PluginValidationError: If code contains unsafe operations
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> code = "def my_hook(context): return {'result': context['x'] * 2}"
|
||
|
|
>>> validate_plugin_code(code) # Passes
|
||
|
|
>>>
|
||
|
|
>>> code = "import os; os.system('rm -rf /')"
|
||
|
|
>>> validate_plugin_code(code) # Raises PluginValidationError
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
tree = ast.parse(code)
|
||
|
|
except SyntaxError as e:
|
||
|
|
raise PluginValidationError(f"Syntax error in plugin code: {e}")
|
||
|
|
|
||
|
|
# Check for dangerous imports
|
||
|
|
for node in ast.walk(tree):
|
||
|
|
if isinstance(node, ast.Import):
|
||
|
|
for alias in node.names:
|
||
|
|
if alias.name not in SAFE_MODULES:
|
||
|
|
raise PluginValidationError(
|
||
|
|
f"Unsafe import: {alias.name}. "
|
||
|
|
f"Allowed modules: {SAFE_MODULES}"
|
||
|
|
)
|
||
|
|
|
||
|
|
elif isinstance(node, ast.ImportFrom):
|
||
|
|
if node.module and node.module not in SAFE_MODULES:
|
||
|
|
# Allow submodules of safe modules
|
||
|
|
base_module = node.module.split('.')[0]
|
||
|
|
if base_module not in SAFE_MODULES:
|
||
|
|
raise PluginValidationError(
|
||
|
|
f"Unsafe import from: {node.module}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check for dangerous function calls
|
||
|
|
elif isinstance(node, ast.Call):
|
||
|
|
if isinstance(node.func, ast.Name):
|
||
|
|
if node.func.id in DANGEROUS_OPERATIONS:
|
||
|
|
if node.func.id == 'open' and allow_file_ops:
|
||
|
|
continue # Allow if explicitly permitted
|
||
|
|
raise PluginValidationError(
|
||
|
|
f"Dangerous operation: {node.func.id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check for attribute access to dangerous methods
|
||
|
|
elif isinstance(node, ast.Attribute):
|
||
|
|
dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'}
|
||
|
|
if node.attr in dangerous_attrs and not allow_file_ops:
|
||
|
|
raise PluginValidationError(
|
||
|
|
f"Dangerous attribute access: {node.attr}"
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("Plugin code validation passed")
|
||
|
|
|
||
|
|
|
||
|
|
def check_hook_function_signature(func_code: str) -> bool:
|
||
|
|
"""
|
||
|
|
Verify that a hook function has the correct signature.
|
||
|
|
|
||
|
|
Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||
|
|
|
||
|
|
Args:
|
||
|
|
func_code: Function source code
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if signature is valid
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
PluginValidationError: If signature is invalid
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
tree = ast.parse(func_code)
|
||
|
|
except SyntaxError as e:
|
||
|
|
raise PluginValidationError(f"Syntax error: {e}")
|
||
|
|
|
||
|
|
# Find function definition
|
||
|
|
func_def = None
|
||
|
|
for node in ast.walk(tree):
|
||
|
|
if isinstance(node, ast.FunctionDef):
|
||
|
|
func_def = node
|
||
|
|
break
|
||
|
|
|
||
|
|
if func_def is None:
|
||
|
|
raise PluginValidationError("No function definition found")
|
||
|
|
|
||
|
|
# Check that it takes exactly one argument (context)
|
||
|
|
if len(func_def.args.args) != 1:
|
||
|
|
raise PluginValidationError(
|
||
|
|
f"Hook function must take exactly 1 argument (context), "
|
||
|
|
f"got {len(func_def.args.args)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(f"Hook function '{func_def.name}' signature is valid")
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def sanitize_plugin_name(name: str) -> str:
|
||
|
|
"""
|
||
|
|
Sanitize a plugin name to be safe for filesystem and imports.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Original plugin name
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Sanitized name (lowercase, alphanumeric + underscore only)
|
||
|
|
"""
|
||
|
|
# Remove special characters, keep only alphanumeric and underscore
|
||
|
|
sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower())
|
||
|
|
|
||
|
|
# Ensure it doesn't start with a number
|
||
|
|
if sanitized[0].isdigit():
|
||
|
|
sanitized = f"plugin_{sanitized}"
|
||
|
|
|
||
|
|
return sanitized
|
||
|
|
|
||
|
|
|
||
|
|
def get_imported_modules(code: str) -> Set[str]:
|
||
|
|
"""
|
||
|
|
Extract all imported modules from code.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
code: Python source code
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Set of module names
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
tree = ast.parse(code)
|
||
|
|
except SyntaxError:
|
||
|
|
return set()
|
||
|
|
|
||
|
|
modules = set()
|
||
|
|
for node in ast.walk(tree):
|
||
|
|
if isinstance(node, ast.Import):
|
||
|
|
for alias in node.names:
|
||
|
|
modules.add(alias.name)
|
||
|
|
elif isinstance(node, ast.ImportFrom):
|
||
|
|
if node.module:
|
||
|
|
modules.add(node.module)
|
||
|
|
|
||
|
|
return modules
|
||
|
|
|
||
|
|
|
||
|
|
def estimate_complexity(code: str) -> int:
|
||
|
|
"""
|
||
|
|
Estimate the cyclomatic complexity of code.
|
||
|
|
|
||
|
|
Higher values indicate more complex code with more branching.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
code: Python source code
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Complexity estimate (1 = simple, 10+ = complex)
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
tree = ast.parse(code)
|
||
|
|
except SyntaxError:
|
||
|
|
return 0
|
||
|
|
|
||
|
|
complexity = 1 # Base complexity
|
||
|
|
|
||
|
|
for node in ast.walk(tree):
|
||
|
|
# Add 1 for each branching statement
|
||
|
|
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
|
||
|
|
complexity += 1
|
||
|
|
elif isinstance(node, ast.BoolOp):
|
||
|
|
complexity += len(node.values) - 1
|
||
|
|
|
||
|
|
return complexity
|