""" Interview State Management This module handles the persistence and management of interview state across sessions. It provides: - InterviewState: Complete state dataclass - InterviewPhase: Enum for interview phases - InterviewStateManager: Save/load/history functionality - LogEntry: Audit log entries """ from dataclasses import dataclass, field, asdict from datetime import datetime from enum import Enum from pathlib import Path from typing import Dict, List, Any, Optional, Literal import json import uuid import shutil import os class InterviewPhase(Enum): """Interview phases in order of progression.""" INTROSPECTION = "introspection" PROBLEM_DEFINITION = "problem_definition" OBJECTIVES = "objectives" CONSTRAINTS = "constraints" DESIGN_VARIABLES = "design_variables" VALIDATION = "validation" REVIEW = "review" COMPLETE = "complete" @classmethod def from_string(cls, s: str) -> "InterviewPhase": """Convert string to enum.""" for phase in cls: if phase.value == s: return phase raise ValueError(f"Unknown phase: {s}") def next_phase(self) -> Optional["InterviewPhase"]: """Get the next phase in sequence.""" phases = list(InterviewPhase) idx = phases.index(self) if idx < len(phases) - 1: return phases[idx + 1] return None def previous_phase(self) -> Optional["InterviewPhase"]: """Get the previous phase in sequence.""" phases = list(InterviewPhase) idx = phases.index(self) if idx > 0: return phases[idx - 1] return None @dataclass class AnsweredQuestion: """Record of an answered question.""" question_id: str answered_at: str # ISO datetime raw_response: str parsed_value: Any inferred: Optional[Dict[str, Any]] = None # What was inferred from answer def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { "question_id": self.question_id, "answered_at": self.answered_at, "raw_response": self.raw_response, "parsed_value": self.parsed_value, "inferred": self.inferred, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AnsweredQuestion": """Create from dictionary.""" return cls( question_id=data["question_id"], answered_at=data["answered_at"], raw_response=data["raw_response"], parsed_value=data["parsed_value"], inferred=data.get("inferred"), ) @dataclass class LogEntry: """Entry for the human-readable audit log.""" timestamp: datetime question_id: str question_text: str answer_raw: str answer_parsed: Any inferred: Optional[Dict[str, Any]] = None warnings: Optional[List[str]] = None def to_markdown(self) -> str: """Format as markdown for audit log.""" lines = [ f"## [{self.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Question: {self.question_id}", "", f"**Question**: {self.question_text}", "", f"**Answer**: {self.answer_raw}", "", ] if self.answer_parsed != self.answer_raw: lines.extend([ f"**Parsed Value**: `{self.answer_parsed}`", "", ]) if self.inferred: lines.append("**Inferred**:") for key, value in self.inferred.items(): lines.append(f"- {key}: {value}") lines.append("") if self.warnings: lines.append("**Warnings**:") for warning in self.warnings: lines.append(f"- {warning}") lines.append("") lines.append("---") lines.append("") return "\n".join(lines) @dataclass class InterviewState: """ Complete interview state (JSON-serializable). This dataclass holds all state needed to resume an interview, including introspection results, answers, and derived configuration. """ version: str = "1.0" session_id: str = field(default_factory=lambda: str(uuid.uuid4())) study_name: str = "" study_path: str = "" parent_study: Optional[str] = None # Progress tracking started_at: str = field(default_factory=lambda: datetime.now().isoformat()) last_updated: str = field(default_factory=lambda: datetime.now().isoformat()) current_phase: str = InterviewPhase.INTROSPECTION.value complexity: Literal["simple", "moderate", "complex"] = "simple" # Question tracking questions_answered: List[Dict[str, Any]] = field(default_factory=list) questions_remaining: List[str] = field(default_factory=list) current_question_id: Optional[str] = None # Introspection cache introspection: Dict[str, Any] = field(default_factory=dict) # Collected answers (organized by category) answers: Dict[str, Any] = field(default_factory=lambda: { "problem_description": None, "physical_context": None, "analysis_types": [], "objectives": [], "constraints": [], "design_variables": [], "protocol": None, "n_trials": 100, "use_neural_acceleration": False, }) # Derived/inferred configuration inferred_config: Dict[str, Any] = field(default_factory=dict) # Validation results warnings: List[str] = field(default_factory=list) warnings_acknowledged: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list) # Blueprint (when complete) blueprint: Optional[Dict[str, Any]] = None def get_phase(self) -> InterviewPhase: """Get current phase as enum.""" return InterviewPhase.from_string(self.current_phase) def set_phase(self, phase: InterviewPhase) -> None: """Set current phase.""" self.current_phase = phase.value self.touch() def touch(self) -> None: """Update last_updated timestamp.""" self.last_updated = datetime.now().isoformat() def is_complete(self) -> bool: """Check if interview is complete.""" return self.current_phase == InterviewPhase.COMPLETE.value def current_question_count(self) -> int: """Get number of questions answered.""" return len(self.questions_answered) def progress_percentage(self) -> float: """ Estimate progress through interview. Based on phase, not questions, since questions are adaptive. """ phases = list(InterviewPhase) current_idx = phases.index(self.get_phase()) return (current_idx / (len(phases) - 1)) * 100 def add_answered_question(self, question: AnsweredQuestion) -> None: """Record a question as answered.""" self.questions_answered.append(question.to_dict()) if question.question_id in self.questions_remaining: self.questions_remaining.remove(question.question_id) self.touch() def get_answer(self, key: str, default: Any = None) -> Any: """Get an answer by key.""" return self.answers.get(key, default) def set_answer(self, key: str, value: Any) -> None: """Set an answer.""" self.answers[key] = value self.touch() def add_warning(self, warning: str) -> None: """Add a warning message.""" if warning not in self.warnings: self.warnings.append(warning) self.touch() def acknowledge_warning(self, warning: str) -> None: """Mark a warning as acknowledged.""" if warning in self.warnings and warning not in self.warnings_acknowledged: self.warnings_acknowledged.append(warning) self.touch() def has_unacknowledged_errors(self) -> bool: """Check if there are blocking errors.""" return len(self.errors) > 0 def has_unacknowledged_warnings(self) -> bool: """Check if there are unacknowledged warnings.""" return any(w not in self.warnings_acknowledged for w in self.warnings) def to_json(self) -> str: """Serialize to JSON string.""" return json.dumps(asdict(self), indent=2, default=str) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return asdict(self) @classmethod def from_json(cls, json_str: str) -> "InterviewState": """Deserialize from JSON string.""" data = json.loads(json_str) return cls.from_dict(data) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "InterviewState": """Create from dictionary.""" # Handle nested types return cls( version=data.get("version", "1.0"), session_id=data.get("session_id", str(uuid.uuid4())), study_name=data.get("study_name", ""), study_path=data.get("study_path", ""), parent_study=data.get("parent_study"), started_at=data.get("started_at", datetime.now().isoformat()), last_updated=data.get("last_updated", datetime.now().isoformat()), current_phase=data.get("current_phase", InterviewPhase.INTROSPECTION.value), complexity=data.get("complexity", "simple"), questions_answered=data.get("questions_answered", []), questions_remaining=data.get("questions_remaining", []), current_question_id=data.get("current_question_id"), introspection=data.get("introspection", {}), answers=data.get("answers", {}), inferred_config=data.get("inferred_config", {}), warnings=data.get("warnings", []), warnings_acknowledged=data.get("warnings_acknowledged", []), errors=data.get("errors", []), blueprint=data.get("blueprint"), ) def validate(self) -> List[str]: """Validate state, return list of errors.""" errors = [] if not self.session_id: errors.append("Missing session_id") if not self.study_name: errors.append("Missing study_name") try: InterviewPhase.from_string(self.current_phase) except ValueError: errors.append(f"Invalid current_phase: {self.current_phase}") if self.complexity not in ["simple", "moderate", "complex"]: errors.append(f"Invalid complexity: {self.complexity}") return errors @dataclass class StateSnapshot: """Snapshot of state for history/undo.""" timestamp: str phase: str questions_count: int state_hash: str file_path: str class InterviewStateManager: """ Manages interview state persistence. Handles: - Save/load state to JSON - Human-readable audit log (MD) - State backup rotation - History for undo/branch """ MAX_BACKUPS = 5 def __init__(self, study_path: Path): """ Initialize state manager. Args: study_path: Path to the study directory """ self.study_path = Path(study_path) self.interview_dir = self.study_path / ".interview" self.state_file = self.interview_dir / "interview_state.json" self.log_file = self.interview_dir / "INTERVIEW_LOG.md" self.backup_dir = self.interview_dir / "backups" self.lock_file = self.interview_dir / ".lock" # Ensure directories exist self._ensure_directories() def _ensure_directories(self) -> None: """Create necessary directories if they don't exist.""" self.interview_dir.mkdir(parents=True, exist_ok=True) self.backup_dir.mkdir(exist_ok=True) def _acquire_lock(self) -> bool: """Acquire lock file for concurrent access prevention.""" try: if self.lock_file.exists(): # Check if lock is stale (older than 5 minutes) mtime = self.lock_file.stat().st_mtime age = datetime.now().timestamp() - mtime if age > 300: # 5 minutes self.lock_file.unlink() else: return False self.lock_file.write_text(str(os.getpid())) return True except Exception: return False def _release_lock(self) -> None: """Release lock file.""" try: if self.lock_file.exists(): self.lock_file.unlink() except Exception: pass def exists(self) -> bool: """Check if a saved state exists.""" return self.state_file.exists() def save_state(self, state: InterviewState) -> None: """ Persist current state to JSON. Performs atomic write with backup rotation. """ if not self._acquire_lock(): raise RuntimeError("Could not acquire lock for state file") try: # Update timestamp state.touch() # Create backup if state file exists if self.state_file.exists(): self._rotate_backups() backup_name = f"state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" shutil.copy(self.state_file, self.backup_dir / backup_name) # Atomic write: write to temp file then rename temp_file = self.state_file.with_suffix(".tmp") temp_file.write_text(state.to_json(), encoding="utf-8") temp_file.replace(self.state_file) finally: self._release_lock() def _rotate_backups(self) -> None: """Keep only the most recent backups.""" backups = sorted( self.backup_dir.glob("state_*.json"), key=lambda p: p.stat().st_mtime, reverse=True ) # Remove old backups for backup in backups[self.MAX_BACKUPS:]: backup.unlink() def load_state(self) -> Optional[InterviewState]: """ Load existing state if available. Returns: InterviewState if exists and valid, None otherwise """ if not self.state_file.exists(): return None try: json_str = self.state_file.read_text(encoding="utf-8") state = InterviewState.from_json(json_str) # Validate state errors = state.validate() if errors: raise ValueError(f"Invalid state: {errors}") return state except (json.JSONDecodeError, ValueError) as e: # Log error but don't crash print(f"Warning: Could not load interview state: {e}") return None def append_log(self, entry: LogEntry) -> None: """ Add entry to human-readable audit log. Creates log file with header if it doesn't exist. """ # Initialize log file if needed if not self.log_file.exists(): header = self._create_log_header() self.log_file.write_text(header, encoding="utf-8") # Append entry with open(self.log_file, "a", encoding="utf-8") as f: f.write(entry.to_markdown()) def _create_log_header(self) -> str: """Create header for new log file.""" return f"""# Interview Log **Study**: {self.study_path.name} **Started**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} This log records all questions and answers from the study interview process. --- """ def finalize_log(self, state: InterviewState) -> None: """Add final summary to log when interview completes.""" summary = f""" ## Interview Complete **Completed**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} **Questions Answered**: {len(state.questions_answered)} **Complexity**: {state.complexity} ### Summary - **Problem**: {state.answers.get('problem_description', 'N/A')} - **Objectives**: {len(state.answers.get('objectives', []))} - **Constraints**: {len(state.answers.get('constraints', []))} - **Design Variables**: {len(state.answers.get('design_variables', []))} ### Warnings Acknowledged """ for warning in state.warnings_acknowledged: summary += f"- {warning}\n" if not state.warnings_acknowledged: summary += "- None\n" summary += "\n---\n" with open(self.log_file, "a", encoding="utf-8") as f: f.write(summary) def get_history(self) -> List[StateSnapshot]: """ Get modification history for undo/branch. Returns list of state snapshots from backups. """ snapshots = [] for backup in sorted(self.backup_dir.glob("state_*.json")): try: data = json.loads(backup.read_text(encoding="utf-8")) snapshot = StateSnapshot( timestamp=data.get("last_updated", "unknown"), phase=data.get("current_phase", "unknown"), questions_count=len(data.get("questions_answered", [])), state_hash=str(hash(backup.read_text())), file_path=str(backup), ) snapshots.append(snapshot) except Exception: continue return snapshots def restore_from_backup(self, backup_path: str) -> Optional[InterviewState]: """Restore state from a backup file.""" backup = Path(backup_path) if not backup.exists(): return None try: json_str = backup.read_text(encoding="utf-8") return InterviewState.from_json(json_str) except Exception: return None def delete_state(self) -> None: """Delete all interview state (for restart).""" if self.state_file.exists(): self.state_file.unlink() # Keep log file but add note if self.log_file.exists(): with open(self.log_file, "a", encoding="utf-8") as f: f.write(f"\n## State Reset\n\n**Reset at**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n---\n\n")