Files

592 lines
20 KiB
Python
Raw Permalink Normal View History

"""
Configuration Validator for Atomizer
====================================
Validates optimization_config.json files before running optimizations.
Catches common errors and provides helpful suggestions.
Usage:
from optimization_engine.validators import validate_config, validate_config_file
# Validate from file path
result = validate_config_file("studies/my_study/1_setup/optimization_config.json")
# Validate from dict
result = validate_config(config_dict)
if result.is_valid:
print("Config is valid!")
else:
for error in result.errors:
print(f"ERROR: {error}")
for warning in result.warnings:
print(f"WARNING: {warning}")
"""
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
@dataclass
class ConfigError:
"""Represents a configuration error that blocks execution."""
field: str
message: str
suggestion: Optional[str] = None
def __str__(self):
msg = f"[{self.field}] {self.message}"
if self.suggestion:
msg += f" (Suggestion: {self.suggestion})"
return msg
@dataclass
class ConfigWarning:
"""Represents a configuration warning that doesn't block execution."""
field: str
message: str
suggestion: Optional[str] = None
def __str__(self):
msg = f"[{self.field}] {self.message}"
if self.suggestion:
msg += f" (Suggestion: {self.suggestion})"
return msg
@dataclass
class ValidationResult:
"""Result of configuration validation."""
errors: List[ConfigError] = field(default_factory=list)
warnings: List[ConfigWarning] = field(default_factory=list)
config: Optional[Dict[str, Any]] = None
@property
def is_valid(self) -> bool:
"""Config is valid if there are no errors (warnings are OK)."""
return len(self.errors) == 0
def __str__(self):
lines = []
if self.errors:
lines.append(f"ERRORS ({len(self.errors)}):")
for e in self.errors:
lines.append(f" - {e}")
if self.warnings:
lines.append(f"WARNINGS ({len(self.warnings)}):")
for w in self.warnings:
lines.append(f" - {w}")
if self.is_valid and not self.warnings:
lines.append("Configuration is valid.")
return "\n".join(lines)
# Valid values for certain fields
VALID_PROTOCOLS = [
'protocol_10_single_objective',
'protocol_11_multi_objective',
'protocol_12_hybrid_surrogate',
'legacy'
]
VALID_SAMPLERS = [
'TPESampler',
'NSGAIISampler',
'CmaEsSampler',
'RandomSampler',
'GridSampler'
]
VALID_GOALS = ['minimize', 'maximize']
VALID_CONSTRAINT_TYPES = ['less_than', 'greater_than', 'equal_to', 'range']
VALID_VAR_TYPES = ['float', 'integer', 'categorical']
VALID_EXTRACTION_ACTIONS = [
'extract_displacement',
'extract_solid_stress',
'extract_frequency',
'extract_mass_from_expression',
'extract_mass_from_bdf',
'extract_mass',
'extract_stress'
]
def validate_config_file(config_path: Union[str, Path]) -> ValidationResult:
"""
Validate an optimization_config.json file.
Args:
config_path: Path to the configuration file
Returns:
ValidationResult with errors, warnings, and parsed config
"""
config_path = Path(config_path)
result = ValidationResult()
# Check file exists
if not config_path.exists():
result.errors.append(ConfigError(
field="file",
message=f"Configuration file not found: {config_path}",
suggestion="Create optimization_config.json using the create-study skill"
))
return result
# Parse JSON
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
except json.JSONDecodeError as e:
result.errors.append(ConfigError(
field="file",
message=f"Invalid JSON: {e}",
suggestion="Check for syntax errors (missing commas, quotes, brackets)"
))
return result
# Validate content
return validate_config(config, result)
def validate_config(config: Dict[str, Any],
result: Optional[ValidationResult] = None) -> ValidationResult:
"""
Validate an optimization configuration dictionary.
Args:
config: Configuration dictionary
result: Existing ValidationResult to append to (optional)
Returns:
ValidationResult with errors, warnings, and config
"""
if result is None:
result = ValidationResult()
result.config = config
# Required top-level fields
_validate_required_fields(config, result)
# Validate each section
if 'design_variables' in config:
_validate_design_variables(config['design_variables'], result)
if 'objectives' in config:
_validate_objectives(config['objectives'], result)
if 'constraints' in config:
_validate_constraints(config['constraints'], result)
if 'optimization_settings' in config:
_validate_optimization_settings(config['optimization_settings'], result)
if 'simulation' in config:
_validate_simulation_settings(config['simulation'], result)
if 'surrogate_settings' in config:
_validate_surrogate_settings(config['surrogate_settings'], result)
# Cross-field validations
_validate_cross_references(config, result)
return result
def _validate_required_fields(config: Dict[str, Any], result: ValidationResult):
"""Check that required top-level fields exist."""
required = ['study_name', 'design_variables', 'objectives']
for field in required:
if field not in config:
result.errors.append(ConfigError(
field=field,
message=f"Required field '{field}' is missing",
suggestion=f"Add '{field}' to your configuration"
))
# Recommended fields
recommended = ['description', 'engineering_context', 'optimization_settings', 'simulation']
for field in recommended:
if field not in config:
result.warnings.append(ConfigWarning(
field=field,
message=f"Recommended field '{field}' is missing",
suggestion=f"Consider adding '{field}' for better documentation"
))
def _validate_design_variables(variables: List[Dict], result: ValidationResult):
"""Validate design variables section."""
if not isinstance(variables, list):
result.errors.append(ConfigError(
field="design_variables",
message="design_variables must be a list",
suggestion="Use array format: [{parameter: ..., bounds: ...}, ...]"
))
return
if len(variables) == 0:
result.errors.append(ConfigError(
field="design_variables",
message="At least one design variable is required",
suggestion="Add design variables with parameter names and bounds"
))
return
param_names = set()
for i, var in enumerate(variables):
prefix = f"design_variables[{i}]"
# Required fields
if 'parameter' not in var:
result.errors.append(ConfigError(
field=prefix,
message="'parameter' name is required",
suggestion="Add 'parameter': 'your_nx_expression_name'"
))
else:
param = var['parameter']
if param in param_names:
result.errors.append(ConfigError(
field=prefix,
message=f"Duplicate parameter name: '{param}'",
suggestion="Each parameter name must be unique"
))
param_names.add(param)
if 'bounds' not in var:
result.errors.append(ConfigError(
field=prefix,
message="'bounds' are required",
suggestion="Add 'bounds': [min_value, max_value]"
))
else:
bounds = var['bounds']
if not isinstance(bounds, list) or len(bounds) != 2:
result.errors.append(ConfigError(
field=f"{prefix}.bounds",
message="Bounds must be [min, max] array",
suggestion="Use format: 'bounds': [1.0, 10.0]"
))
elif bounds[0] >= bounds[1]:
result.errors.append(ConfigError(
field=f"{prefix}.bounds",
message=f"Min ({bounds[0]}) must be less than max ({bounds[1]})",
suggestion="Swap values or adjust range"
))
elif bounds[0] == bounds[1]:
result.warnings.append(ConfigWarning(
field=f"{prefix}.bounds",
message="Min equals max - variable will be constant",
suggestion="If intentional, consider removing this variable"
))
# Type validation
var_type = var.get('type', 'float')
if var_type not in VALID_VAR_TYPES:
result.warnings.append(ConfigWarning(
field=f"{prefix}.type",
message=f"Unknown type '{var_type}'",
suggestion=f"Use one of: {', '.join(VALID_VAR_TYPES)}"
))
# Integer bounds check
if var_type == 'integer' and 'bounds' in var:
bounds = var['bounds']
if isinstance(bounds, list) and len(bounds) == 2:
if not (isinstance(bounds[0], int) and isinstance(bounds[1], int)):
result.warnings.append(ConfigWarning(
field=f"{prefix}.bounds",
message="Integer variable bounds should be integers",
suggestion="Use whole numbers for integer bounds"
))
def _validate_objectives(objectives: List[Dict], result: ValidationResult):
"""Validate objectives section."""
if not isinstance(objectives, list):
result.errors.append(ConfigError(
field="objectives",
message="objectives must be a list",
suggestion="Use array format: [{name: ..., goal: ...}, ...]"
))
return
if len(objectives) == 0:
result.errors.append(ConfigError(
field="objectives",
message="At least one objective is required",
suggestion="Add an objective with name and goal (minimize/maximize)"
))
return
if len(objectives) > 3:
result.warnings.append(ConfigWarning(
field="objectives",
message=f"{len(objectives)} objectives may make optimization difficult",
suggestion="Consider reducing to 2-3 objectives for clearer trade-offs"
))
obj_names = set()
for i, obj in enumerate(objectives):
prefix = f"objectives[{i}]"
# Required fields
if 'name' not in obj:
result.errors.append(ConfigError(
field=prefix,
message="'name' is required",
suggestion="Add 'name': 'mass' or similar"
))
else:
name = obj['name']
if name in obj_names:
result.errors.append(ConfigError(
field=prefix,
message=f"Duplicate objective name: '{name}'",
suggestion="Each objective name must be unique"
))
obj_names.add(name)
if 'goal' not in obj:
result.errors.append(ConfigError(
field=prefix,
message="'goal' is required",
suggestion="Add 'goal': 'minimize' or 'goal': 'maximize'"
))
elif obj['goal'] not in VALID_GOALS:
result.errors.append(ConfigError(
field=f"{prefix}.goal",
message=f"Invalid goal '{obj['goal']}'",
suggestion=f"Use one of: {', '.join(VALID_GOALS)}"
))
# Extraction validation
if 'extraction' in obj:
_validate_extraction(obj['extraction'], f"{prefix}.extraction", result)
def _validate_constraints(constraints: List[Dict], result: ValidationResult):
"""Validate constraints section."""
if not isinstance(constraints, list):
result.errors.append(ConfigError(
field="constraints",
message="constraints must be a list",
suggestion="Use array format: [{name: ..., type: ..., threshold: ...}, ...]"
))
return
constraint_names = set()
for i, const in enumerate(constraints):
prefix = f"constraints[{i}]"
# Required fields
if 'name' not in const:
result.errors.append(ConfigError(
field=prefix,
message="'name' is required",
suggestion="Add 'name': 'max_stress' or similar"
))
else:
name = const['name']
if name in constraint_names:
result.warnings.append(ConfigWarning(
field=prefix,
message=f"Duplicate constraint name: '{name}'",
suggestion="Consider using unique names for clarity"
))
constraint_names.add(name)
if 'type' not in const:
result.errors.append(ConfigError(
field=prefix,
message="'type' is required",
suggestion="Add 'type': 'less_than' or 'type': 'greater_than'"
))
elif const['type'] not in VALID_CONSTRAINT_TYPES:
result.errors.append(ConfigError(
field=f"{prefix}.type",
message=f"Invalid constraint type '{const['type']}'",
suggestion=f"Use one of: {', '.join(VALID_CONSTRAINT_TYPES)}"
))
if 'threshold' not in const:
result.errors.append(ConfigError(
field=prefix,
message="'threshold' is required",
suggestion="Add 'threshold': 200 (the limit value)"
))
# Extraction validation
if 'extraction' in const:
_validate_extraction(const['extraction'], f"{prefix}.extraction", result)
def _validate_extraction(extraction: Dict, prefix: str, result: ValidationResult):
"""Validate extraction configuration."""
if not isinstance(extraction, dict):
result.errors.append(ConfigError(
field=prefix,
message="extraction must be an object",
suggestion="Use format: {action: '...', params: {...}}"
))
return
if 'action' not in extraction:
result.errors.append(ConfigError(
field=prefix,
message="'action' is required in extraction",
suggestion="Add 'action': 'extract_displacement' or similar"
))
elif extraction['action'] not in VALID_EXTRACTION_ACTIONS:
result.warnings.append(ConfigWarning(
field=f"{prefix}.action",
message=f"Unknown extraction action '{extraction['action']}'",
suggestion=f"Standard actions: {', '.join(VALID_EXTRACTION_ACTIONS)}"
))
def _validate_optimization_settings(settings: Dict, result: ValidationResult):
"""Validate optimization settings section."""
# Protocol
if 'protocol' in settings:
protocol = settings['protocol']
if protocol not in VALID_PROTOCOLS:
result.warnings.append(ConfigWarning(
field="optimization_settings.protocol",
message=f"Unknown protocol '{protocol}'",
suggestion=f"Standard protocols: {', '.join(VALID_PROTOCOLS)}"
))
# Number of trials
if 'n_trials' in settings:
n_trials = settings['n_trials']
if not isinstance(n_trials, int) or n_trials < 1:
result.errors.append(ConfigError(
field="optimization_settings.n_trials",
message="n_trials must be a positive integer",
suggestion="Use a value like 30, 50, or 100"
))
elif n_trials < 10:
result.warnings.append(ConfigWarning(
field="optimization_settings.n_trials",
message=f"Only {n_trials} trials may not be enough for good optimization",
suggestion="Consider at least 20-30 trials for meaningful results"
))
# Sampler
if 'sampler' in settings:
sampler = settings['sampler']
if sampler not in VALID_SAMPLERS:
result.warnings.append(ConfigWarning(
field="optimization_settings.sampler",
message=f"Unknown sampler '{sampler}'",
suggestion=f"Standard samplers: {', '.join(VALID_SAMPLERS)}"
))
def _validate_simulation_settings(simulation: Dict, result: ValidationResult):
"""Validate simulation settings section."""
required = ['model_file', 'sim_file']
for field in required:
if field not in simulation:
result.warnings.append(ConfigWarning(
field=f"simulation.{field}",
message=f"'{field}' not specified",
suggestion="Add file name for better documentation"
))
def _validate_surrogate_settings(surrogate: Dict, result: ValidationResult):
"""Validate surrogate (NN) settings section."""
if surrogate.get('enabled', False):
# Check training settings
if 'training' in surrogate:
training = surrogate['training']
if training.get('initial_fea_trials', 0) < 20:
result.warnings.append(ConfigWarning(
field="surrogate_settings.training.initial_fea_trials",
message="Less than 20 initial FEA trials may not provide enough training data",
suggestion="Recommend at least 20-30 initial trials"
))
# Check model settings
if 'model' in surrogate:
model = surrogate['model']
if 'min_accuracy_mape' in model:
mape = model['min_accuracy_mape']
if mape > 20:
result.warnings.append(ConfigWarning(
field="surrogate_settings.model.min_accuracy_mape",
message=f"MAPE threshold {mape}% is quite high",
suggestion="Consider 5-10% for better surrogate accuracy"
))
def _validate_cross_references(config: Dict, result: ValidationResult):
"""Validate cross-references between sections."""
# Check sampler matches objective count
objectives = config.get('objectives', [])
settings = config.get('optimization_settings', {})
sampler = settings.get('sampler', 'TPESampler')
if len(objectives) > 1 and sampler == 'TPESampler':
result.warnings.append(ConfigWarning(
field="optimization_settings.sampler",
message="TPESampler with multiple objectives will scalarize them",
suggestion="Consider NSGAIISampler for true multi-objective optimization"
))
if len(objectives) == 1 and sampler == 'NSGAIISampler':
result.warnings.append(ConfigWarning(
field="optimization_settings.sampler",
message="NSGAIISampler is designed for multi-objective; single-objective may be slower",
suggestion="Consider TPESampler or CmaEsSampler for single-objective"
))
# Protocol consistency
protocol = settings.get('protocol', '')
if 'multi_objective' in protocol and len(objectives) == 1:
result.warnings.append(ConfigWarning(
field="optimization_settings.protocol",
message="Multi-objective protocol with single objective",
suggestion="Use protocol_10_single_objective instead"
))
if 'single_objective' in protocol and len(objectives) > 1:
result.warnings.append(ConfigWarning(
field="optimization_settings.protocol",
message="Single-objective protocol with multiple objectives",
suggestion="Use protocol_11_multi_objective for multiple objectives"
))
# CLI interface for direct execution
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python config_validator.py <path_to_config.json>")
sys.exit(1)
config_path = sys.argv[1]
result = validate_config_file(config_path)
print(result)
if result.is_valid:
print("\n✓ Configuration is valid!")
sys.exit(0)
else:
print(f"\n✗ Configuration has {len(result.errors)} error(s)")
sys.exit(1)