feat: Phase 1 ingestion hardening + Phase 5 Trusted Project State

Phase 1 - Ingestion hardening:
- Encoding fallback (UTF-8/UTF-8-sig/Latin-1/CP1252)
- Delete detection: purge DB/vector entries for removed files
- Ingestion stats endpoint (GET /stats)

Phase 5 - Trusted Project State:
- project_state table with categories (status, decision, requirement, contact, milestone, fact, config)
- CRUD API: POST/GET/DELETE /project/state
- Upsert semantics, invalidation (supersede) support
- Context builder integrates project state at highest trust precedence
- Project state gets 20% budget allocation, appears first in context
- Trust precedence: Project State > Retrieved Chunks (per Master Plan)

33/33 tests passing. Validated end-to-end with GigaBIT M1 project data.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-05 09:41:59 -04:00
parent 6081462058
commit 531c560db7
7 changed files with 671 additions and 35 deletions

View File

@@ -10,7 +10,13 @@ from atocore.context.builder import (
get_last_context_pack, get_last_context_pack,
_pack_to_dict, _pack_to_dict,
) )
from atocore.ingestion.pipeline import ingest_file, ingest_folder from atocore.context.project_state import (
CATEGORIES,
get_state,
invalidate_state,
set_state,
)
from atocore.ingestion.pipeline import ingest_file, ingest_folder, get_ingestion_stats
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
from atocore.retrieval.retriever import retrieve from atocore.retrieval.retriever import retrieve
from atocore.retrieval.vector_store import get_vector_store from atocore.retrieval.vector_store import get_vector_store
@@ -57,6 +63,26 @@ class ContextBuildResponse(BaseModel):
chunks: list[dict] chunks: list[dict]
class ProjectStateSetRequest(BaseModel):
project: str
category: str
key: str
value: str
source: str = ""
confidence: float = 1.0
class ProjectStateGetRequest(BaseModel):
project: str
category: str | None = None
class ProjectStateInvalidateRequest(BaseModel):
project: str
category: str
key: str
# --- Endpoints --- # --- Endpoints ---
@@ -127,6 +153,58 @@ def api_build_context(req: ContextBuildRequest) -> ContextBuildResponse:
) )
@router.post("/project/state")
def api_set_project_state(req: ProjectStateSetRequest) -> dict:
"""Set or update a trusted project state entry."""
try:
entry = set_state(
project_name=req.project,
category=req.category,
key=req.key,
value=req.value,
source=req.source,
confidence=req.confidence,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
log.error("set_state_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to set state: {e}")
return {"status": "ok", "id": entry.id, "category": entry.category, "key": entry.key}
@router.get("/project/state/{project_name}")
def api_get_project_state(project_name: str, category: str | None = None) -> dict:
"""Get trusted project state entries."""
entries = get_state(project_name, category=category)
return {
"project": project_name,
"entries": [
{
"id": e.id,
"category": e.category,
"key": e.key,
"value": e.value,
"source": e.source,
"confidence": e.confidence,
"status": e.status,
"updated_at": e.updated_at,
}
for e in entries
],
"categories": CATEGORIES,
}
@router.delete("/project/state")
def api_invalidate_project_state(req: ProjectStateInvalidateRequest) -> dict:
"""Invalidate (supersede) a project state entry."""
success = invalidate_state(req.project, req.category, req.key)
if not success:
raise HTTPException(status_code=404, detail="State entry not found or already invalidated")
return {"status": "invalidated", "project": req.project, "category": req.category, "key": req.key}
@router.get("/health") @router.get("/health")
def api_health() -> dict: def api_health() -> dict:
"""Health check.""" """Health check."""
@@ -138,6 +216,12 @@ def api_health() -> dict:
} }
@router.get("/stats")
def api_stats() -> dict:
"""Ingestion statistics."""
return get_ingestion_stats()
@router.get("/debug/context") @router.get("/debug/context")
def api_debug_context() -> dict: def api_debug_context() -> dict:
"""Inspect the last assembled context pack.""" """Inspect the last assembled context pack."""

View File

@@ -1,11 +1,16 @@
"""Context pack assembly: retrieve, rank, budget, format.""" """Context pack assembly: retrieve, rank, budget, format.
Trust precedence (per Master Plan):
1. Trusted Project State → always included first, uses its own budget slice
2. Retrieved chunks → ranked, deduplicated, budget-constrained
"""
import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from atocore.config import settings from atocore.config import settings
from atocore.context.project_state import format_project_state, get_state
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
from atocore.retrieval.retriever import ChunkResult, retrieve from atocore.retrieval.retriever import ChunkResult, retrieve
@@ -14,9 +19,14 @@ log = get_logger("context_builder")
SYSTEM_PREFIX = ( SYSTEM_PREFIX = (
"You have access to the following personal context from the user's knowledge base.\n" "You have access to the following personal context from the user's knowledge base.\n"
"Use it to inform your answer. If the context is not relevant, ignore it.\n" "Use it to inform your answer. If the context is not relevant, ignore it.\n"
"Do not mention the context system unless asked." "Do not mention the context system unless asked.\n"
"When project state is provided, treat it as the most authoritative source."
) )
# Budget allocation (per Master Plan section 9)
# project_state gets up to 20% of budget, retrieval gets the rest
PROJECT_STATE_BUDGET_RATIO = 0.20
# Last built context pack for debug inspection # Last built context pack for debug inspection
_last_context_pack: "ContextPack | None" = None _last_context_pack: "ContextPack | None" = None
@@ -33,6 +43,8 @@ class ContextChunk:
@dataclass @dataclass
class ContextPack: class ContextPack:
chunks_used: list[ContextChunk] = field(default_factory=list) chunks_used: list[ContextChunk] = field(default_factory=list)
project_state_text: str = ""
project_state_chars: int = 0
total_chars: int = 0 total_chars: int = 0
budget: int = 0 budget: int = 0
budget_remaining: int = 0 budget_remaining: int = 0
@@ -48,31 +60,61 @@ def build_context(
project_hint: str | None = None, project_hint: str | None = None,
budget: int | None = None, budget: int | None = None,
) -> ContextPack: ) -> ContextPack:
"""Build a context pack for a user prompt.""" """Build a context pack for a user prompt.
Trust precedence applied:
1. Project state is injected first (highest trust)
2. Retrieved chunks fill the remaining budget
"""
global _last_context_pack global _last_context_pack
start = time.time() start = time.time()
budget = budget or settings.context_budget budget = budget or settings.context_budget
# 1. Retrieve candidates # 1. Get Trusted Project State (highest precedence)
project_state_text = ""
project_state_chars = 0
state_budget = int(budget * PROJECT_STATE_BUDGET_RATIO)
if project_hint:
state_entries = get_state(project_hint)
if state_entries:
project_state_text = format_project_state(state_entries)
project_state_chars = len(project_state_text)
# If state exceeds its budget, it still gets included (it's highest trust)
# but we log it
if project_state_chars > state_budget:
log.info(
"project_state_exceeds_budget",
state_chars=project_state_chars,
state_budget=state_budget,
)
# 2. Calculate remaining budget for retrieval
retrieval_budget = budget - project_state_chars
# 3. Retrieve candidates
candidates = retrieve(user_prompt, top_k=settings.context_top_k) candidates = retrieve(user_prompt, top_k=settings.context_top_k)
# 2. Score and rank # 4. Score and rank
scored = _rank_chunks(candidates, project_hint) scored = _rank_chunks(candidates, project_hint)
# 3. Select within budget # 5. Select within remaining budget
selected = _select_within_budget(scored, budget) selected = _select_within_budget(scored, max(retrieval_budget, 0))
# 4. Format # 6. Format full context
formatted = _format_context_block(selected) formatted = _format_full_context(project_state_text, selected)
# 5. Build full prompt # 7. Build full prompt
full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}" full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}"
total_chars = sum(c.char_count for c in selected) retrieval_chars = sum(c.char_count for c in selected)
total_chars = project_state_chars + retrieval_chars
duration_ms = int((time.time() - start) * 1000) duration_ms = int((time.time() - start) * 1000)
pack = ContextPack( pack = ContextPack(
chunks_used=selected, chunks_used=selected,
project_state_text=project_state_text,
project_state_chars=project_state_chars,
total_chars=total_chars, total_chars=total_chars,
budget=budget, budget=budget,
budget_remaining=budget - total_chars, budget_remaining=budget - total_chars,
@@ -88,6 +130,8 @@ def build_context(
log.info( log.info(
"context_built", "context_built",
chunks_used=len(selected), chunks_used=len(selected),
project_state_chars=project_state_chars,
retrieval_chars=retrieval_chars,
total_chars=total_chars, total_chars=total_chars,
budget_remaining=budget - total_chars, budget_remaining=budget - total_chars,
duration_ms=duration_ms, duration_ms=duration_ms,
@@ -163,27 +207,38 @@ def _select_within_budget(
return selected return selected
def _format_context_block(chunks: list[ContextChunk]) -> str: def _format_full_context(
"""Format chunks into the context block string.""" project_state_text: str,
if not chunks: chunks: list[ContextChunk],
return "--- AtoCore Context ---\nNo relevant context found.\n--- End Context ---" ) -> str:
"""Format project state + retrieved chunks into full context block."""
parts = []
lines = ["--- AtoCore Context ---"] # Project state first (highest trust)
for chunk in chunks: if project_state_text:
lines.append( parts.append(project_state_text)
f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]" parts.append("")
)
lines.append(chunk.content) # Retrieved chunks
lines.append("") if chunks:
lines.append("--- End Context ---") parts.append("--- AtoCore Retrieved Context ---")
return "\n".join(lines) for chunk in chunks:
parts.append(
f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]"
)
parts.append(chunk.content)
parts.append("")
parts.append("--- End Context ---")
elif not project_state_text:
parts.append("--- AtoCore Context ---\nNo relevant context found.\n--- End Context ---")
return "\n".join(parts)
def _shorten_path(path: str) -> str: def _shorten_path(path: str) -> str:
"""Shorten an absolute path to a relative-like display.""" """Shorten an absolute path to a relative-like display."""
p = Path(path) p = Path(path)
parts = p.parts parts = p.parts
# Show last 3 parts at most
if len(parts) > 3: if len(parts) > 3:
return str(Path(*parts[-3:])) return str(Path(*parts[-3:]))
return str(p) return str(p)
@@ -194,11 +249,13 @@ def _pack_to_dict(pack: ContextPack) -> dict:
return { return {
"query": pack.query, "query": pack.query,
"project_hint": pack.project_hint, "project_hint": pack.project_hint,
"project_state_chars": pack.project_state_chars,
"chunks_used": len(pack.chunks_used), "chunks_used": len(pack.chunks_used),
"total_chars": pack.total_chars, "total_chars": pack.total_chars,
"budget": pack.budget, "budget": pack.budget,
"budget_remaining": pack.budget_remaining, "budget_remaining": pack.budget_remaining,
"duration_ms": pack.duration_ms, "duration_ms": pack.duration_ms,
"has_project_state": bool(pack.project_state_text),
"chunks": [ "chunks": [
{ {
"source_file": c.source_file, "source_file": c.source_file,

View File

@@ -0,0 +1,231 @@
"""Trusted Project State — the highest-priority context source.
Per the Master Plan trust precedence:
1. Trusted Project State (this module)
2. AtoDrive artifacts
3. Recent validated memory
4. AtoVault summaries
5. PKM chunks
6. Historical / low-confidence
Project state is manually curated or explicitly confirmed facts about a project.
It always wins over retrieval-based context when there's a conflict.
"""
import json
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from atocore.models.database import get_connection
from atocore.observability.logger import get_logger
log = get_logger("project_state")
# DB schema extension for project state
PROJECT_STATE_SCHEMA = """
CREATE TABLE IF NOT EXISTS project_state (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
category TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
source TEXT DEFAULT '',
confidence REAL DEFAULT 1.0,
status TEXT DEFAULT 'active',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(project_id, category, key)
);
CREATE INDEX IF NOT EXISTS idx_project_state_project ON project_state(project_id);
CREATE INDEX IF NOT EXISTS idx_project_state_category ON project_state(category);
CREATE INDEX IF NOT EXISTS idx_project_state_status ON project_state(status);
"""
# Valid categories for project state entries
CATEGORIES = [
"status", # current project status, phase, blockers
"decision", # confirmed design/engineering decisions
"requirement", # key requirements and constraints
"contact", # key people, vendors, stakeholders
"milestone", # dates, deadlines, deliverables
"fact", # verified technical facts
"config", # project configuration, parameters
]
@dataclass
class ProjectStateEntry:
id: str
project_id: str
category: str
key: str
value: str
source: str = ""
confidence: float = 1.0
status: str = "active"
created_at: str = ""
updated_at: str = ""
def init_project_state_schema() -> None:
"""Create the project_state table if it doesn't exist."""
with get_connection() as conn:
conn.executescript(PROJECT_STATE_SCHEMA)
log.info("project_state_schema_initialized")
def ensure_project(name: str, description: str = "") -> str:
"""Get or create a project by name. Returns project_id."""
with get_connection() as conn:
row = conn.execute(
"SELECT id FROM projects WHERE name = ?", (name,)
).fetchone()
if row:
return row["id"]
project_id = str(uuid.uuid4())
conn.execute(
"INSERT INTO projects (id, name, description) VALUES (?, ?, ?)",
(project_id, name, description),
)
log.info("project_created", name=name, project_id=project_id)
return project_id
def set_state(
project_name: str,
category: str,
key: str,
value: str,
source: str = "",
confidence: float = 1.0,
) -> ProjectStateEntry:
"""Set or update a project state entry. Upsert semantics."""
if category not in CATEGORIES:
raise ValueError(f"Invalid category '{category}'. Must be one of: {CATEGORIES}")
project_id = ensure_project(project_name)
entry_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with get_connection() as conn:
# Check if entry exists
existing = conn.execute(
"SELECT id FROM project_state WHERE project_id = ? AND category = ? AND key = ?",
(project_id, category, key),
).fetchone()
if existing:
entry_id = existing["id"]
conn.execute(
"UPDATE project_state SET value = ?, source = ?, confidence = ?, "
"status = 'active', updated_at = CURRENT_TIMESTAMP "
"WHERE id = ?",
(value, source, confidence, entry_id),
)
log.info("project_state_updated", project=project_name, category=category, key=key)
else:
conn.execute(
"INSERT INTO project_state (id, project_id, category, key, value, source, confidence) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(entry_id, project_id, category, key, value, source, confidence),
)
log.info("project_state_created", project=project_name, category=category, key=key)
return ProjectStateEntry(
id=entry_id,
project_id=project_id,
category=category,
key=key,
value=value,
source=source,
confidence=confidence,
status="active",
created_at=now,
updated_at=now,
)
def get_state(
project_name: str,
category: str | None = None,
active_only: bool = True,
) -> list[ProjectStateEntry]:
"""Get project state entries, optionally filtered by category."""
with get_connection() as conn:
project = conn.execute(
"SELECT id FROM projects WHERE name = ?", (project_name,)
).fetchone()
if not project:
return []
query = "SELECT * FROM project_state WHERE project_id = ?"
params: list = [project["id"]]
if category:
query += " AND category = ?"
params.append(category)
if active_only:
query += " AND status = 'active'"
query += " ORDER BY category, key"
rows = conn.execute(query, params).fetchall()
return [
ProjectStateEntry(
id=r["id"],
project_id=r["project_id"],
category=r["category"],
key=r["key"],
value=r["value"],
source=r["source"],
confidence=r["confidence"],
status=r["status"],
created_at=r["created_at"],
updated_at=r["updated_at"],
)
for r in rows
]
def invalidate_state(project_name: str, category: str, key: str) -> bool:
"""Mark a project state entry as superseded."""
with get_connection() as conn:
project = conn.execute(
"SELECT id FROM projects WHERE name = ?", (project_name,)
).fetchone()
if not project:
return False
result = conn.execute(
"UPDATE project_state SET status = 'superseded', updated_at = CURRENT_TIMESTAMP "
"WHERE project_id = ? AND category = ? AND key = ? AND status = 'active'",
(project["id"], category, key),
)
if result.rowcount > 0:
log.info("project_state_invalidated", project=project_name, category=category, key=key)
return True
return False
def format_project_state(entries: list[ProjectStateEntry]) -> str:
"""Format project state entries for context injection."""
if not entries:
return ""
lines = ["--- Trusted Project State ---"]
current_category = ""
for entry in entries:
if entry.category != current_category:
current_category = entry.category
lines.append(f"\n[{current_category.upper()}]")
lines.append(f" {entry.key}: {entry.value}")
if entry.source:
lines.append(f" (source: {entry.source})")
lines.append("\n--- End Project State ---")
return "\n".join(lines)

View File

@@ -15,6 +15,9 @@ from atocore.retrieval.vector_store import get_vector_store
log = get_logger("ingestion") log = get_logger("ingestion")
# Encodings to try when reading markdown files
_ENCODINGS = ["utf-8", "utf-8-sig", "latin-1", "cp1252"]
def ingest_file(file_path: Path) -> dict: def ingest_file(file_path: Path) -> dict:
"""Ingest a single markdown file. Returns stats.""" """Ingest a single markdown file. Returns stats."""
@@ -26,9 +29,9 @@ def ingest_file(file_path: Path) -> dict:
if file_path.suffix.lower() not in (".md", ".markdown"): if file_path.suffix.lower() not in (".md", ".markdown"):
raise ValueError(f"Not a markdown file: {file_path}") raise ValueError(f"Not a markdown file: {file_path}")
# Read and hash # Read with encoding fallback
raw_content = file_path.read_text(encoding="utf-8") raw_content = _read_file_safe(file_path)
file_hash = hashlib.sha256(raw_content.encode()).hexdigest() file_hash = hashlib.sha256(raw_content.encode("utf-8")).hexdigest()
# Check if already ingested and unchanged # Check if already ingested and unchanged
with get_connection() as conn: with get_connection() as conn:
@@ -136,16 +139,24 @@ def ingest_file(file_path: Path) -> dict:
} }
def ingest_folder(folder_path: Path) -> list[dict]: def ingest_folder(folder_path: Path, purge_deleted: bool = True) -> list[dict]:
"""Ingest all markdown files in a folder recursively.""" """Ingest all markdown files in a folder recursively.
Args:
folder_path: Directory to scan for .md files.
purge_deleted: If True, remove DB/vector entries for files
that no longer exist on disk.
"""
folder_path = folder_path.resolve() folder_path = folder_path.resolve()
if not folder_path.is_dir(): if not folder_path.is_dir():
raise NotADirectoryError(f"Not a directory: {folder_path}") raise NotADirectoryError(f"Not a directory: {folder_path}")
results = [] results = []
md_files = sorted(folder_path.rglob("*.md")) md_files = sorted(folder_path.rglob("*.md"))
current_paths = {str(f.resolve()) for f in md_files}
log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files)) log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files))
# Ingest new/changed files
for md_file in md_files: for md_file in md_files:
try: try:
result = ingest_file(md_file) result = ingest_file(md_file)
@@ -154,4 +165,80 @@ def ingest_folder(folder_path: Path) -> list[dict]:
log.error("ingestion_error", file_path=str(md_file), error=str(e)) log.error("ingestion_error", file_path=str(md_file), error=str(e))
results.append({"file": str(md_file), "status": "error", "error": str(e)}) results.append({"file": str(md_file), "status": "error", "error": str(e)})
# Purge entries for deleted files
if purge_deleted:
deleted = _purge_deleted_files(folder_path, current_paths)
if deleted:
log.info("purged_deleted_files", count=deleted)
results.append({"status": "purged", "deleted_count": deleted})
return results return results
def get_ingestion_stats() -> dict:
"""Return ingestion statistics."""
with get_connection() as conn:
docs = conn.execute("SELECT COUNT(*) as c FROM source_documents").fetchone()
chunks = conn.execute("SELECT COUNT(*) as c FROM source_chunks").fetchone()
recent = conn.execute(
"SELECT file_path, title, ingested_at FROM source_documents "
"ORDER BY updated_at DESC LIMIT 5"
).fetchall()
vector_store = get_vector_store()
return {
"total_documents": docs["c"],
"total_chunks": chunks["c"],
"total_vectors": vector_store.count,
"recent_documents": [
{"file_path": r["file_path"], "title": r["title"], "ingested_at": r["ingested_at"]}
for r in recent
],
}
def _read_file_safe(file_path: Path) -> str:
"""Read a file with encoding fallback."""
for encoding in _ENCODINGS:
try:
return file_path.read_text(encoding=encoding)
except (UnicodeDecodeError, ValueError):
continue
# Last resort: read with errors replaced
return file_path.read_text(encoding="utf-8", errors="replace")
def _purge_deleted_files(folder_path: Path, current_paths: set[str]) -> int:
"""Remove DB/vector entries for files under folder_path that no longer exist."""
folder_str = str(folder_path)
deleted_count = 0
vector_store = get_vector_store()
with get_connection() as conn:
# Find documents under this folder
rows = conn.execute(
"SELECT id, file_path FROM source_documents WHERE file_path LIKE ?",
(f"{folder_str}%",),
).fetchall()
for row in rows:
if row["file_path"] not in current_paths:
doc_id = row["id"]
# Get chunk IDs for vector deletion
chunk_ids = [
r["id"]
for r in conn.execute(
"SELECT id FROM source_chunks WHERE document_id = ?",
(doc_id,),
).fetchall()
]
# Delete from DB
conn.execute("DELETE FROM source_chunks WHERE document_id = ?", (doc_id,))
conn.execute("DELETE FROM source_documents WHERE id = ?", (doc_id,))
# Delete from vectors
if chunk_ids:
vector_store.delete(chunk_ids)
log.info("purged_deleted_file", file_path=row["file_path"])
deleted_count += 1
return deleted_count

View File

@@ -4,6 +4,7 @@ from fastapi import FastAPI
from atocore.api.routes import router from atocore.api.routes import router
from atocore.config import settings from atocore.config import settings
from atocore.context.project_state import init_project_state_schema
from atocore.models.database import init_db from atocore.models.database import init_db
from atocore.observability.logger import setup_logging from atocore.observability.logger import setup_logging
@@ -20,6 +21,7 @@ app.include_router(router)
def startup(): def startup():
setup_logging() setup_logging()
init_db() init_db()
init_project_state_schema()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,6 +1,7 @@
"""Tests for the context builder.""" """Tests for the context builder."""
from atocore.context.builder import build_context, get_last_context_pack from atocore.context.builder import build_context, get_last_context_pack
from atocore.context.project_state import init_project_state_schema, set_state
from atocore.ingestion.pipeline import ingest_file from atocore.ingestion.pipeline import ingest_file
from atocore.models.database import init_db from atocore.models.database import init_db
@@ -8,19 +9,20 @@ from atocore.models.database import init_db
def test_build_context_returns_pack(tmp_data_dir, sample_markdown): def test_build_context_returns_pack(tmp_data_dir, sample_markdown):
"""Test that context builder returns a valid pack.""" """Test that context builder returns a valid pack."""
init_db() init_db()
init_project_state_schema()
ingest_file(sample_markdown) ingest_file(sample_markdown)
pack = build_context("What is AtoCore?") pack = build_context("What is AtoCore?")
assert pack.total_chars > 0 assert pack.total_chars > 0
assert len(pack.chunks_used) > 0 assert len(pack.chunks_used) > 0
assert pack.budget_remaining >= 0 assert pack.budget_remaining >= 0
assert "--- AtoCore Context ---" in pack.formatted_context
assert "--- End Context ---" in pack.formatted_context assert "--- End Context ---" in pack.formatted_context
def test_context_respects_budget(tmp_data_dir, sample_markdown): def test_context_respects_budget(tmp_data_dir, sample_markdown):
"""Test that context builder respects character budget.""" """Test that context builder respects character budget."""
init_db() init_db()
init_project_state_schema()
ingest_file(sample_markdown) ingest_file(sample_markdown)
pack = build_context("What is AtoCore?", budget=500) pack = build_context("What is AtoCore?", budget=500)
@@ -30,17 +32,18 @@ def test_context_respects_budget(tmp_data_dir, sample_markdown):
def test_context_with_project_hint(tmp_data_dir, sample_markdown): def test_context_with_project_hint(tmp_data_dir, sample_markdown):
"""Test that project hint boosts relevant chunks.""" """Test that project hint boosts relevant chunks."""
init_db() init_db()
init_project_state_schema()
ingest_file(sample_markdown) ingest_file(sample_markdown)
pack = build_context("What is the architecture?", project_hint="atocore") pack = build_context("What is the architecture?", project_hint="atocore")
assert len(pack.chunks_used) > 0 assert len(pack.chunks_used) > 0
# With project hint, we should still get results
assert pack.total_chars > 0 assert pack.total_chars > 0
def test_last_context_pack_stored(tmp_data_dir, sample_markdown): def test_last_context_pack_stored(tmp_data_dir, sample_markdown):
"""Test that last context pack is stored for debug.""" """Test that last context pack is stored for debug."""
init_db() init_db()
init_project_state_schema()
ingest_file(sample_markdown) ingest_file(sample_markdown)
build_context("test prompt") build_context("test prompt")
@@ -52,9 +55,54 @@ def test_last_context_pack_stored(tmp_data_dir, sample_markdown):
def test_full_prompt_structure(tmp_data_dir, sample_markdown): def test_full_prompt_structure(tmp_data_dir, sample_markdown):
"""Test that the full prompt has correct structure.""" """Test that the full prompt has correct structure."""
init_db() init_db()
init_project_state_schema()
ingest_file(sample_markdown) ingest_file(sample_markdown)
pack = build_context("What are memory types?") pack = build_context("What are memory types?")
assert "knowledge base" in pack.full_prompt.lower() assert "knowledge base" in pack.full_prompt.lower()
assert "--- AtoCore Context ---" in pack.full_prompt
assert "What are memory types?" in pack.full_prompt assert "What are memory types?" in pack.full_prompt
def test_project_state_included_in_context(tmp_data_dir, sample_markdown):
"""Test that trusted project state is injected into context."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
# Set some project state
set_state("atocore", "status", "phase", "Phase 0.5 complete")
set_state("atocore", "decision", "database", "SQLite for structured data")
pack = build_context("What is AtoCore?", project_hint="atocore")
# Project state should appear in context
assert "--- Trusted Project State ---" in pack.formatted_context
assert "Phase 0.5 complete" in pack.formatted_context
assert "SQLite for structured data" in pack.formatted_context
assert pack.project_state_chars > 0
def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown):
"""Test that project state is included even with tight budget."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
set_state("atocore", "status", "phase", "Phase 1 in progress")
# Small budget — project state should still be included
pack = build_context("status?", project_hint="atocore", budget=500)
assert "Phase 1 in progress" in pack.formatted_context
def test_no_project_state_without_hint(tmp_data_dir, sample_markdown):
"""Test that project state is not included without project hint."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
set_state("atocore", "status", "phase", "Phase 1")
pack = build_context("What is AtoCore?")
assert pack.project_state_chars == 0
assert "--- Trusted Project State ---" not in pack.formatted_context

127
tests/test_project_state.py Normal file
View File

@@ -0,0 +1,127 @@
"""Tests for Trusted Project State."""
import pytest
from atocore.context.project_state import (
CATEGORIES,
ensure_project,
format_project_state,
get_state,
init_project_state_schema,
invalidate_state,
set_state,
)
from atocore.models.database import init_db
@pytest.fixture(autouse=True)
def setup_db(tmp_data_dir):
"""Initialize DB and project state schema for every test."""
init_db()
init_project_state_schema()
def test_ensure_project_creates():
"""Test creating a new project."""
pid = ensure_project("test-project", "A test project")
assert pid
# Second call returns same ID
pid2 = ensure_project("test-project")
assert pid == pid2
def test_set_state_creates_entry():
"""Test creating a project state entry."""
entry = set_state("myproject", "status", "phase", "Phase 0.5 — PoC complete")
assert entry.category == "status"
assert entry.key == "phase"
assert entry.value == "Phase 0.5 — PoC complete"
assert entry.status == "active"
def test_set_state_upserts():
"""Test that setting same key updates the value."""
set_state("myproject", "status", "phase", "Phase 0")
entry = set_state("myproject", "status", "phase", "Phase 1")
assert entry.value == "Phase 1"
# Only one entry should exist
entries = get_state("myproject", category="status")
assert len(entries) == 1
assert entries[0].value == "Phase 1"
def test_set_state_invalid_category():
"""Test that invalid category raises ValueError."""
with pytest.raises(ValueError, match="Invalid category"):
set_state("myproject", "invalid_category", "key", "value")
def test_get_state_all():
"""Test getting all state entries for a project."""
set_state("proj", "status", "phase", "Phase 1")
set_state("proj", "decision", "database", "SQLite for v1")
set_state("proj", "requirement", "latency", "<2 seconds")
entries = get_state("proj")
assert len(entries) == 3
categories = {e.category for e in entries}
assert categories == {"status", "decision", "requirement"}
def test_get_state_by_category():
"""Test filtering by category."""
set_state("proj", "status", "phase", "Phase 1")
set_state("proj", "decision", "database", "SQLite")
set_state("proj", "decision", "vectordb", "ChromaDB")
entries = get_state("proj", category="decision")
assert len(entries) == 2
assert all(e.category == "decision" for e in entries)
def test_get_state_nonexistent_project():
"""Test getting state for a project that doesn't exist."""
entries = get_state("nonexistent")
assert entries == []
def test_invalidate_state():
"""Test marking a state entry as superseded."""
set_state("invalidate-test", "decision", "approach", "monolith")
success = invalidate_state("invalidate-test", "decision", "approach")
assert success
# Active entries should be empty
entries = get_state("invalidate-test", active_only=True)
assert len(entries) == 0
# But entry still exists if we include inactive
entries = get_state("invalidate-test", active_only=False)
assert len(entries) == 1
assert entries[0].status == "superseded"
def test_invalidate_nonexistent():
"""Test invalidating a nonexistent entry."""
success = invalidate_state("proj", "decision", "nonexistent")
assert not success
def test_format_project_state():
"""Test formatting state entries for context injection."""
set_state("proj", "status", "phase", "Phase 1")
set_state("proj", "decision", "database", "SQLite", source="Build Spec V1")
entries = get_state("proj")
formatted = format_project_state(entries)
assert "--- Trusted Project State ---" in formatted
assert "--- End Project State ---" in formatted
assert "phase: Phase 1" in formatted
assert "database: SQLite" in formatted
assert "(source: Build Spec V1)" in formatted
def test_format_empty():
"""Test formatting empty state."""
assert format_project_state([]) == ""