Files

557 lines
18 KiB
Python
Raw Permalink Normal View History

"""
Interview State Management
This module handles the persistence and management of interview state across sessions.
It provides:
- InterviewState: Complete state dataclass
- InterviewPhase: Enum for interview phases
- InterviewStateManager: Save/load/history functionality
- LogEntry: Audit log entries
"""
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Dict, List, Any, Optional, Literal
import json
import uuid
import shutil
import os
class InterviewPhase(Enum):
"""Interview phases in order of progression."""
INTROSPECTION = "introspection"
PROBLEM_DEFINITION = "problem_definition"
OBJECTIVES = "objectives"
CONSTRAINTS = "constraints"
DESIGN_VARIABLES = "design_variables"
VALIDATION = "validation"
REVIEW = "review"
COMPLETE = "complete"
@classmethod
def from_string(cls, s: str) -> "InterviewPhase":
"""Convert string to enum."""
for phase in cls:
if phase.value == s:
return phase
raise ValueError(f"Unknown phase: {s}")
def next_phase(self) -> Optional["InterviewPhase"]:
"""Get the next phase in sequence."""
phases = list(InterviewPhase)
idx = phases.index(self)
if idx < len(phases) - 1:
return phases[idx + 1]
return None
def previous_phase(self) -> Optional["InterviewPhase"]:
"""Get the previous phase in sequence."""
phases = list(InterviewPhase)
idx = phases.index(self)
if idx > 0:
return phases[idx - 1]
return None
@dataclass
class AnsweredQuestion:
"""Record of an answered question."""
question_id: str
answered_at: str # ISO datetime
raw_response: str
parsed_value: Any
inferred: Optional[Dict[str, Any]] = None # What was inferred from answer
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"question_id": self.question_id,
"answered_at": self.answered_at,
"raw_response": self.raw_response,
"parsed_value": self.parsed_value,
"inferred": self.inferred,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AnsweredQuestion":
"""Create from dictionary."""
return cls(
question_id=data["question_id"],
answered_at=data["answered_at"],
raw_response=data["raw_response"],
parsed_value=data["parsed_value"],
inferred=data.get("inferred"),
)
@dataclass
class LogEntry:
"""Entry for the human-readable audit log."""
timestamp: datetime
question_id: str
question_text: str
answer_raw: str
answer_parsed: Any
inferred: Optional[Dict[str, Any]] = None
warnings: Optional[List[str]] = None
def to_markdown(self) -> str:
"""Format as markdown for audit log."""
lines = [
f"## [{self.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Question: {self.question_id}",
"",
f"**Question**: {self.question_text}",
"",
f"**Answer**: {self.answer_raw}",
"",
]
if self.answer_parsed != self.answer_raw:
lines.extend([
f"**Parsed Value**: `{self.answer_parsed}`",
"",
])
if self.inferred:
lines.append("**Inferred**:")
for key, value in self.inferred.items():
lines.append(f"- {key}: {value}")
lines.append("")
if self.warnings:
lines.append("**Warnings**:")
for warning in self.warnings:
lines.append(f"- {warning}")
lines.append("")
lines.append("---")
lines.append("")
return "\n".join(lines)
@dataclass
class InterviewState:
"""
Complete interview state (JSON-serializable).
This dataclass holds all state needed to resume an interview,
including introspection results, answers, and derived configuration.
"""
version: str = "1.0"
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
study_name: str = ""
study_path: str = ""
parent_study: Optional[str] = None
# Progress tracking
started_at: str = field(default_factory=lambda: datetime.now().isoformat())
last_updated: str = field(default_factory=lambda: datetime.now().isoformat())
current_phase: str = InterviewPhase.INTROSPECTION.value
complexity: Literal["simple", "moderate", "complex"] = "simple"
# Question tracking
questions_answered: List[Dict[str, Any]] = field(default_factory=list)
questions_remaining: List[str] = field(default_factory=list)
current_question_id: Optional[str] = None
# Introspection cache
introspection: Dict[str, Any] = field(default_factory=dict)
# Collected answers (organized by category)
answers: Dict[str, Any] = field(default_factory=lambda: {
"problem_description": None,
"physical_context": None,
"analysis_types": [],
"objectives": [],
"constraints": [],
"design_variables": [],
"protocol": None,
"n_trials": 100,
"use_neural_acceleration": False,
})
# Derived/inferred configuration
inferred_config: Dict[str, Any] = field(default_factory=dict)
# Validation results
warnings: List[str] = field(default_factory=list)
warnings_acknowledged: List[str] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
# Blueprint (when complete)
blueprint: Optional[Dict[str, Any]] = None
def get_phase(self) -> InterviewPhase:
"""Get current phase as enum."""
return InterviewPhase.from_string(self.current_phase)
def set_phase(self, phase: InterviewPhase) -> None:
"""Set current phase."""
self.current_phase = phase.value
self.touch()
def touch(self) -> None:
"""Update last_updated timestamp."""
self.last_updated = datetime.now().isoformat()
def is_complete(self) -> bool:
"""Check if interview is complete."""
return self.current_phase == InterviewPhase.COMPLETE.value
def current_question_count(self) -> int:
"""Get number of questions answered."""
return len(self.questions_answered)
def progress_percentage(self) -> float:
"""
Estimate progress through interview.
Based on phase, not questions, since questions are adaptive.
"""
phases = list(InterviewPhase)
current_idx = phases.index(self.get_phase())
return (current_idx / (len(phases) - 1)) * 100
def add_answered_question(self, question: AnsweredQuestion) -> None:
"""Record a question as answered."""
self.questions_answered.append(question.to_dict())
if question.question_id in self.questions_remaining:
self.questions_remaining.remove(question.question_id)
self.touch()
def get_answer(self, key: str, default: Any = None) -> Any:
"""Get an answer by key."""
return self.answers.get(key, default)
def set_answer(self, key: str, value: Any) -> None:
"""Set an answer."""
self.answers[key] = value
self.touch()
def add_warning(self, warning: str) -> None:
"""Add a warning message."""
if warning not in self.warnings:
self.warnings.append(warning)
self.touch()
def acknowledge_warning(self, warning: str) -> None:
"""Mark a warning as acknowledged."""
if warning in self.warnings and warning not in self.warnings_acknowledged:
self.warnings_acknowledged.append(warning)
self.touch()
def has_unacknowledged_errors(self) -> bool:
"""Check if there are blocking errors."""
return len(self.errors) > 0
def has_unacknowledged_warnings(self) -> bool:
"""Check if there are unacknowledged warnings."""
return any(w not in self.warnings_acknowledged for w in self.warnings)
def to_json(self) -> str:
"""Serialize to JSON string."""
return json.dumps(asdict(self), indent=2, default=str)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_json(cls, json_str: str) -> "InterviewState":
"""Deserialize from JSON string."""
data = json.loads(json_str)
return cls.from_dict(data)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InterviewState":
"""Create from dictionary."""
# Handle nested types
return cls(
version=data.get("version", "1.0"),
session_id=data.get("session_id", str(uuid.uuid4())),
study_name=data.get("study_name", ""),
study_path=data.get("study_path", ""),
parent_study=data.get("parent_study"),
started_at=data.get("started_at", datetime.now().isoformat()),
last_updated=data.get("last_updated", datetime.now().isoformat()),
current_phase=data.get("current_phase", InterviewPhase.INTROSPECTION.value),
complexity=data.get("complexity", "simple"),
questions_answered=data.get("questions_answered", []),
questions_remaining=data.get("questions_remaining", []),
current_question_id=data.get("current_question_id"),
introspection=data.get("introspection", {}),
answers=data.get("answers", {}),
inferred_config=data.get("inferred_config", {}),
warnings=data.get("warnings", []),
warnings_acknowledged=data.get("warnings_acknowledged", []),
errors=data.get("errors", []),
blueprint=data.get("blueprint"),
)
def validate(self) -> List[str]:
"""Validate state, return list of errors."""
errors = []
if not self.session_id:
errors.append("Missing session_id")
if not self.study_name:
errors.append("Missing study_name")
try:
InterviewPhase.from_string(self.current_phase)
except ValueError:
errors.append(f"Invalid current_phase: {self.current_phase}")
if self.complexity not in ["simple", "moderate", "complex"]:
errors.append(f"Invalid complexity: {self.complexity}")
return errors
@dataclass
class StateSnapshot:
"""Snapshot of state for history/undo."""
timestamp: str
phase: str
questions_count: int
state_hash: str
file_path: str
class InterviewStateManager:
"""
Manages interview state persistence.
Handles:
- Save/load state to JSON
- Human-readable audit log (MD)
- State backup rotation
- History for undo/branch
"""
MAX_BACKUPS = 5
def __init__(self, study_path: Path):
"""
Initialize state manager.
Args:
study_path: Path to the study directory
"""
self.study_path = Path(study_path)
self.interview_dir = self.study_path / ".interview"
self.state_file = self.interview_dir / "interview_state.json"
self.log_file = self.interview_dir / "INTERVIEW_LOG.md"
self.backup_dir = self.interview_dir / "backups"
self.lock_file = self.interview_dir / ".lock"
# Ensure directories exist
self._ensure_directories()
def _ensure_directories(self) -> None:
"""Create necessary directories if they don't exist."""
self.interview_dir.mkdir(parents=True, exist_ok=True)
self.backup_dir.mkdir(exist_ok=True)
def _acquire_lock(self) -> bool:
"""Acquire lock file for concurrent access prevention."""
try:
if self.lock_file.exists():
# Check if lock is stale (older than 5 minutes)
mtime = self.lock_file.stat().st_mtime
age = datetime.now().timestamp() - mtime
if age > 300: # 5 minutes
self.lock_file.unlink()
else:
return False
self.lock_file.write_text(str(os.getpid()))
return True
except Exception:
return False
def _release_lock(self) -> None:
"""Release lock file."""
try:
if self.lock_file.exists():
self.lock_file.unlink()
except Exception:
pass
def exists(self) -> bool:
"""Check if a saved state exists."""
return self.state_file.exists()
def save_state(self, state: InterviewState) -> None:
"""
Persist current state to JSON.
Performs atomic write with backup rotation.
"""
if not self._acquire_lock():
raise RuntimeError("Could not acquire lock for state file")
try:
# Update timestamp
state.touch()
# Create backup if state file exists
if self.state_file.exists():
self._rotate_backups()
backup_name = f"state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
shutil.copy(self.state_file, self.backup_dir / backup_name)
# Atomic write: write to temp file then rename
temp_file = self.state_file.with_suffix(".tmp")
temp_file.write_text(state.to_json(), encoding="utf-8")
temp_file.replace(self.state_file)
finally:
self._release_lock()
def _rotate_backups(self) -> None:
"""Keep only the most recent backups."""
backups = sorted(
self.backup_dir.glob("state_*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True
)
# Remove old backups
for backup in backups[self.MAX_BACKUPS:]:
backup.unlink()
def load_state(self) -> Optional[InterviewState]:
"""
Load existing state if available.
Returns:
InterviewState if exists and valid, None otherwise
"""
if not self.state_file.exists():
return None
try:
json_str = self.state_file.read_text(encoding="utf-8")
state = InterviewState.from_json(json_str)
# Validate state
errors = state.validate()
if errors:
raise ValueError(f"Invalid state: {errors}")
return state
except (json.JSONDecodeError, ValueError) as e:
# Log error but don't crash
print(f"Warning: Could not load interview state: {e}")
return None
def append_log(self, entry: LogEntry) -> None:
"""
Add entry to human-readable audit log.
Creates log file with header if it doesn't exist.
"""
# Initialize log file if needed
if not self.log_file.exists():
header = self._create_log_header()
self.log_file.write_text(header, encoding="utf-8")
# Append entry
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(entry.to_markdown())
def _create_log_header(self) -> str:
"""Create header for new log file."""
return f"""# Interview Log
**Study**: {self.study_path.name}
**Started**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
This log records all questions and answers from the study interview process.
---
"""
def finalize_log(self, state: InterviewState) -> None:
"""Add final summary to log when interview completes."""
summary = f"""
## Interview Complete
**Completed**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Questions Answered**: {len(state.questions_answered)}
**Complexity**: {state.complexity}
### Summary
- **Problem**: {state.answers.get('problem_description', 'N/A')}
- **Objectives**: {len(state.answers.get('objectives', []))}
- **Constraints**: {len(state.answers.get('constraints', []))}
- **Design Variables**: {len(state.answers.get('design_variables', []))}
### Warnings Acknowledged
"""
for warning in state.warnings_acknowledged:
summary += f"- {warning}\n"
if not state.warnings_acknowledged:
summary += "- None\n"
summary += "\n---\n"
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(summary)
def get_history(self) -> List[StateSnapshot]:
"""
Get modification history for undo/branch.
Returns list of state snapshots from backups.
"""
snapshots = []
for backup in sorted(self.backup_dir.glob("state_*.json")):
try:
data = json.loads(backup.read_text(encoding="utf-8"))
snapshot = StateSnapshot(
timestamp=data.get("last_updated", "unknown"),
phase=data.get("current_phase", "unknown"),
questions_count=len(data.get("questions_answered", [])),
state_hash=str(hash(backup.read_text())),
file_path=str(backup),
)
snapshots.append(snapshot)
except Exception:
continue
return snapshots
def restore_from_backup(self, backup_path: str) -> Optional[InterviewState]:
"""Restore state from a backup file."""
backup = Path(backup_path)
if not backup.exists():
return None
try:
json_str = backup.read_text(encoding="utf-8")
return InterviewState.from_json(json_str)
except Exception:
return None
def delete_state(self) -> None:
"""Delete all interview state (for restart)."""
if self.state_file.exists():
self.state_file.unlink()
# Keep log file but add note
if self.log_file.exists():
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(f"\n## State Reset\n\n**Reset at**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n---\n\n")