diff --git a/src/atocore/api/routes.py b/src/atocore/api/routes.py index 3dc3007..f121d2d 100644 --- a/src/atocore/api/routes.py +++ b/src/atocore/api/routes.py @@ -10,7 +10,13 @@ from atocore.context.builder import ( get_last_context_pack, _pack_to_dict, ) -from atocore.ingestion.pipeline import ingest_file, ingest_folder +from atocore.context.project_state import ( + CATEGORIES, + get_state, + invalidate_state, + set_state, +) +from atocore.ingestion.pipeline import ingest_file, ingest_folder, get_ingestion_stats from atocore.observability.logger import get_logger from atocore.retrieval.retriever import retrieve from atocore.retrieval.vector_store import get_vector_store @@ -57,6 +63,26 @@ class ContextBuildResponse(BaseModel): chunks: list[dict] +class ProjectStateSetRequest(BaseModel): + project: str + category: str + key: str + value: str + source: str = "" + confidence: float = 1.0 + + +class ProjectStateGetRequest(BaseModel): + project: str + category: str | None = None + + +class ProjectStateInvalidateRequest(BaseModel): + project: str + category: str + key: str + + # --- Endpoints --- @@ -127,6 +153,58 @@ def api_build_context(req: ContextBuildRequest) -> ContextBuildResponse: ) +@router.post("/project/state") +def api_set_project_state(req: ProjectStateSetRequest) -> dict: + """Set or update a trusted project state entry.""" + try: + entry = set_state( + project_name=req.project, + category=req.category, + key=req.key, + value=req.value, + source=req.source, + confidence=req.confidence, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + log.error("set_state_failed", error=str(e)) + raise HTTPException(status_code=500, detail=f"Failed to set state: {e}") + return {"status": "ok", "id": entry.id, "category": entry.category, "key": entry.key} + + +@router.get("/project/state/{project_name}") +def api_get_project_state(project_name: str, category: str | None = None) -> dict: + """Get trusted project state entries.""" + entries = get_state(project_name, category=category) + return { + "project": project_name, + "entries": [ + { + "id": e.id, + "category": e.category, + "key": e.key, + "value": e.value, + "source": e.source, + "confidence": e.confidence, + "status": e.status, + "updated_at": e.updated_at, + } + for e in entries + ], + "categories": CATEGORIES, + } + + +@router.delete("/project/state") +def api_invalidate_project_state(req: ProjectStateInvalidateRequest) -> dict: + """Invalidate (supersede) a project state entry.""" + success = invalidate_state(req.project, req.category, req.key) + if not success: + raise HTTPException(status_code=404, detail="State entry not found or already invalidated") + return {"status": "invalidated", "project": req.project, "category": req.category, "key": req.key} + + @router.get("/health") def api_health() -> dict: """Health check.""" @@ -138,6 +216,12 @@ def api_health() -> dict: } +@router.get("/stats") +def api_stats() -> dict: + """Ingestion statistics.""" + return get_ingestion_stats() + + @router.get("/debug/context") def api_debug_context() -> dict: """Inspect the last assembled context pack.""" diff --git a/src/atocore/context/builder.py b/src/atocore/context/builder.py index eb02c1d..ddb2066 100644 --- a/src/atocore/context/builder.py +++ b/src/atocore/context/builder.py @@ -1,11 +1,16 @@ -"""Context pack assembly: retrieve, rank, budget, format.""" +"""Context pack assembly: retrieve, rank, budget, format. + +Trust precedence (per Master Plan): + 1. Trusted Project State → always included first, uses its own budget slice + 2. Retrieved chunks → ranked, deduplicated, budget-constrained +""" -import json import time from dataclasses import dataclass, field from pathlib import Path from atocore.config import settings +from atocore.context.project_state import format_project_state, get_state from atocore.observability.logger import get_logger from atocore.retrieval.retriever import ChunkResult, retrieve @@ -14,9 +19,14 @@ log = get_logger("context_builder") SYSTEM_PREFIX = ( "You have access to the following personal context from the user's knowledge base.\n" "Use it to inform your answer. If the context is not relevant, ignore it.\n" - "Do not mention the context system unless asked." + "Do not mention the context system unless asked.\n" + "When project state is provided, treat it as the most authoritative source." ) +# Budget allocation (per Master Plan section 9) +# project_state gets up to 20% of budget, retrieval gets the rest +PROJECT_STATE_BUDGET_RATIO = 0.20 + # Last built context pack for debug inspection _last_context_pack: "ContextPack | None" = None @@ -33,6 +43,8 @@ class ContextChunk: @dataclass class ContextPack: chunks_used: list[ContextChunk] = field(default_factory=list) + project_state_text: str = "" + project_state_chars: int = 0 total_chars: int = 0 budget: int = 0 budget_remaining: int = 0 @@ -48,31 +60,61 @@ def build_context( project_hint: str | None = None, budget: int | None = None, ) -> ContextPack: - """Build a context pack for a user prompt.""" + """Build a context pack for a user prompt. + + Trust precedence applied: + 1. Project state is injected first (highest trust) + 2. Retrieved chunks fill the remaining budget + """ global _last_context_pack start = time.time() budget = budget or settings.context_budget - # 1. Retrieve candidates + # 1. Get Trusted Project State (highest precedence) + project_state_text = "" + project_state_chars = 0 + state_budget = int(budget * PROJECT_STATE_BUDGET_RATIO) + + if project_hint: + state_entries = get_state(project_hint) + if state_entries: + project_state_text = format_project_state(state_entries) + project_state_chars = len(project_state_text) + # If state exceeds its budget, it still gets included (it's highest trust) + # but we log it + if project_state_chars > state_budget: + log.info( + "project_state_exceeds_budget", + state_chars=project_state_chars, + state_budget=state_budget, + ) + + # 2. Calculate remaining budget for retrieval + retrieval_budget = budget - project_state_chars + + # 3. Retrieve candidates candidates = retrieve(user_prompt, top_k=settings.context_top_k) - # 2. Score and rank + # 4. Score and rank scored = _rank_chunks(candidates, project_hint) - # 3. Select within budget - selected = _select_within_budget(scored, budget) + # 5. Select within remaining budget + selected = _select_within_budget(scored, max(retrieval_budget, 0)) - # 4. Format - formatted = _format_context_block(selected) + # 6. Format full context + formatted = _format_full_context(project_state_text, selected) - # 5. Build full prompt + # 7. Build full prompt full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}" - total_chars = sum(c.char_count for c in selected) + retrieval_chars = sum(c.char_count for c in selected) + total_chars = project_state_chars + retrieval_chars duration_ms = int((time.time() - start) * 1000) pack = ContextPack( chunks_used=selected, + project_state_text=project_state_text, + project_state_chars=project_state_chars, total_chars=total_chars, budget=budget, budget_remaining=budget - total_chars, @@ -88,6 +130,8 @@ def build_context( log.info( "context_built", chunks_used=len(selected), + project_state_chars=project_state_chars, + retrieval_chars=retrieval_chars, total_chars=total_chars, budget_remaining=budget - total_chars, duration_ms=duration_ms, @@ -163,27 +207,38 @@ def _select_within_budget( return selected -def _format_context_block(chunks: list[ContextChunk]) -> str: - """Format chunks into the context block string.""" - if not chunks: - return "--- AtoCore Context ---\nNo relevant context found.\n--- End Context ---" +def _format_full_context( + project_state_text: str, + chunks: list[ContextChunk], +) -> str: + """Format project state + retrieved chunks into full context block.""" + parts = [] - lines = ["--- AtoCore Context ---"] - for chunk in chunks: - lines.append( - f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]" - ) - lines.append(chunk.content) - lines.append("") - lines.append("--- End Context ---") - return "\n".join(lines) + # Project state first (highest trust) + if project_state_text: + parts.append(project_state_text) + parts.append("") + + # Retrieved chunks + if chunks: + parts.append("--- AtoCore Retrieved Context ---") + for chunk in chunks: + parts.append( + f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]" + ) + parts.append(chunk.content) + parts.append("") + parts.append("--- End Context ---") + elif not project_state_text: + parts.append("--- AtoCore Context ---\nNo relevant context found.\n--- End Context ---") + + return "\n".join(parts) def _shorten_path(path: str) -> str: """Shorten an absolute path to a relative-like display.""" p = Path(path) parts = p.parts - # Show last 3 parts at most if len(parts) > 3: return str(Path(*parts[-3:])) return str(p) @@ -194,11 +249,13 @@ def _pack_to_dict(pack: ContextPack) -> dict: return { "query": pack.query, "project_hint": pack.project_hint, + "project_state_chars": pack.project_state_chars, "chunks_used": len(pack.chunks_used), "total_chars": pack.total_chars, "budget": pack.budget, "budget_remaining": pack.budget_remaining, "duration_ms": pack.duration_ms, + "has_project_state": bool(pack.project_state_text), "chunks": [ { "source_file": c.source_file, diff --git a/src/atocore/context/project_state.py b/src/atocore/context/project_state.py new file mode 100644 index 0000000..2adacbf --- /dev/null +++ b/src/atocore/context/project_state.py @@ -0,0 +1,231 @@ +"""Trusted Project State — the highest-priority context source. + +Per the Master Plan trust precedence: + 1. Trusted Project State (this module) + 2. AtoDrive artifacts + 3. Recent validated memory + 4. AtoVault summaries + 5. PKM chunks + 6. Historical / low-confidence + +Project state is manually curated or explicitly confirmed facts about a project. +It always wins over retrieval-based context when there's a conflict. +""" + +import json +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone + +from atocore.models.database import get_connection +from atocore.observability.logger import get_logger + +log = get_logger("project_state") + +# DB schema extension for project state +PROJECT_STATE_SCHEMA = """ +CREATE TABLE IF NOT EXISTS project_state ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + category TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + source TEXT DEFAULT '', + confidence REAL DEFAULT 1.0, + status TEXT DEFAULT 'active', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(project_id, category, key) +); + +CREATE INDEX IF NOT EXISTS idx_project_state_project ON project_state(project_id); +CREATE INDEX IF NOT EXISTS idx_project_state_category ON project_state(category); +CREATE INDEX IF NOT EXISTS idx_project_state_status ON project_state(status); +""" + +# Valid categories for project state entries +CATEGORIES = [ + "status", # current project status, phase, blockers + "decision", # confirmed design/engineering decisions + "requirement", # key requirements and constraints + "contact", # key people, vendors, stakeholders + "milestone", # dates, deadlines, deliverables + "fact", # verified technical facts + "config", # project configuration, parameters +] + + +@dataclass +class ProjectStateEntry: + id: str + project_id: str + category: str + key: str + value: str + source: str = "" + confidence: float = 1.0 + status: str = "active" + created_at: str = "" + updated_at: str = "" + + +def init_project_state_schema() -> None: + """Create the project_state table if it doesn't exist.""" + with get_connection() as conn: + conn.executescript(PROJECT_STATE_SCHEMA) + log.info("project_state_schema_initialized") + + +def ensure_project(name: str, description: str = "") -> str: + """Get or create a project by name. Returns project_id.""" + with get_connection() as conn: + row = conn.execute( + "SELECT id FROM projects WHERE name = ?", (name,) + ).fetchone() + if row: + return row["id"] + + project_id = str(uuid.uuid4()) + conn.execute( + "INSERT INTO projects (id, name, description) VALUES (?, ?, ?)", + (project_id, name, description), + ) + log.info("project_created", name=name, project_id=project_id) + return project_id + + +def set_state( + project_name: str, + category: str, + key: str, + value: str, + source: str = "", + confidence: float = 1.0, +) -> ProjectStateEntry: + """Set or update a project state entry. Upsert semantics.""" + if category not in CATEGORIES: + raise ValueError(f"Invalid category '{category}'. Must be one of: {CATEGORIES}") + + project_id = ensure_project(project_name) + entry_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + + with get_connection() as conn: + # Check if entry exists + existing = conn.execute( + "SELECT id FROM project_state WHERE project_id = ? AND category = ? AND key = ?", + (project_id, category, key), + ).fetchone() + + if existing: + entry_id = existing["id"] + conn.execute( + "UPDATE project_state SET value = ?, source = ?, confidence = ?, " + "status = 'active', updated_at = CURRENT_TIMESTAMP " + "WHERE id = ?", + (value, source, confidence, entry_id), + ) + log.info("project_state_updated", project=project_name, category=category, key=key) + else: + conn.execute( + "INSERT INTO project_state (id, project_id, category, key, value, source, confidence) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (entry_id, project_id, category, key, value, source, confidence), + ) + log.info("project_state_created", project=project_name, category=category, key=key) + + return ProjectStateEntry( + id=entry_id, + project_id=project_id, + category=category, + key=key, + value=value, + source=source, + confidence=confidence, + status="active", + created_at=now, + updated_at=now, + ) + + +def get_state( + project_name: str, + category: str | None = None, + active_only: bool = True, +) -> list[ProjectStateEntry]: + """Get project state entries, optionally filtered by category.""" + with get_connection() as conn: + project = conn.execute( + "SELECT id FROM projects WHERE name = ?", (project_name,) + ).fetchone() + if not project: + return [] + + query = "SELECT * FROM project_state WHERE project_id = ?" + params: list = [project["id"]] + + if category: + query += " AND category = ?" + params.append(category) + if active_only: + query += " AND status = 'active'" + + query += " ORDER BY category, key" + rows = conn.execute(query, params).fetchall() + + return [ + ProjectStateEntry( + id=r["id"], + project_id=r["project_id"], + category=r["category"], + key=r["key"], + value=r["value"], + source=r["source"], + confidence=r["confidence"], + status=r["status"], + created_at=r["created_at"], + updated_at=r["updated_at"], + ) + for r in rows + ] + + +def invalidate_state(project_name: str, category: str, key: str) -> bool: + """Mark a project state entry as superseded.""" + with get_connection() as conn: + project = conn.execute( + "SELECT id FROM projects WHERE name = ?", (project_name,) + ).fetchone() + if not project: + return False + + result = conn.execute( + "UPDATE project_state SET status = 'superseded', updated_at = CURRENT_TIMESTAMP " + "WHERE project_id = ? AND category = ? AND key = ? AND status = 'active'", + (project["id"], category, key), + ) + if result.rowcount > 0: + log.info("project_state_invalidated", project=project_name, category=category, key=key) + return True + return False + + +def format_project_state(entries: list[ProjectStateEntry]) -> str: + """Format project state entries for context injection.""" + if not entries: + return "" + + lines = ["--- Trusted Project State ---"] + current_category = "" + + for entry in entries: + if entry.category != current_category: + current_category = entry.category + lines.append(f"\n[{current_category.upper()}]") + lines.append(f" {entry.key}: {entry.value}") + if entry.source: + lines.append(f" (source: {entry.source})") + + lines.append("\n--- End Project State ---") + return "\n".join(lines) diff --git a/src/atocore/ingestion/pipeline.py b/src/atocore/ingestion/pipeline.py index 290ecb1..ed829c9 100644 --- a/src/atocore/ingestion/pipeline.py +++ b/src/atocore/ingestion/pipeline.py @@ -15,6 +15,9 @@ from atocore.retrieval.vector_store import get_vector_store log = get_logger("ingestion") +# Encodings to try when reading markdown files +_ENCODINGS = ["utf-8", "utf-8-sig", "latin-1", "cp1252"] + def ingest_file(file_path: Path) -> dict: """Ingest a single markdown file. Returns stats.""" @@ -26,9 +29,9 @@ def ingest_file(file_path: Path) -> dict: if file_path.suffix.lower() not in (".md", ".markdown"): raise ValueError(f"Not a markdown file: {file_path}") - # Read and hash - raw_content = file_path.read_text(encoding="utf-8") - file_hash = hashlib.sha256(raw_content.encode()).hexdigest() + # Read with encoding fallback + raw_content = _read_file_safe(file_path) + file_hash = hashlib.sha256(raw_content.encode("utf-8")).hexdigest() # Check if already ingested and unchanged with get_connection() as conn: @@ -136,16 +139,24 @@ def ingest_file(file_path: Path) -> dict: } -def ingest_folder(folder_path: Path) -> list[dict]: - """Ingest all markdown files in a folder recursively.""" +def ingest_folder(folder_path: Path, purge_deleted: bool = True) -> list[dict]: + """Ingest all markdown files in a folder recursively. + + Args: + folder_path: Directory to scan for .md files. + purge_deleted: If True, remove DB/vector entries for files + that no longer exist on disk. + """ folder_path = folder_path.resolve() if not folder_path.is_dir(): raise NotADirectoryError(f"Not a directory: {folder_path}") results = [] md_files = sorted(folder_path.rglob("*.md")) + current_paths = {str(f.resolve()) for f in md_files} log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files)) + # Ingest new/changed files for md_file in md_files: try: result = ingest_file(md_file) @@ -154,4 +165,80 @@ def ingest_folder(folder_path: Path) -> list[dict]: log.error("ingestion_error", file_path=str(md_file), error=str(e)) results.append({"file": str(md_file), "status": "error", "error": str(e)}) + # Purge entries for deleted files + if purge_deleted: + deleted = _purge_deleted_files(folder_path, current_paths) + if deleted: + log.info("purged_deleted_files", count=deleted) + results.append({"status": "purged", "deleted_count": deleted}) + return results + + +def get_ingestion_stats() -> dict: + """Return ingestion statistics.""" + with get_connection() as conn: + docs = conn.execute("SELECT COUNT(*) as c FROM source_documents").fetchone() + chunks = conn.execute("SELECT COUNT(*) as c FROM source_chunks").fetchone() + recent = conn.execute( + "SELECT file_path, title, ingested_at FROM source_documents " + "ORDER BY updated_at DESC LIMIT 5" + ).fetchall() + + vector_store = get_vector_store() + return { + "total_documents": docs["c"], + "total_chunks": chunks["c"], + "total_vectors": vector_store.count, + "recent_documents": [ + {"file_path": r["file_path"], "title": r["title"], "ingested_at": r["ingested_at"]} + for r in recent + ], + } + + +def _read_file_safe(file_path: Path) -> str: + """Read a file with encoding fallback.""" + for encoding in _ENCODINGS: + try: + return file_path.read_text(encoding=encoding) + except (UnicodeDecodeError, ValueError): + continue + # Last resort: read with errors replaced + return file_path.read_text(encoding="utf-8", errors="replace") + + +def _purge_deleted_files(folder_path: Path, current_paths: set[str]) -> int: + """Remove DB/vector entries for files under folder_path that no longer exist.""" + folder_str = str(folder_path) + deleted_count = 0 + vector_store = get_vector_store() + + with get_connection() as conn: + # Find documents under this folder + rows = conn.execute( + "SELECT id, file_path FROM source_documents WHERE file_path LIKE ?", + (f"{folder_str}%",), + ).fetchall() + + for row in rows: + if row["file_path"] not in current_paths: + doc_id = row["id"] + # Get chunk IDs for vector deletion + chunk_ids = [ + r["id"] + for r in conn.execute( + "SELECT id FROM source_chunks WHERE document_id = ?", + (doc_id,), + ).fetchall() + ] + # Delete from DB + conn.execute("DELETE FROM source_chunks WHERE document_id = ?", (doc_id,)) + conn.execute("DELETE FROM source_documents WHERE id = ?", (doc_id,)) + # Delete from vectors + if chunk_ids: + vector_store.delete(chunk_ids) + log.info("purged_deleted_file", file_path=row["file_path"]) + deleted_count += 1 + + return deleted_count diff --git a/src/atocore/main.py b/src/atocore/main.py index 33529a4..9c79b0d 100644 --- a/src/atocore/main.py +++ b/src/atocore/main.py @@ -4,6 +4,7 @@ from fastapi import FastAPI from atocore.api.routes import router from atocore.config import settings +from atocore.context.project_state import init_project_state_schema from atocore.models.database import init_db from atocore.observability.logger import setup_logging @@ -20,6 +21,7 @@ app.include_router(router) def startup(): setup_logging() init_db() + init_project_state_schema() if __name__ == "__main__": diff --git a/tests/test_context_builder.py b/tests/test_context_builder.py index fdba8d9..1939e44 100644 --- a/tests/test_context_builder.py +++ b/tests/test_context_builder.py @@ -1,6 +1,7 @@ """Tests for the context builder.""" from atocore.context.builder import build_context, get_last_context_pack +from atocore.context.project_state import init_project_state_schema, set_state from atocore.ingestion.pipeline import ingest_file from atocore.models.database import init_db @@ -8,19 +9,20 @@ from atocore.models.database import init_db def test_build_context_returns_pack(tmp_data_dir, sample_markdown): """Test that context builder returns a valid pack.""" init_db() + init_project_state_schema() ingest_file(sample_markdown) pack = build_context("What is AtoCore?") assert pack.total_chars > 0 assert len(pack.chunks_used) > 0 assert pack.budget_remaining >= 0 - assert "--- AtoCore Context ---" in pack.formatted_context assert "--- End Context ---" in pack.formatted_context def test_context_respects_budget(tmp_data_dir, sample_markdown): """Test that context builder respects character budget.""" init_db() + init_project_state_schema() ingest_file(sample_markdown) pack = build_context("What is AtoCore?", budget=500) @@ -30,17 +32,18 @@ def test_context_respects_budget(tmp_data_dir, sample_markdown): def test_context_with_project_hint(tmp_data_dir, sample_markdown): """Test that project hint boosts relevant chunks.""" init_db() + init_project_state_schema() ingest_file(sample_markdown) pack = build_context("What is the architecture?", project_hint="atocore") assert len(pack.chunks_used) > 0 - # With project hint, we should still get results assert pack.total_chars > 0 def test_last_context_pack_stored(tmp_data_dir, sample_markdown): """Test that last context pack is stored for debug.""" init_db() + init_project_state_schema() ingest_file(sample_markdown) build_context("test prompt") @@ -52,9 +55,54 @@ def test_last_context_pack_stored(tmp_data_dir, sample_markdown): def test_full_prompt_structure(tmp_data_dir, sample_markdown): """Test that the full prompt has correct structure.""" init_db() + init_project_state_schema() ingest_file(sample_markdown) pack = build_context("What are memory types?") assert "knowledge base" in pack.full_prompt.lower() - assert "--- AtoCore Context ---" in pack.full_prompt assert "What are memory types?" in pack.full_prompt + + +def test_project_state_included_in_context(tmp_data_dir, sample_markdown): + """Test that trusted project state is injected into context.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + # Set some project state + set_state("atocore", "status", "phase", "Phase 0.5 complete") + set_state("atocore", "decision", "database", "SQLite for structured data") + + pack = build_context("What is AtoCore?", project_hint="atocore") + + # Project state should appear in context + assert "--- Trusted Project State ---" in pack.formatted_context + assert "Phase 0.5 complete" in pack.formatted_context + assert "SQLite for structured data" in pack.formatted_context + assert pack.project_state_chars > 0 + + +def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown): + """Test that project state is included even with tight budget.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + set_state("atocore", "status", "phase", "Phase 1 in progress") + + # Small budget — project state should still be included + pack = build_context("status?", project_hint="atocore", budget=500) + assert "Phase 1 in progress" in pack.formatted_context + + +def test_no_project_state_without_hint(tmp_data_dir, sample_markdown): + """Test that project state is not included without project hint.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + set_state("atocore", "status", "phase", "Phase 1") + + pack = build_context("What is AtoCore?") + assert pack.project_state_chars == 0 + assert "--- Trusted Project State ---" not in pack.formatted_context diff --git a/tests/test_project_state.py b/tests/test_project_state.py new file mode 100644 index 0000000..ffa7248 --- /dev/null +++ b/tests/test_project_state.py @@ -0,0 +1,127 @@ +"""Tests for Trusted Project State.""" + +import pytest + +from atocore.context.project_state import ( + CATEGORIES, + ensure_project, + format_project_state, + get_state, + init_project_state_schema, + invalidate_state, + set_state, +) +from atocore.models.database import init_db + + +@pytest.fixture(autouse=True) +def setup_db(tmp_data_dir): + """Initialize DB and project state schema for every test.""" + init_db() + init_project_state_schema() + + +def test_ensure_project_creates(): + """Test creating a new project.""" + pid = ensure_project("test-project", "A test project") + assert pid + # Second call returns same ID + pid2 = ensure_project("test-project") + assert pid == pid2 + + +def test_set_state_creates_entry(): + """Test creating a project state entry.""" + entry = set_state("myproject", "status", "phase", "Phase 0.5 — PoC complete") + assert entry.category == "status" + assert entry.key == "phase" + assert entry.value == "Phase 0.5 — PoC complete" + assert entry.status == "active" + + +def test_set_state_upserts(): + """Test that setting same key updates the value.""" + set_state("myproject", "status", "phase", "Phase 0") + entry = set_state("myproject", "status", "phase", "Phase 1") + assert entry.value == "Phase 1" + + # Only one entry should exist + entries = get_state("myproject", category="status") + assert len(entries) == 1 + assert entries[0].value == "Phase 1" + + +def test_set_state_invalid_category(): + """Test that invalid category raises ValueError.""" + with pytest.raises(ValueError, match="Invalid category"): + set_state("myproject", "invalid_category", "key", "value") + + +def test_get_state_all(): + """Test getting all state entries for a project.""" + set_state("proj", "status", "phase", "Phase 1") + set_state("proj", "decision", "database", "SQLite for v1") + set_state("proj", "requirement", "latency", "<2 seconds") + + entries = get_state("proj") + assert len(entries) == 3 + categories = {e.category for e in entries} + assert categories == {"status", "decision", "requirement"} + + +def test_get_state_by_category(): + """Test filtering by category.""" + set_state("proj", "status", "phase", "Phase 1") + set_state("proj", "decision", "database", "SQLite") + set_state("proj", "decision", "vectordb", "ChromaDB") + + entries = get_state("proj", category="decision") + assert len(entries) == 2 + assert all(e.category == "decision" for e in entries) + + +def test_get_state_nonexistent_project(): + """Test getting state for a project that doesn't exist.""" + entries = get_state("nonexistent") + assert entries == [] + + +def test_invalidate_state(): + """Test marking a state entry as superseded.""" + set_state("invalidate-test", "decision", "approach", "monolith") + success = invalidate_state("invalidate-test", "decision", "approach") + assert success + + # Active entries should be empty + entries = get_state("invalidate-test", active_only=True) + assert len(entries) == 0 + + # But entry still exists if we include inactive + entries = get_state("invalidate-test", active_only=False) + assert len(entries) == 1 + assert entries[0].status == "superseded" + + +def test_invalidate_nonexistent(): + """Test invalidating a nonexistent entry.""" + success = invalidate_state("proj", "decision", "nonexistent") + assert not success + + +def test_format_project_state(): + """Test formatting state entries for context injection.""" + set_state("proj", "status", "phase", "Phase 1") + set_state("proj", "decision", "database", "SQLite", source="Build Spec V1") + entries = get_state("proj") + + formatted = format_project_state(entries) + assert "--- Trusted Project State ---" in formatted + assert "--- End Project State ---" in formatted + assert "phase: Phase 1" in formatted + assert "database: SQLite" in formatted + assert "(source: Build Spec V1)" in formatted + + +def test_format_empty(): + """Test formatting empty state.""" + assert format_project_state([]) == ""