diff --git a/projects/hydrotech-beam/studies/01_doe_landscape/run_doe.py b/projects/hydrotech-beam/studies/01_doe_landscape/run_doe.py index cda9e7df..f25de316 100644 --- a/projects/hydrotech-beam/studies/01_doe_landscape/run_doe.py +++ b/projects/hydrotech-beam/studies/01_doe_landscape/run_doe.py @@ -418,6 +418,10 @@ def run_study(args: argparse.Namespace) -> None: db_path = results_dir / "optuna_study.db" storage = f"sqlite:///{db_path}" + if args.clean and db_path.exists(): + logger.info("--clean flag: deleting existing DB at %s", db_path) + db_path.unlink() + if args.resume: logger.info("Resuming existing study: %s", args.study_name) study = optuna.load_study( @@ -430,7 +434,7 @@ def run_study(args: argparse.Namespace) -> None: study_name=args.study_name, storage=storage, direction="minimize", # minimize mass - load_if_exists=False, + load_if_exists=True, # safe re-run: reuse if exists sampler=optuna.samplers.TPESampler(seed=args.seed), ) logger.info("Created new study: %s (storage: %s)", args.study_name, db_path) @@ -622,6 +626,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Resume an existing study instead of creating a new one", ) + parser.add_argument( + "--clean", + action="store_true", + help="Delete existing results DB before starting (fresh run)", + ) parser.add_argument( "--verbose", "-v", action="store_true", diff --git a/projects/hydrotech-beam/studies/01_doe_landscape/sampling.py b/projects/hydrotech-beam/studies/01_doe_landscape/sampling.py index 6e88ca86..4cfa6c0c 100644 --- a/projects/hydrotech-beam/studies/01_doe_landscape/sampling.py +++ b/projects/hydrotech-beam/studies/01_doe_landscape/sampling.py @@ -173,6 +173,16 @@ def _ensure_integer_coverage( logger.info("All 11 hole_count levels represented ✓") return samples + # Skip patching when sample size is too small to cover all levels + n_samples = len(samples) + if n_samples < len(all_levels): + logger.info( + "Only %d samples — too few to cover all 11 hole_count levels " + "(need ≥11). Skipping stratified patching.", + n_samples, + ) + return samples + logger.warning( "Missing hole_count levels: %s — patching with replacements", sorted(missing_levels),