""" Plugin Code Validators Ensures safety and correctness of plugin code before execution. """ import ast import re from typing import List, Set import logging logger = logging.getLogger(__name__) class PluginValidationError(Exception): """Raised when plugin code fails validation.""" pass # Whitelist of safe modules for plugin imports SAFE_MODULES = { 'math', 'numpy', 'scipy', 'pandas', 'pathlib', 'json', 'csv', 'datetime', 'logging', 'typing', 'dataclasses', 'optuna', 'pyNastran' } # Blacklist of dangerous operations DANGEROUS_OPERATIONS = { 'eval', 'exec', 'compile', '__import__', 'open', # File operations should be explicit 'subprocess', 'os.system', 'os.popen', 'shutil.rmtree', # Dangerous file operations } def validate_plugin_code(code: str, allow_file_ops: bool = False) -> None: """ Validate plugin code for safety before execution. Args: code: Python source code to validate allow_file_ops: Whether to allow file operations (open, Path.write_text, etc.) Raises: PluginValidationError: If code contains unsafe operations Example: >>> code = "def my_hook(context): return {'result': context['x'] * 2}" >>> validate_plugin_code(code) # Passes >>> >>> code = "import os; os.system('rm -rf /')" >>> validate_plugin_code(code) # Raises PluginValidationError """ try: tree = ast.parse(code) except SyntaxError as e: raise PluginValidationError(f"Syntax error in plugin code: {e}") # Check for dangerous imports for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: if alias.name not in SAFE_MODULES: raise PluginValidationError( f"Unsafe import: {alias.name}. " f"Allowed modules: {SAFE_MODULES}" ) elif isinstance(node, ast.ImportFrom): if node.module and node.module not in SAFE_MODULES: # Allow submodules of safe modules base_module = node.module.split('.')[0] if base_module not in SAFE_MODULES: raise PluginValidationError( f"Unsafe import from: {node.module}" ) # Check for dangerous function calls elif isinstance(node, ast.Call): if isinstance(node.func, ast.Name): if node.func.id in DANGEROUS_OPERATIONS: if node.func.id == 'open' and allow_file_ops: continue # Allow if explicitly permitted raise PluginValidationError( f"Dangerous operation: {node.func.id}" ) # Check for attribute access to dangerous methods elif isinstance(node, ast.Attribute): dangerous_attrs = {'system', 'popen', 'rmtree', 'remove', 'unlink'} if node.attr in dangerous_attrs and not allow_file_ops: raise PluginValidationError( f"Dangerous attribute access: {node.attr}" ) logger.info("Plugin code validation passed") def check_hook_function_signature(func_code: str) -> bool: """ Verify that a hook function has the correct signature. Expected: def hook_name(context: Dict[str, Any]) -> Optional[Dict[str, Any]] Args: func_code: Function source code Returns: True if signature is valid Raises: PluginValidationError: If signature is invalid """ try: tree = ast.parse(func_code) except SyntaxError as e: raise PluginValidationError(f"Syntax error: {e}") # Find function definition func_def = None for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): func_def = node break if func_def is None: raise PluginValidationError("No function definition found") # Check that it takes exactly one argument (context) if len(func_def.args.args) != 1: raise PluginValidationError( f"Hook function must take exactly 1 argument (context), " f"got {len(func_def.args.args)}" ) logger.info(f"Hook function '{func_def.name}' signature is valid") return True def sanitize_plugin_name(name: str) -> str: """ Sanitize a plugin name to be safe for filesystem and imports. Args: name: Original plugin name Returns: Sanitized name (lowercase, alphanumeric + underscore only) """ # Remove special characters, keep only alphanumeric and underscore sanitized = re.sub(r'[^a-z0-9_]', '_', name.lower()) # Ensure it doesn't start with a number if sanitized[0].isdigit(): sanitized = f"plugin_{sanitized}" return sanitized def get_imported_modules(code: str) -> Set[str]: """ Extract all imported modules from code. Args: code: Python source code Returns: Set of module names """ try: tree = ast.parse(code) except SyntaxError: return set() modules = set() for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: modules.add(alias.name) elif isinstance(node, ast.ImportFrom): if node.module: modules.add(node.module) return modules def estimate_complexity(code: str) -> int: """ Estimate the cyclomatic complexity of code. Higher values indicate more complex code with more branching. Args: code: Python source code Returns: Complexity estimate (1 = simple, 10+ = complex) """ try: tree = ast.parse(code) except SyntaxError: return 0 complexity = 1 # Base complexity for node in ast.walk(tree): # Add 1 for each branching statement if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)): complexity += 1 elif isinstance(node, ast.BoolOp): complexity += len(node.values) - 1 return complexity