feat: Add TrialManager and DashboardDB for unified trial management
- Add TrialManager (trial_manager.py) for consistent trial_NNNN naming - Add DashboardDB (dashboard_db.py) for Optuna-compatible database schema - Update CLAUDE.md with trial management documentation - Update ATOMIZER_CONTEXT.md with v1.8 trial system - Update cheatsheet v2.2 with new utilities - Update SYS_14 protocol to v2.3 with TrialManager integration - Add LAC learnings for trial management patterns - Add archive/README.md for deprecated code policy Key principles: - Trial numbers NEVER reset (monotonic) - Folders NEVER get overwritten - Database always synced with filesystem - Surrogate predictions are NOT trials (only FEA results) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
574
optimization_engine/utils/dashboard_db.py
Normal file
574
optimization_engine/utils/dashboard_db.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
Dashboard Database Compatibility Module
|
||||
========================================
|
||||
|
||||
Provides Optuna-compatible database schema for all optimization types,
|
||||
ensuring dashboard compatibility regardless of optimization method
|
||||
(standard Optuna, turbo/surrogate, GNN, etc.)
|
||||
|
||||
Usage:
|
||||
from optimization_engine.utils.dashboard_db import DashboardDB
|
||||
|
||||
# Initialize (creates Optuna-compatible schema)
|
||||
db = DashboardDB(study_dir / "3_results" / "study.db", study_name="my_study")
|
||||
|
||||
# Log a trial
|
||||
db.log_trial(
|
||||
params={"rib_thickness": 10.5, "mass": 118.0},
|
||||
objectives={"wfe_40_20": 5.63, "wfe_60_20": 12.75},
|
||||
weighted_sum=175.87, # optional, for single-objective ranking
|
||||
is_feasible=True,
|
||||
metadata={"turbo_iteration": 1, "predicted_ws": 186.77}
|
||||
)
|
||||
|
||||
# Mark best trial
|
||||
db.mark_best(trial_id=1)
|
||||
|
||||
# Get summary
|
||||
print(db.get_summary())
|
||||
|
||||
Schema follows Optuna's native format for full dashboard compatibility.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
|
||||
class DashboardDB:
|
||||
"""Optuna-compatible database wrapper for dashboard integration."""
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
def __init__(self, db_path: Union[str, Path], study_name: str, direction: str = "MINIMIZE"):
|
||||
"""
|
||||
Initialize database with Optuna-compatible schema.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
study_name: Name of the optimization study
|
||||
direction: "MINIMIZE" or "MAXIMIZE"
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.study_name = study_name
|
||||
self.direction = direction
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self):
|
||||
"""Create Optuna-compatible database schema."""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Core Optuna tables
|
||||
|
||||
# version_info - tracks schema version
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS version_info (
|
||||
version_info_id INTEGER PRIMARY KEY,
|
||||
schema_version INTEGER,
|
||||
library_version VARCHAR(256)
|
||||
)
|
||||
''')
|
||||
|
||||
# Insert version if not exists
|
||||
cursor.execute("SELECT COUNT(*) FROM version_info")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
cursor.execute(
|
||||
"INSERT INTO version_info (schema_version, library_version) VALUES (?, ?)",
|
||||
(12, "atomizer-dashboard-1.0")
|
||||
)
|
||||
|
||||
# studies - Optuna study metadata
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS studies (
|
||||
study_id INTEGER PRIMARY KEY,
|
||||
study_name VARCHAR(512) UNIQUE
|
||||
)
|
||||
''')
|
||||
|
||||
# Insert study if not exists
|
||||
cursor.execute("SELECT study_id FROM studies WHERE study_name = ?", (self.study_name,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self.study_id = result[0]
|
||||
else:
|
||||
cursor.execute("INSERT INTO studies (study_name) VALUES (?)", (self.study_name,))
|
||||
self.study_id = cursor.lastrowid
|
||||
|
||||
# study_directions - optimization direction
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS study_directions (
|
||||
study_direction_id INTEGER PRIMARY KEY,
|
||||
direction VARCHAR(8) NOT NULL,
|
||||
study_id INTEGER,
|
||||
objective INTEGER,
|
||||
FOREIGN KEY (study_id) REFERENCES studies(study_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# Insert direction if not exists
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM study_directions WHERE study_id = ?",
|
||||
(self.study_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
cursor.execute(
|
||||
"INSERT INTO study_directions (direction, study_id, objective) VALUES (?, ?, ?)",
|
||||
(self.direction, self.study_id, 0)
|
||||
)
|
||||
|
||||
# trials - main trial table (Optuna schema)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trials (
|
||||
trial_id INTEGER PRIMARY KEY,
|
||||
number INTEGER,
|
||||
study_id INTEGER,
|
||||
state VARCHAR(8) NOT NULL DEFAULT 'COMPLETE',
|
||||
datetime_start DATETIME,
|
||||
datetime_complete DATETIME,
|
||||
FOREIGN KEY (study_id) REFERENCES studies(study_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_values - objective values
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_values (
|
||||
trial_value_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
objective INTEGER,
|
||||
value FLOAT,
|
||||
value_type VARCHAR(7) DEFAULT 'FINITE',
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_params - parameter values
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_params (
|
||||
param_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
param_name VARCHAR(512),
|
||||
param_value FLOAT,
|
||||
distribution_json TEXT,
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_user_attributes - custom metadata
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_user_attributes (
|
||||
trial_user_attribute_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
key VARCHAR(512),
|
||||
value_json TEXT,
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_system_attributes - system metadata
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_system_attributes (
|
||||
trial_system_attribute_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
key VARCHAR(512),
|
||||
value_json TEXT,
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# study_user_attributes
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS study_user_attributes (
|
||||
study_user_attribute_id INTEGER PRIMARY KEY,
|
||||
study_id INTEGER,
|
||||
key VARCHAR(512),
|
||||
value_json TEXT,
|
||||
FOREIGN KEY (study_id) REFERENCES studies(study_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# study_system_attributes
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS study_system_attributes (
|
||||
study_system_attribute_id INTEGER PRIMARY KEY,
|
||||
study_id INTEGER,
|
||||
key VARCHAR(512),
|
||||
value_json TEXT,
|
||||
FOREIGN KEY (study_id) REFERENCES studies(study_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_intermediate_values (for pruning callbacks)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_intermediate_values (
|
||||
trial_intermediate_value_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
step INTEGER,
|
||||
intermediate_value FLOAT,
|
||||
intermediate_value_type VARCHAR(7) DEFAULT 'FINITE',
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# trial_heartbeats (for distributed optimization)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trial_heartbeats (
|
||||
trial_heartbeat_id INTEGER PRIMARY KEY,
|
||||
trial_id INTEGER,
|
||||
heartbeat DATETIME,
|
||||
FOREIGN KEY (trial_id) REFERENCES trials(trial_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# alembic_version (Optuna uses alembic for migrations)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS alembic_version (
|
||||
version_num VARCHAR(32) PRIMARY KEY
|
||||
)
|
||||
''')
|
||||
cursor.execute("INSERT OR IGNORE INTO alembic_version VALUES ('v3.0.0')")
|
||||
|
||||
# Create indexes for performance
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS ix_trials_study_id ON trials(study_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS ix_trials_state ON trials(state)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS ix_trial_values_trial_id ON trial_values(trial_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS ix_trial_params_trial_id ON trial_params(trial_id)")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def log_trial(
|
||||
self,
|
||||
params: Dict[str, float],
|
||||
objectives: Dict[str, float],
|
||||
weighted_sum: Optional[float] = None,
|
||||
is_feasible: bool = True,
|
||||
state: str = "COMPLETE",
|
||||
datetime_start: Optional[str] = None,
|
||||
datetime_complete: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Log a trial to the database.
|
||||
|
||||
Args:
|
||||
params: Parameter name -> value mapping
|
||||
objectives: Objective name -> value mapping
|
||||
weighted_sum: Optional weighted sum for single-objective ranking
|
||||
is_feasible: Whether trial meets constraints
|
||||
state: Trial state ("COMPLETE", "PRUNED", "FAIL", "RUNNING")
|
||||
datetime_start: ISO format timestamp
|
||||
datetime_complete: ISO format timestamp
|
||||
metadata: Additional metadata (turbo_iteration, predicted values, etc.)
|
||||
|
||||
Returns:
|
||||
trial_id of inserted trial
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get next trial number
|
||||
cursor.execute(
|
||||
"SELECT COALESCE(MAX(number), -1) + 1 FROM trials WHERE study_id = ?",
|
||||
(self.study_id,)
|
||||
)
|
||||
trial_number = cursor.fetchone()[0]
|
||||
|
||||
# Default timestamps
|
||||
now = datetime.now().isoformat()
|
||||
dt_start = datetime_start or now
|
||||
dt_complete = datetime_complete or now
|
||||
|
||||
# Insert trial
|
||||
cursor.execute('''
|
||||
INSERT INTO trials (number, study_id, state, datetime_start, datetime_complete)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (trial_number, self.study_id, state, dt_start, dt_complete))
|
||||
trial_id = cursor.lastrowid
|
||||
|
||||
# Insert objective values
|
||||
# Use weighted_sum as primary objective if provided, else first objective value
|
||||
primary_value = weighted_sum if weighted_sum is not None else list(objectives.values())[0]
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_values (trial_id, objective, value, value_type)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (trial_id, 0, primary_value, 'FINITE'))
|
||||
|
||||
# Insert all objectives as user attributes
|
||||
for obj_name, obj_value in objectives.items():
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_user_attributes (trial_id, key, value_json)
|
||||
VALUES (?, ?, ?)
|
||||
''', (trial_id, f"obj_{obj_name}", json.dumps(obj_value)))
|
||||
|
||||
# Insert parameters
|
||||
for param_name, param_value in params.items():
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_params (trial_id, param_name, param_value, distribution_json)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (trial_id, param_name, param_value, '{}'))
|
||||
|
||||
# Insert feasibility as user attribute
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_user_attributes (trial_id, key, value_json)
|
||||
VALUES (?, ?, ?)
|
||||
''', (trial_id, 'is_feasible', json.dumps(is_feasible)))
|
||||
|
||||
# Insert metadata
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_user_attributes (trial_id, key, value_json)
|
||||
VALUES (?, ?, ?)
|
||||
''', (trial_id, key, json.dumps(value)))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return trial_id
|
||||
|
||||
def mark_best(self, trial_id: int):
|
||||
"""Mark a trial as the best (adds user attribute)."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Remove previous best markers
|
||||
cursor.execute('''
|
||||
DELETE FROM trial_user_attributes
|
||||
WHERE key = 'is_best' AND trial_id IN (
|
||||
SELECT trial_id FROM trials WHERE study_id = ?
|
||||
)
|
||||
''', (self.study_id,))
|
||||
|
||||
# Mark new best
|
||||
cursor.execute('''
|
||||
INSERT INTO trial_user_attributes (trial_id, key, value_json)
|
||||
VALUES (?, 'is_best', 'true')
|
||||
''', (trial_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_trial_count(self, state: str = "COMPLETE") -> int:
|
||||
"""Get count of trials in given state."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM trials WHERE study_id = ? AND state = ?",
|
||||
(self.study_id, state)
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
return count
|
||||
|
||||
def get_best_trial(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get best trial (lowest objective value)."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT t.trial_id, t.number, tv.value
|
||||
FROM trials t
|
||||
JOIN trial_values tv ON t.trial_id = tv.trial_id
|
||||
WHERE t.study_id = ? AND t.state = 'COMPLETE'
|
||||
ORDER BY tv.value ASC
|
||||
LIMIT 1
|
||||
''', (self.study_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return {
|
||||
'trial_id': result[0],
|
||||
'number': result[1],
|
||||
'value': result[2]
|
||||
}
|
||||
return None
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get database summary for logging."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM trials WHERE study_id = ? AND state = 'COMPLETE'",
|
||||
(self.study_id,)
|
||||
)
|
||||
complete = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM trials WHERE study_id = ? AND state = 'PRUNED'",
|
||||
(self.study_id,)
|
||||
)
|
||||
pruned = cursor.fetchone()[0]
|
||||
|
||||
best = self.get_best_trial()
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'study_name': self.study_name,
|
||||
'complete_trials': complete,
|
||||
'pruned_trials': pruned,
|
||||
'best_value': best['value'] if best else None,
|
||||
'best_trial': best['number'] if best else None,
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""Clear all trials (for re-running)."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM trial_user_attributes WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trial_system_attributes WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trial_values WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trial_params WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trial_intermediate_values WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trial_heartbeats WHERE trial_id IN (SELECT trial_id FROM trials WHERE study_id = ?)", (self.study_id,))
|
||||
cursor.execute("DELETE FROM trials WHERE study_id = ?", (self.study_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def convert_custom_to_optuna(
|
||||
db_path: Union[str, Path],
|
||||
study_name: str,
|
||||
custom_table: str = "trials",
|
||||
param_columns: Optional[List[str]] = None,
|
||||
objective_column: str = "weighted_sum",
|
||||
status_column: str = "status",
|
||||
datetime_column: str = "datetime_complete",
|
||||
) -> int:
|
||||
"""
|
||||
Convert a custom database schema to Optuna-compatible format.
|
||||
|
||||
Args:
|
||||
db_path: Path to database
|
||||
study_name: Name for the study
|
||||
custom_table: Name of custom trials table to convert
|
||||
param_columns: List of parameter column names (auto-detect if None)
|
||||
objective_column: Column containing objective value
|
||||
status_column: Column containing trial status
|
||||
datetime_column: Column containing timestamp
|
||||
|
||||
Returns:
|
||||
Number of trials converted
|
||||
"""
|
||||
db_path = Path(db_path)
|
||||
backup_path = db_path.with_suffix('.db.bak')
|
||||
|
||||
# Backup original
|
||||
import shutil
|
||||
shutil.copy(db_path, backup_path)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if custom table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(custom_table,)
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
conn.close()
|
||||
raise ValueError(f"Table '{custom_table}' not found")
|
||||
|
||||
# Get column info
|
||||
cursor.execute(f"PRAGMA table_info({custom_table})")
|
||||
columns = {row[1]: row[2] for row in cursor.fetchall()}
|
||||
|
||||
# Read all custom trials
|
||||
cursor.execute(f"SELECT * FROM {custom_table}")
|
||||
custom_trials = cursor.fetchall()
|
||||
|
||||
# Get column names
|
||||
cursor.execute(f"PRAGMA table_info({custom_table})")
|
||||
col_names = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
# Drop ALL existing tables to start fresh
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
existing_tables = [row[0] for row in cursor.fetchall()]
|
||||
for table in existing_tables:
|
||||
if table != 'sqlite_sequence': # Don't drop internal SQLite table
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Now create proper Optuna schema from scratch
|
||||
db = DashboardDB(db_path, study_name)
|
||||
|
||||
converted = 0
|
||||
for row in custom_trials:
|
||||
trial_data = dict(zip(col_names, row))
|
||||
|
||||
# Extract params from JSON if available
|
||||
params = {}
|
||||
if 'params_json' in trial_data and trial_data['params_json']:
|
||||
try:
|
||||
params = json.loads(trial_data['params_json'])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Extract objectives from JSON if available
|
||||
objectives = {}
|
||||
if 'objectives_json' in trial_data and trial_data['objectives_json']:
|
||||
try:
|
||||
objectives = json.loads(trial_data['objectives_json'])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get weighted sum
|
||||
weighted_sum = trial_data.get(objective_column)
|
||||
|
||||
# Map status to state
|
||||
status = trial_data.get(status_column, 'COMPLETE')
|
||||
state = 'COMPLETE' if status.upper() in ('COMPLETE', 'COMPLETED') else status.upper()
|
||||
|
||||
# Get feasibility
|
||||
is_feasible = bool(trial_data.get('is_feasible', 1))
|
||||
|
||||
# Build metadata
|
||||
metadata = {}
|
||||
for key in ['turbo_iteration', 'predicted_ws', 'prediction_error', 'solve_time']:
|
||||
if key in trial_data and trial_data[key] is not None:
|
||||
metadata[key] = trial_data[key]
|
||||
|
||||
# Log trial
|
||||
db.log_trial(
|
||||
params=params,
|
||||
objectives=objectives,
|
||||
weighted_sum=weighted_sum,
|
||||
is_feasible=is_feasible,
|
||||
state=state,
|
||||
datetime_start=trial_data.get('datetime_start'),
|
||||
datetime_complete=trial_data.get(datetime_column),
|
||||
metadata=metadata,
|
||||
)
|
||||
converted += 1
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
# Convenience function for turbo optimization
|
||||
def init_turbo_database(study_dir: Path, study_name: str) -> DashboardDB:
|
||||
"""
|
||||
Initialize a dashboard-compatible database for turbo optimization.
|
||||
|
||||
Args:
|
||||
study_dir: Study directory (contains 3_results/)
|
||||
study_name: Name of the study
|
||||
|
||||
Returns:
|
||||
DashboardDB instance ready for logging
|
||||
"""
|
||||
results_dir = study_dir / "3_results"
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_path = results_dir / "study.db"
|
||||
|
||||
return DashboardDB(db_path, study_name)
|
||||
292
optimization_engine/utils/trial_manager.py
Normal file
292
optimization_engine/utils/trial_manager.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Trial Manager - Unified trial numbering and folder management
|
||||
==============================================================
|
||||
|
||||
Provides consistent trial_NNNN naming across all optimization methods
|
||||
(Optuna, Turbo, GNN, manual) with proper database integration.
|
||||
|
||||
Usage:
|
||||
from optimization_engine.utils.trial_manager import TrialManager
|
||||
|
||||
tm = TrialManager(study_dir)
|
||||
|
||||
# Get next trial (creates folder, reserves DB row)
|
||||
trial = tm.new_trial(params={'rib_thickness': 10.5, ...})
|
||||
|
||||
# After FEA completes
|
||||
tm.complete_trial(
|
||||
trial_id=trial['trial_id'],
|
||||
objectives={'wfe_40_20': 5.63, 'mass_kg': 118.67},
|
||||
metadata={'solve_time': 211.7}
|
||||
)
|
||||
|
||||
Key principles:
|
||||
- Trial numbers NEVER reset (monotonically increasing)
|
||||
- Folders NEVER get overwritten
|
||||
- Database is always in sync with filesystem
|
||||
- Surrogate predictions are NOT trials (only FEA results)
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from filelock import FileLock
|
||||
|
||||
from .dashboard_db import DashboardDB
|
||||
|
||||
|
||||
class TrialManager:
|
||||
"""Manages trial numbering, folders, and database for optimization studies."""
|
||||
|
||||
def __init__(self, study_dir: Union[str, Path], study_name: Optional[str] = None):
|
||||
"""
|
||||
Initialize trial manager for a study.
|
||||
|
||||
Args:
|
||||
study_dir: Path to study directory (contains 1_setup/, 2_iterations/, 3_results/)
|
||||
study_name: Name of study (defaults to directory name)
|
||||
"""
|
||||
self.study_dir = Path(study_dir)
|
||||
self.study_name = study_name or self.study_dir.name
|
||||
|
||||
self.iterations_dir = self.study_dir / "2_iterations"
|
||||
self.results_dir = self.study_dir / "3_results"
|
||||
self.db_path = self.results_dir / "study.db"
|
||||
self.lock_path = self.results_dir / ".trial_lock"
|
||||
|
||||
# Ensure directories exist
|
||||
self.iterations_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize database
|
||||
self.db = DashboardDB(self.db_path, self.study_name)
|
||||
|
||||
def _get_next_trial_number(self) -> int:
|
||||
"""Get next available trial number (never resets)."""
|
||||
# Check filesystem
|
||||
existing_folders = list(self.iterations_dir.glob("trial_*"))
|
||||
max_folder = 0
|
||||
for folder in existing_folders:
|
||||
try:
|
||||
num = int(folder.name.split('_')[1])
|
||||
max_folder = max(max_folder, num)
|
||||
except (IndexError, ValueError):
|
||||
continue
|
||||
|
||||
# Check database
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COALESCE(MAX(number), -1) + 1 FROM trials")
|
||||
max_db = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
# Return max of both + 1 (use 1-based for folders, 0-based for DB)
|
||||
return max(max_folder, max_db) + 1
|
||||
|
||||
def new_trial(
|
||||
self,
|
||||
params: Dict[str, float],
|
||||
source: str = "turbo",
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a new trial - creates folder and reserves DB row.
|
||||
|
||||
Args:
|
||||
params: Design parameters for this trial
|
||||
source: How this trial was generated ("turbo", "optuna", "manual")
|
||||
metadata: Additional info (turbo_batch, predicted_ws, etc.)
|
||||
|
||||
Returns:
|
||||
Dict with trial_id, trial_number, folder_path
|
||||
"""
|
||||
# Use file lock to prevent race conditions
|
||||
with FileLock(self.lock_path):
|
||||
trial_number = self._get_next_trial_number()
|
||||
|
||||
# Create folder with zero-padded name
|
||||
folder_name = f"trial_{trial_number:04d}"
|
||||
folder_path = self.iterations_dir / folder_name
|
||||
folder_path.mkdir(exist_ok=True)
|
||||
|
||||
# Save params to folder
|
||||
params_file = folder_path / "params.json"
|
||||
with open(params_file, 'w') as f:
|
||||
json.dump(params, f, indent=2)
|
||||
|
||||
# Also save as .exp format for NX compatibility
|
||||
exp_file = folder_path / "params.exp"
|
||||
with open(exp_file, 'w') as f:
|
||||
for name, value in params.items():
|
||||
f.write(f"[mm]{name}={value}\n")
|
||||
|
||||
# Save metadata
|
||||
meta = {
|
||||
"trial_number": trial_number,
|
||||
"source": source,
|
||||
"status": "RUNNING",
|
||||
"datetime_start": datetime.now().isoformat(),
|
||||
"params": params,
|
||||
}
|
||||
if metadata:
|
||||
meta.update(metadata)
|
||||
|
||||
meta_file = folder_path / "_meta.json"
|
||||
with open(meta_file, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return {
|
||||
"trial_id": trial_number, # Will be updated after DB insert
|
||||
"trial_number": trial_number,
|
||||
"folder_path": folder_path,
|
||||
"folder_name": folder_name,
|
||||
}
|
||||
|
||||
def complete_trial(
|
||||
self,
|
||||
trial_number: int,
|
||||
objectives: Dict[str, float],
|
||||
weighted_sum: Optional[float] = None,
|
||||
is_feasible: bool = True,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Complete a trial - logs to database and updates folder metadata.
|
||||
|
||||
Args:
|
||||
trial_number: Trial number from new_trial()
|
||||
objectives: Objective values from FEA
|
||||
weighted_sum: Combined objective for ranking
|
||||
is_feasible: Whether constraints are satisfied
|
||||
metadata: Additional info (solve_time, prediction_error, etc.)
|
||||
|
||||
Returns:
|
||||
Database trial_id
|
||||
"""
|
||||
folder_path = self.iterations_dir / f"trial_{trial_number:04d}"
|
||||
|
||||
# Load existing metadata
|
||||
meta_file = folder_path / "_meta.json"
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
params = meta.get("params", {})
|
||||
|
||||
# Update metadata
|
||||
meta["status"] = "COMPLETE"
|
||||
meta["datetime_complete"] = datetime.now().isoformat()
|
||||
meta["objectives"] = objectives
|
||||
meta["weighted_sum"] = weighted_sum
|
||||
meta["is_feasible"] = is_feasible
|
||||
if metadata:
|
||||
meta.update(metadata)
|
||||
|
||||
# Save results.json
|
||||
results_file = folder_path / "results.json"
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump({
|
||||
"objectives": objectives,
|
||||
"weighted_sum": weighted_sum,
|
||||
"is_feasible": is_feasible,
|
||||
"metadata": metadata or {}
|
||||
}, f, indent=2)
|
||||
|
||||
# Update _meta.json
|
||||
with open(meta_file, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Log to database
|
||||
db_metadata = metadata or {}
|
||||
db_metadata["source"] = meta.get("source", "unknown")
|
||||
if "turbo_batch" in meta:
|
||||
db_metadata["turbo_batch"] = meta["turbo_batch"]
|
||||
if "predicted_ws" in meta:
|
||||
db_metadata["predicted_ws"] = meta["predicted_ws"]
|
||||
|
||||
trial_id = self.db.log_trial(
|
||||
params=params,
|
||||
objectives=objectives,
|
||||
weighted_sum=weighted_sum,
|
||||
is_feasible=is_feasible,
|
||||
state="COMPLETE",
|
||||
datetime_start=meta.get("datetime_start"),
|
||||
datetime_complete=meta.get("datetime_complete"),
|
||||
metadata=db_metadata,
|
||||
)
|
||||
|
||||
# Check if this is the new best
|
||||
best = self.db.get_best_trial()
|
||||
if best and best['trial_id'] == trial_id:
|
||||
self.db.mark_best(trial_id)
|
||||
meta["is_best"] = True
|
||||
with open(meta_file, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return trial_id
|
||||
|
||||
def fail_trial(self, trial_number: int, error: str):
|
||||
"""Mark a trial as failed."""
|
||||
folder_path = self.iterations_dir / f"trial_{trial_number:04d}"
|
||||
meta_file = folder_path / "_meta.json"
|
||||
|
||||
if meta_file.exists():
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
meta["status"] = "FAIL"
|
||||
meta["error"] = error
|
||||
meta["datetime_complete"] = datetime.now().isoformat()
|
||||
with open(meta_file, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
def get_trial_folder(self, trial_number: int) -> Path:
|
||||
"""Get folder path for a trial number."""
|
||||
return self.iterations_dir / f"trial_{trial_number:04d}"
|
||||
|
||||
def get_all_trials(self) -> List[Dict[str, Any]]:
|
||||
"""Get all completed trials from database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT t.trial_id, t.number, tv.value
|
||||
FROM trials t
|
||||
JOIN trial_values tv ON t.trial_id = tv.trial_id
|
||||
WHERE t.state = 'COMPLETE'
|
||||
ORDER BY t.number
|
||||
""")
|
||||
|
||||
trials = []
|
||||
for row in cursor.fetchall():
|
||||
trials.append({
|
||||
"trial_id": row[0],
|
||||
"number": row[1],
|
||||
"value": row[2]
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return trials
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get trial manager summary."""
|
||||
summary = self.db.get_summary()
|
||||
|
||||
# Add folder count
|
||||
folders = list(self.iterations_dir.glob("trial_*"))
|
||||
summary["folder_count"] = len(folders)
|
||||
|
||||
return summary
|
||||
|
||||
def copy_model_files(self, source_dir: Path, trial_number: int) -> Path:
|
||||
"""Copy NX model files to trial folder."""
|
||||
dest = self.get_trial_folder(trial_number)
|
||||
|
||||
# Copy relevant files
|
||||
extensions = ['.prt', '.fem', '.sim', '.afm', '.op2', '.f06', '.dat']
|
||||
for ext in extensions:
|
||||
for src_file in source_dir.glob(f"*{ext}"):
|
||||
shutil.copy2(src_file, dest / src_file.name)
|
||||
|
||||
return dest
|
||||
Reference in New Issue
Block a user