Files
Atomizer/optimization_engine/validators/results_validator.py

566 lines
19 KiB
Python
Raw Normal View History

"""
Results Validator for Atomizer Optimization Studies
Validates optimization results stored in study.db and provides
analysis of trial quality, constraint satisfaction, and data integrity.
Usage:
from optimization_engine.validators.results_validator import validate_results
result = validate_results("studies/my_study/2_results/study.db")
if result.is_valid:
print("Results are valid")
else:
for error in result.errors:
print(f"ERROR: {error}")
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import json
@dataclass
class ResultsError:
"""Represents an error found during results validation."""
code: str
message: str
trial_number: Optional[int] = None
def __str__(self) -> str:
if self.trial_number is not None:
return f"[{self.code}] Trial #{self.trial_number}: {self.message}"
return f"[{self.code}] {self.message}"
@dataclass
class ResultsWarning:
"""Represents a warning found during results validation."""
code: str
message: str
trial_number: Optional[int] = None
def __str__(self) -> str:
if self.trial_number is not None:
return f"[{self.code}] Trial #{self.trial_number}: {self.message}"
return f"[{self.code}] {self.message}"
@dataclass
class ResultsInfo:
"""Information about the optimization results."""
study_name: str = ""
n_trials: int = 0
n_completed: int = 0
n_failed: int = 0
n_pruned: int = 0
n_pareto: int = 0
feasibility_rate: float = 0.0
is_multi_objective: bool = False
objective_names: List[str] = field(default_factory=list)
best_values: Dict[str, float] = field(default_factory=dict)
parameter_names: List[str] = field(default_factory=list)
@dataclass
class ResultsValidationResult:
"""Complete validation result for optimization results."""
is_valid: bool
errors: List[ResultsError]
warnings: List[ResultsWarning]
info: ResultsInfo
def __str__(self) -> str:
lines = []
# Status
status = "[OK] Results validation passed!" if self.is_valid else "[X] Results validation failed!"
lines.append(status)
lines.append("")
# Info
lines.append("RESULTS SUMMARY")
lines.append("-" * 40)
lines.append(f" Study: {self.info.study_name}")
lines.append(f" Total trials: {self.info.n_trials}")
lines.append(f" Completed: {self.info.n_completed}")
lines.append(f" Failed: {self.info.n_failed}")
if self.info.n_pruned > 0:
lines.append(f" Pruned: {self.info.n_pruned}")
lines.append(f" Multi-objective: {'Yes' if self.info.is_multi_objective else 'No'}")
if self.info.is_multi_objective and self.info.n_pareto > 0:
lines.append(f" Pareto-optimal: {self.info.n_pareto}")
if self.info.feasibility_rate > 0:
lines.append(f" Feasibility rate: {self.info.feasibility_rate:.1f}%")
lines.append("")
# Best values
if self.info.best_values:
lines.append("BEST VALUES")
lines.append("-" * 40)
for name, value in self.info.best_values.items():
lines.append(f" {name}: {value:.4f}")
lines.append("")
# Errors
if self.errors:
lines.append("ERRORS")
lines.append("-" * 40)
for error in self.errors:
lines.append(f" {error}")
lines.append("")
# Warnings
if self.warnings:
lines.append("WARNINGS")
lines.append("-" * 40)
for warning in self.warnings:
lines.append(f" {warning}")
lines.append("")
return "\n".join(lines)
def validate_results(
db_path: str,
config_path: Optional[str] = None,
min_trials: int = 1
) -> ResultsValidationResult:
"""
Validate optimization results stored in study.db.
Args:
db_path: Path to study.db file
config_path: Optional path to optimization_config.json for cross-validation
min_trials: Minimum number of completed trials required
Returns:
ResultsValidationResult with errors, warnings, and info
"""
errors: List[ResultsError] = []
warnings: List[ResultsWarning] = []
info = ResultsInfo()
db_path = Path(db_path)
# Check database exists
if not db_path.exists():
errors.append(ResultsError(
code="DB_NOT_FOUND",
message=f"Database not found: {db_path}"
))
return ResultsValidationResult(
is_valid=False,
errors=errors,
warnings=warnings,
info=info
)
# Try to load with Optuna
try:
import optuna
storage_url = f"sqlite:///{db_path}"
# Get all studies in the database
storage = optuna.storages.RDBStorage(url=storage_url)
study_summaries = storage.get_all_studies()
if not study_summaries:
errors.append(ResultsError(
code="NO_STUDIES",
message="Database contains no optimization studies"
))
return ResultsValidationResult(
is_valid=False,
errors=errors,
warnings=warnings,
info=info
)
# Use the first (usually only) study
study_summary = study_summaries[0]
info.study_name = study_summary.study_name
# Load the full study
study = optuna.load_study(
study_name=info.study_name,
storage=storage_url
)
# Basic counts
info.n_trials = len(study.trials)
info.n_completed = len([t for t in study.trials
if t.state == optuna.trial.TrialState.COMPLETE])
info.n_failed = len([t for t in study.trials
if t.state == optuna.trial.TrialState.FAIL])
info.n_pruned = len([t for t in study.trials
if t.state == optuna.trial.TrialState.PRUNED])
# Check minimum trials
if info.n_completed < min_trials:
errors.append(ResultsError(
code="INSUFFICIENT_TRIALS",
message=f"Only {info.n_completed} completed trials (minimum: {min_trials})"
))
# Check for multi-objective
info.is_multi_objective = len(study.directions) > 1
# Get parameter names from first completed trial
for trial in study.trials:
if trial.state == optuna.trial.TrialState.COMPLETE:
info.parameter_names = list(trial.params.keys())
break
# Analyze Pareto front for multi-objective
if info.is_multi_objective:
try:
pareto_trials = study.best_trials
info.n_pareto = len(pareto_trials)
if info.n_pareto == 0 and info.n_completed > 0:
warnings.append(ResultsWarning(
code="NO_PARETO",
message="No Pareto-optimal solutions found despite completed trials"
))
except Exception as e:
warnings.append(ResultsWarning(
code="PARETO_ERROR",
message=f"Could not compute Pareto front: {e}"
))
else:
# Single objective - get best value
if info.n_completed > 0:
try:
best_trial = study.best_trial
info.best_values["objective"] = best_trial.value
except Exception:
pass
# Analyze feasibility
feasible_count = 0
for trial in study.trials:
if trial.state == optuna.trial.TrialState.COMPLETE:
# Check user_attrs for feasibility flag
is_feasible = trial.user_attrs.get('feasible', True)
if is_feasible:
feasible_count += 1
if info.n_completed > 0:
info.feasibility_rate = (feasible_count / info.n_completed) * 100
if info.feasibility_rate < 50:
warnings.append(ResultsWarning(
code="LOW_FEASIBILITY",
message=f"Low feasibility rate ({info.feasibility_rate:.1f}%) - consider relaxing constraints or adjusting bounds"
))
elif info.feasibility_rate < 80:
warnings.append(ResultsWarning(
code="MODERATE_FEASIBILITY",
message=f"Moderate feasibility rate ({info.feasibility_rate:.1f}%)"
))
# Check for data quality issues
_validate_trial_data(study, errors, warnings)
# Cross-validate with config if provided
if config_path:
_cross_validate_with_config(study, config_path, info, errors, warnings)
except ImportError:
errors.append(ResultsError(
code="OPTUNA_NOT_INSTALLED",
message="Optuna is not installed. Cannot validate results."
))
except Exception as e:
errors.append(ResultsError(
code="LOAD_ERROR",
message=f"Failed to load study: {e}"
))
return ResultsValidationResult(
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings,
info=info
)
def _validate_trial_data(study, errors: List[ResultsError], warnings: List[ResultsWarning]):
"""Validate individual trial data quality."""
import optuna
for trial in study.trials:
if trial.state != optuna.trial.TrialState.COMPLETE:
continue
# Check for NaN or inf values
if trial.values:
for i, val in enumerate(trial.values):
if val is None:
errors.append(ResultsError(
code="NULL_OBJECTIVE",
message=f"Objective {i} has null value",
trial_number=trial.number
))
elif not isinstance(val, (int, float)):
errors.append(ResultsError(
code="INVALID_OBJECTIVE_TYPE",
message=f"Objective {i} has invalid type: {type(val)}",
trial_number=trial.number
))
elif isinstance(val, float):
import math
if math.isnan(val):
errors.append(ResultsError(
code="NAN_OBJECTIVE",
message=f"Objective {i} is NaN",
trial_number=trial.number
))
elif math.isinf(val):
warnings.append(ResultsWarning(
code="INF_OBJECTIVE",
message=f"Objective {i} is infinite",
trial_number=trial.number
))
# Check for missing parameters
if not trial.params:
errors.append(ResultsError(
code="MISSING_PARAMS",
message="Trial has no parameters recorded",
trial_number=trial.number
))
# Check for negative values where unexpected
for param_name, param_value in trial.params.items():
if 'thickness' in param_name.lower() and param_value <= 0:
warnings.append(ResultsWarning(
code="INVALID_THICKNESS",
message=f"{param_name} = {param_value} (non-positive thickness)",
trial_number=trial.number
))
elif 'diameter' in param_name.lower() and param_value <= 0:
warnings.append(ResultsWarning(
code="INVALID_DIAMETER",
message=f"{param_name} = {param_value} (non-positive diameter)",
trial_number=trial.number
))
def _cross_validate_with_config(
study,
config_path: str,
info: ResultsInfo,
errors: List[ResultsError],
warnings: List[ResultsWarning]
):
"""Cross-validate results with optimization config."""
import optuna
config_path = Path(config_path)
if not config_path.exists():
warnings.append(ResultsWarning(
code="CONFIG_NOT_FOUND",
message=f"Config file not found for cross-validation: {config_path}"
))
return
try:
with open(config_path, 'r') as f:
config = json.load(f)
# Check parameter names match
config_params = set()
for var in config.get('design_variables', []):
param_name = var.get('parameter', var.get('name', ''))
if param_name:
config_params.add(param_name)
result_params = set(info.parameter_names)
missing_in_results = config_params - result_params
extra_in_results = result_params - config_params
if missing_in_results:
warnings.append(ResultsWarning(
code="MISSING_PARAMS_IN_RESULTS",
message=f"Config params not in results: {missing_in_results}"
))
if extra_in_results:
warnings.append(ResultsWarning(
code="EXTRA_PARAMS_IN_RESULTS",
message=f"Results have extra params not in config: {extra_in_results}"
))
# Check objective count matches
config_objectives = len(config.get('objectives', []))
result_objectives = len(study.directions)
if config_objectives != result_objectives:
warnings.append(ResultsWarning(
code="OBJECTIVE_COUNT_MISMATCH",
message=f"Config has {config_objectives} objectives, results have {result_objectives}"
))
# Get objective names from config
for obj in config.get('objectives', []):
obj_name = obj.get('name', '')
if obj_name:
info.objective_names.append(obj_name)
# Check bounds violations
for trial in study.trials:
if trial.state != optuna.trial.TrialState.COMPLETE:
continue
for var in config.get('design_variables', []):
param_name = var.get('parameter', var.get('name', ''))
bounds = var.get('bounds', [])
if param_name in trial.params and len(bounds) == 2:
value = trial.params[param_name]
min_val, max_val = bounds
# Small tolerance for floating point
tolerance = (max_val - min_val) * 0.001
if value < min_val - tolerance:
warnings.append(ResultsWarning(
code="BELOW_MIN_BOUND",
message=f"{param_name} = {value} < min ({min_val})",
trial_number=trial.number
))
elif value > max_val + tolerance:
warnings.append(ResultsWarning(
code="ABOVE_MAX_BOUND",
message=f"{param_name} = {value} > max ({max_val})",
trial_number=trial.number
))
except json.JSONDecodeError as e:
warnings.append(ResultsWarning(
code="CONFIG_PARSE_ERROR",
message=f"Could not parse config JSON: {e}"
))
except Exception as e:
warnings.append(ResultsWarning(
code="CONFIG_ERROR",
message=f"Error reading config: {e}"
))
def validate_study_results(study_name: str) -> ResultsValidationResult:
"""
Convenience function to validate results for a named study.
Args:
study_name: Name of the study (folder in studies/)
Returns:
ResultsValidationResult
"""
from pathlib import Path
study_dir = Path(f"studies/{study_name}")
db_path = study_dir / "2_results" / "study.db"
config_path = study_dir / "1_setup" / "optimization_config.json"
return validate_results(
db_path=str(db_path),
config_path=str(config_path) if config_path.exists() else None
)
def get_pareto_summary(db_path: str) -> Dict[str, Any]:
"""
Get a summary of Pareto-optimal designs from results.
Args:
db_path: Path to study.db
Returns:
Dictionary with Pareto front information
"""
try:
import optuna
storage_url = f"sqlite:///{db_path}"
storage = optuna.storages.RDBStorage(url=storage_url)
study_summaries = storage.get_all_studies()
if not study_summaries:
return {"error": "No studies found"}
study = optuna.load_study(
study_name=study_summaries[0].study_name,
storage=storage_url
)
if len(study.directions) < 2:
# Single objective
if study.best_trial:
return {
"type": "single_objective",
"best_trial": study.best_trial.number,
"best_value": study.best_value,
"best_params": study.best_params
}
return {"error": "No completed trials"}
# Multi-objective
pareto_trials = study.best_trials
designs = []
for trial in pareto_trials:
designs.append({
"trial_number": trial.number,
"objectives": trial.values,
"parameters": trial.params,
"user_attrs": dict(trial.user_attrs)
})
# Calculate ranges
ranges = {}
if designs:
for i in range(len(designs[0]["objectives"])):
values = [d["objectives"][i] for d in designs]
ranges[f"objective_{i}"] = {
"min": min(values),
"max": max(values),
"spread": max(values) - min(values)
}
return {
"type": "multi_objective",
"n_pareto": len(pareto_trials),
"designs": designs,
"ranges": ranges
}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python results_validator.py <study_name_or_db_path>")
print("Example: python results_validator.py uav_arm_optimization")
print("Example: python results_validator.py studies/my_study/2_results/study.db")
sys.exit(1)
arg = sys.argv[1]
# Check if it's a study name or db path
if arg.endswith('.db'):
result = validate_results(arg)
else:
result = validate_study_results(arg)
print(result)