Fix: SQLite duplicate study (load_if_exists), sampling crash with n<11, add --clean flag

This commit is contained in:
2026-02-11 13:09:30 +00:00
parent e8b4d37667
commit 135698d96a
2 changed files with 20 additions and 1 deletions

View File

@@ -418,6 +418,10 @@ def run_study(args: argparse.Namespace) -> None:
db_path = results_dir / "optuna_study.db"
storage = f"sqlite:///{db_path}"
if args.clean and db_path.exists():
logger.info("--clean flag: deleting existing DB at %s", db_path)
db_path.unlink()
if args.resume:
logger.info("Resuming existing study: %s", args.study_name)
study = optuna.load_study(
@@ -430,7 +434,7 @@ def run_study(args: argparse.Namespace) -> None:
study_name=args.study_name,
storage=storage,
direction="minimize", # minimize mass
load_if_exists=False,
load_if_exists=True, # safe re-run: reuse if exists
sampler=optuna.samplers.TPESampler(seed=args.seed),
)
logger.info("Created new study: %s (storage: %s)", args.study_name, db_path)
@@ -622,6 +626,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Resume an existing study instead of creating a new one",
)
parser.add_argument(
"--clean",
action="store_true",
help="Delete existing results DB before starting (fresh run)",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",

View File

@@ -173,6 +173,16 @@ def _ensure_integer_coverage(
logger.info("All 11 hole_count levels represented ✓")
return samples
# Skip patching when sample size is too small to cover all levels
n_samples = len(samples)
if n_samples < len(all_levels):
logger.info(
"Only %d samples — too few to cover all 11 hole_count levels "
"(need ≥11). Skipping stratified patching.",
n_samples,
)
return samples
logger.warning(
"Missing hole_count levels: %s — patching with replacements",
sorted(missing_levels),