Fix: SQLite duplicate study (load_if_exists), sampling crash with n<11, add --clean flag
This commit is contained in:
@@ -418,6 +418,10 @@ def run_study(args: argparse.Namespace) -> None:
|
|||||||
db_path = results_dir / "optuna_study.db"
|
db_path = results_dir / "optuna_study.db"
|
||||||
storage = f"sqlite:///{db_path}"
|
storage = f"sqlite:///{db_path}"
|
||||||
|
|
||||||
|
if args.clean and db_path.exists():
|
||||||
|
logger.info("--clean flag: deleting existing DB at %s", db_path)
|
||||||
|
db_path.unlink()
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
logger.info("Resuming existing study: %s", args.study_name)
|
logger.info("Resuming existing study: %s", args.study_name)
|
||||||
study = optuna.load_study(
|
study = optuna.load_study(
|
||||||
@@ -430,7 +434,7 @@ def run_study(args: argparse.Namespace) -> None:
|
|||||||
study_name=args.study_name,
|
study_name=args.study_name,
|
||||||
storage=storage,
|
storage=storage,
|
||||||
direction="minimize", # minimize mass
|
direction="minimize", # minimize mass
|
||||||
load_if_exists=False,
|
load_if_exists=True, # safe re-run: reuse if exists
|
||||||
sampler=optuna.samplers.TPESampler(seed=args.seed),
|
sampler=optuna.samplers.TPESampler(seed=args.seed),
|
||||||
)
|
)
|
||||||
logger.info("Created new study: %s (storage: %s)", args.study_name, db_path)
|
logger.info("Created new study: %s (storage: %s)", args.study_name, db_path)
|
||||||
@@ -622,6 +626,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Resume an existing study instead of creating a new one",
|
help="Resume an existing study instead of creating a new one",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clean",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete existing results DB before starting (fresh run)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose", "-v",
|
"--verbose", "-v",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -173,6 +173,16 @@ def _ensure_integer_coverage(
|
|||||||
logger.info("All 11 hole_count levels represented ✓")
|
logger.info("All 11 hole_count levels represented ✓")
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
# Skip patching when sample size is too small to cover all levels
|
||||||
|
n_samples = len(samples)
|
||||||
|
if n_samples < len(all_levels):
|
||||||
|
logger.info(
|
||||||
|
"Only %d samples — too few to cover all 11 hole_count levels "
|
||||||
|
"(need ≥11). Skipping stratified patching.",
|
||||||
|
n_samples,
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Missing hole_count levels: %s — patching with replacements",
|
"Missing hole_count levels: %s — patching with replacements",
|
||||||
sorted(missing_levels),
|
sorted(missing_levels),
|
||||||
|
|||||||
Reference in New Issue
Block a user