Stabilize core correctness and sync project plan state

This commit is contained in:
2026-04-05 17:53:23 -04:00
parent b48f0c95ab
commit b0889b3925
20 changed files with 551 additions and 168 deletions

View File

@@ -33,4 +33,4 @@ where = ["src"]
testpaths = ["tests"] testpaths = ["tests"]
python_files = ["test_*.py"] python_files = ["test_*.py"]
python_functions = ["test_*"] python_functions = ["test_*"]
addopts = "--cov=atocore --cov-report=term-missing -v" addopts = "-v"

View File

@@ -192,6 +192,7 @@ def api_create_memory(req: MemoryCreateRequest) -> dict:
@router.get("/memory") @router.get("/memory")
def api_get_memories( def api_get_memories(
memory_type: str | None = None, memory_type: str | None = None,
project: str | None = None,
active_only: bool = True, active_only: bool = True,
min_confidence: float = 0.0, min_confidence: float = 0.0,
limit: int = 50, limit: int = 50,
@@ -199,6 +200,7 @@ def api_get_memories(
"""List memories, optionally filtered.""" """List memories, optionally filtered."""
memories = get_memories( memories = get_memories(
memory_type=memory_type, memory_type=memory_type,
project=project,
active_only=active_only, active_only=active_only,
min_confidence=min_confidence, min_confidence=min_confidence,
limit=limit, limit=limit,
@@ -209,6 +211,7 @@ def api_get_memories(
"id": m.id, "id": m.id,
"memory_type": m.memory_type, "memory_type": m.memory_type,
"content": m.content, "content": m.content,
"project": m.project,
"confidence": m.confidence, "confidence": m.confidence,
"status": m.status, "status": m.status,
"updated_at": m.updated_at, "updated_at": m.updated_at,

View File

@@ -10,7 +10,7 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from atocore.config import settings import atocore.config as _config
from atocore.context.project_state import format_project_state, get_state from atocore.context.project_state import format_project_state, get_state
from atocore.memory.service import get_memories_for_context from atocore.memory.service import get_memories_for_context
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
@@ -74,20 +74,27 @@ def build_context(
""" """
global _last_context_pack global _last_context_pack
start = time.time() start = time.time()
budget = budget or settings.context_budget budget = _config.settings.context_budget if budget is None else max(budget, 0)
# 1. Get Trusted Project State (highest precedence) # 1. Get Trusted Project State (highest precedence)
project_state_text = "" project_state_text = ""
project_state_chars = 0 project_state_chars = 0
project_state_budget = min(
budget,
max(0, int(budget * PROJECT_STATE_BUDGET_RATIO)),
)
if project_hint: if project_hint:
state_entries = get_state(project_hint) state_entries = get_state(project_hint)
if state_entries: if state_entries:
project_state_text = format_project_state(state_entries) project_state_text = format_project_state(state_entries)
project_state_chars = len(project_state_text) project_state_text, project_state_chars = _truncate_text_block(
project_state_text,
project_state_budget or budget,
)
# 2. Get identity + preference memories (second precedence) # 2. Get identity + preference memories (second precedence)
memory_budget = int(budget * MEMORY_BUDGET_RATIO) memory_budget = min(int(budget * MEMORY_BUDGET_RATIO), max(budget - project_state_chars, 0))
memory_text, memory_chars = get_memories_for_context( memory_text, memory_chars = get_memories_for_context(
memory_types=["identity", "preference"], memory_types=["identity", "preference"],
budget=memory_budget, budget=memory_budget,
@@ -97,7 +104,7 @@ def build_context(
retrieval_budget = budget - project_state_chars - memory_chars retrieval_budget = budget - project_state_chars - memory_chars
# 4. Retrieve candidates # 4. Retrieve candidates
candidates = retrieve(user_prompt, top_k=settings.context_top_k) candidates = retrieve(user_prompt, top_k=_config.settings.context_top_k) if retrieval_budget > 0 else []
# 5. Score and rank # 5. Score and rank
scored = _rank_chunks(candidates, project_hint) scored = _rank_chunks(candidates, project_hint)
@@ -107,12 +114,21 @@ def build_context(
# 7. Format full context # 7. Format full context
formatted = _format_full_context(project_state_text, memory_text, selected) formatted = _format_full_context(project_state_text, memory_text, selected)
if len(formatted) > budget:
formatted, selected = _trim_context_to_budget(
project_state_text,
memory_text,
selected,
budget,
)
# 8. Build full prompt # 8. Build full prompt
full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}" full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}"
project_state_chars = len(project_state_text)
memory_chars = len(memory_text)
retrieval_chars = sum(c.char_count for c in selected) retrieval_chars = sum(c.char_count for c in selected)
total_chars = project_state_chars + memory_chars + retrieval_chars total_chars = len(formatted)
duration_ms = int((time.time() - start) * 1000) duration_ms = int((time.time() - start) * 1000)
pack = ContextPack( pack = ContextPack(
@@ -235,6 +251,8 @@ def _format_full_context(
# 3. Retrieved chunks (lowest trust) # 3. Retrieved chunks (lowest trust)
if chunks: if chunks:
parts.append("--- AtoCore Retrieved Context ---") parts.append("--- AtoCore Retrieved Context ---")
if project_state_text:
parts.append("If retrieved context conflicts with Trusted Project State above, trust the Trusted Project State.")
for chunk in chunks: for chunk in chunks:
parts.append( parts.append(
f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]" f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]"
@@ -282,3 +300,44 @@ def _pack_to_dict(pack: ContextPack) -> dict:
for c in pack.chunks_used for c in pack.chunks_used
], ],
} }
def _truncate_text_block(text: str, budget: int) -> tuple[str, int]:
"""Trim a formatted text block so trusted tiers cannot exceed the total budget."""
if budget <= 0 or not text:
return "", 0
if len(text) <= budget:
return text, len(text)
if budget <= 3:
trimmed = text[:budget]
else:
trimmed = f"{text[: budget - 3].rstrip()}..."
return trimmed, len(trimmed)
def _trim_context_to_budget(
project_state_text: str,
memory_text: str,
chunks: list[ContextChunk],
budget: int,
) -> tuple[str, list[ContextChunk]]:
"""Trim retrieval first, then memory, then project state until formatted context fits."""
kept_chunks = list(chunks)
formatted = _format_full_context(project_state_text, memory_text, kept_chunks)
while len(formatted) > budget and kept_chunks:
kept_chunks.pop()
formatted = _format_full_context(project_state_text, memory_text, kept_chunks)
if len(formatted) <= budget:
return formatted, kept_chunks
memory_text, _ = _truncate_text_block(memory_text, max(budget - len(project_state_text), 0))
formatted = _format_full_context(project_state_text, memory_text, kept_chunks)
if len(formatted) <= budget:
return formatted, kept_chunks
project_state_text, _ = _truncate_text_block(project_state_text, budget)
formatted = _format_full_context(project_state_text, "", [])
if len(formatted) > budget:
formatted, _ = _truncate_text_block(formatted, budget)
return formatted, []

View File

@@ -12,10 +12,8 @@ Project state is manually curated or explicitly confirmed facts about a project.
It always wins over retrieval-based context when there's a conflict. It always wins over retrieval-based context when there's a conflict.
""" """
import json
import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from atocore.models.database import get_connection from atocore.models.database import get_connection
@@ -81,7 +79,7 @@ def ensure_project(name: str, description: str = "") -> str:
"""Get or create a project by name. Returns project_id.""" """Get or create a project by name. Returns project_id."""
with get_connection() as conn: with get_connection() as conn:
row = conn.execute( row = conn.execute(
"SELECT id FROM projects WHERE name = ?", (name,) "SELECT id FROM projects WHERE lower(name) = lower(?)", (name,)
).fetchone() ).fetchone()
if row: if row:
return row["id"] return row["id"]
@@ -106,6 +104,7 @@ def set_state(
"""Set or update a project state entry. Upsert semantics.""" """Set or update a project state entry. Upsert semantics."""
if category not in CATEGORIES: if category not in CATEGORIES:
raise ValueError(f"Invalid category '{category}'. Must be one of: {CATEGORIES}") raise ValueError(f"Invalid category '{category}'. Must be one of: {CATEGORIES}")
_validate_confidence(confidence)
project_id = ensure_project(project_name) project_id = ensure_project(project_name)
entry_id = str(uuid.uuid4()) entry_id = str(uuid.uuid4())
@@ -157,7 +156,7 @@ def get_state(
"""Get project state entries, optionally filtered by category.""" """Get project state entries, optionally filtered by category."""
with get_connection() as conn: with get_connection() as conn:
project = conn.execute( project = conn.execute(
"SELECT id FROM projects WHERE name = ?", (project_name,) "SELECT id FROM projects WHERE lower(name) = lower(?)", (project_name,)
).fetchone() ).fetchone()
if not project: if not project:
return [] return []
@@ -195,7 +194,7 @@ def invalidate_state(project_name: str, category: str, key: str) -> bool:
"""Mark a project state entry as superseded.""" """Mark a project state entry as superseded."""
with get_connection() as conn: with get_connection() as conn:
project = conn.execute( project = conn.execute(
"SELECT id FROM projects WHERE name = ?", (project_name,) "SELECT id FROM projects WHERE lower(name) = lower(?)", (project_name,)
).fetchone() ).fetchone()
if not project: if not project:
return False return False
@@ -229,3 +228,8 @@ def format_project_state(entries: list[ProjectStateEntry]) -> str:
lines.append("\n--- End Project State ---") lines.append("\n--- End Project State ---")
return "\n".join(lines) return "\n".join(lines)
def _validate_confidence(confidence: float) -> None:
if not 0.0 <= confidence <= 1.0:
raise ValueError("Confidence must be between 0.0 and 1.0")

View File

@@ -3,7 +3,7 @@
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from atocore.config import settings import atocore.config as _config
@dataclass @dataclass
@@ -29,9 +29,9 @@ def chunk_markdown(
3. If still > max_size, split on paragraph breaks 3. If still > max_size, split on paragraph breaks
4. If still > max_size, hard split with overlap 4. If still > max_size, hard split with overlap
""" """
max_size = max_size or settings.chunk_max_size max_size = max_size or _config.settings.chunk_max_size
overlap = overlap or settings.chunk_overlap overlap = overlap or _config.settings.chunk_overlap
min_size = min_size or settings.chunk_min_size min_size = min_size or _config.settings.chunk_min_size
base_metadata = base_metadata or {} base_metadata = base_metadata or {}
sections = _split_by_heading(body, level=2) sections = _split_by_heading(body, level=2)

View File

@@ -17,10 +17,10 @@ class ParsedDocument:
headings: list[tuple[int, str]] = field(default_factory=list) headings: list[tuple[int, str]] = field(default_factory=list)
def parse_markdown(file_path: Path) -> ParsedDocument: def parse_markdown(file_path: Path, text: str | None = None) -> ParsedDocument:
"""Parse a markdown file, extracting frontmatter and structure.""" """Parse a markdown file, extracting frontmatter and structure."""
text = file_path.read_text(encoding="utf-8") raw_text = text if text is not None else file_path.read_text(encoding="utf-8")
post = frontmatter.loads(text) post = frontmatter.loads(raw_text)
meta = dict(post.metadata) if post.metadata else {} meta = dict(post.metadata) if post.metadata else {}
body = post.content.strip() body = post.content.strip()

View File

@@ -6,7 +6,6 @@ import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from atocore.config import settings
from atocore.ingestion.chunker import chunk_markdown from atocore.ingestion.chunker import chunk_markdown
from atocore.ingestion.parser import parse_markdown from atocore.ingestion.parser import parse_markdown
from atocore.models.database import get_connection from atocore.models.database import get_connection
@@ -45,7 +44,7 @@ def ingest_file(file_path: Path) -> dict:
return {"file": str(file_path), "status": "skipped", "reason": "unchanged"} return {"file": str(file_path), "status": "skipped", "reason": "unchanged"}
# Parse # Parse
parsed = parse_markdown(file_path) parsed = parse_markdown(file_path, text=raw_content)
# Chunk # Chunk
base_meta = { base_meta = {
@@ -55,85 +54,98 @@ def ingest_file(file_path: Path) -> dict:
} }
chunks = chunk_markdown(parsed.body, base_metadata=base_meta) chunks = chunk_markdown(parsed.body, base_metadata=base_meta)
if not chunks:
log.warning("no_chunks_created", file_path=str(file_path))
return {"file": str(file_path), "status": "empty", "chunks": 0}
# Store in DB and vector store # Store in DB and vector store
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
vector_store = get_vector_store() vector_store = get_vector_store()
old_chunk_ids: list[str] = []
new_chunk_ids: list[str] = []
with get_connection() as conn: try:
# Remove old data if re-ingesting with get_connection() as conn:
if existing: # Remove old data if re-ingesting
doc_id = existing["id"] if existing:
old_chunk_ids = [ doc_id = existing["id"]
row["id"] old_chunk_ids = [
for row in conn.execute( row["id"]
"SELECT id FROM source_chunks WHERE document_id = ?", for row in conn.execute(
(doc_id,), "SELECT id FROM source_chunks WHERE document_id = ?",
).fetchall() (doc_id,),
] ).fetchall()
conn.execute( ]
"DELETE FROM source_chunks WHERE document_id = ?", (doc_id,) conn.execute(
) "DELETE FROM source_chunks WHERE document_id = ?", (doc_id,)
conn.execute( )
"UPDATE source_documents SET file_hash = ?, title = ?, tags = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", conn.execute(
(file_hash, parsed.title, json.dumps(parsed.tags), doc_id), "UPDATE source_documents SET file_hash = ?, title = ?, tags = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
) (file_hash, parsed.title, json.dumps(parsed.tags), doc_id),
# Remove old vectors )
if old_chunk_ids: else:
vector_store.delete(old_chunk_ids) conn.execute(
else: "INSERT INTO source_documents (id, file_path, file_hash, title, doc_type, tags) VALUES (?, ?, ?, ?, ?, ?)",
conn.execute( (doc_id, str(file_path), file_hash, parsed.title, "markdown", json.dumps(parsed.tags)),
"INSERT INTO source_documents (id, file_path, file_hash, title, doc_type, tags) VALUES (?, ?, ?, ?, ?, ?)", )
(doc_id, str(file_path), file_hash, parsed.title, "markdown", json.dumps(parsed.tags)),
)
# Insert chunks if not chunks:
chunk_ids = [] log.warning("no_chunks_created", file_path=str(file_path))
chunk_contents = [] else:
chunk_metadatas = [] # Insert chunks
chunk_contents = []
chunk_metadatas = []
for chunk in chunks: for chunk in chunks:
chunk_id = str(uuid.uuid4()) chunk_id = str(uuid.uuid4())
chunk_ids.append(chunk_id) new_chunk_ids.append(chunk_id)
chunk_contents.append(chunk.content) chunk_contents.append(chunk.content)
chunk_metadatas.append({ chunk_metadatas.append({
"document_id": doc_id, "document_id": doc_id,
"heading_path": chunk.heading_path, "heading_path": chunk.heading_path,
"source_file": str(file_path), "source_file": str(file_path),
"tags": json.dumps(parsed.tags), "tags": json.dumps(parsed.tags),
"title": parsed.title, "title": parsed.title,
}) })
conn.execute( conn.execute(
"INSERT INTO source_chunks (id, document_id, chunk_index, content, heading_path, char_count, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)", "INSERT INTO source_chunks (id, document_id, chunk_index, content, heading_path, char_count, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)",
( (
chunk_id, chunk_id,
doc_id, doc_id,
chunk.chunk_index, chunk.chunk_index,
chunk.content, chunk.content,
chunk.heading_path, chunk.heading_path,
chunk.char_count, chunk.char_count,
json.dumps(chunk.metadata), json.dumps(chunk.metadata),
), ),
) )
# Store embeddings # Add new vectors before commit so DB can still roll back on failure.
vector_store.add(chunk_ids, chunk_contents, chunk_metadatas) vector_store.add(new_chunk_ids, chunk_contents, chunk_metadatas)
except Exception:
if new_chunk_ids:
vector_store.delete(new_chunk_ids)
raise
# Delete stale vectors only after the DB transaction committed.
if old_chunk_ids:
vector_store.delete(old_chunk_ids)
duration_ms = int((time.time() - start) * 1000) duration_ms = int((time.time() - start) * 1000)
log.info( if chunks:
"file_ingested", log.info(
file_path=str(file_path), "file_ingested",
chunks_created=len(chunks), file_path=str(file_path),
duration_ms=duration_ms, chunks_created=len(chunks),
) duration_ms=duration_ms,
)
else:
log.info(
"file_ingested_empty",
file_path=str(file_path),
duration_ms=duration_ms,
)
return { return {
"file": str(file_path), "file": str(file_path),
"status": "ingested", "status": "ingested" if chunks else "empty",
"chunks": len(chunks), "chunks": len(chunks),
"duration_ms": duration_ms, "duration_ms": duration_ms,
} }
@@ -152,7 +164,9 @@ def ingest_folder(folder_path: Path, purge_deleted: bool = True) -> list[dict]:
raise NotADirectoryError(f"Not a directory: {folder_path}") raise NotADirectoryError(f"Not a directory: {folder_path}")
results = [] results = []
md_files = sorted(folder_path.rglob("*.md")) md_files = sorted(
list(folder_path.rglob("*.md")) + list(folder_path.rglob("*.markdown"))
)
current_paths = {str(f.resolve()) for f in md_files} current_paths = {str(f.resolve()) for f in md_files}
log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files)) log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files))
@@ -213,32 +227,35 @@ def _purge_deleted_files(folder_path: Path, current_paths: set[str]) -> int:
folder_str = str(folder_path) folder_str = str(folder_path)
deleted_count = 0 deleted_count = 0
vector_store = get_vector_store() vector_store = get_vector_store()
chunk_ids_to_delete: list[str] = []
with get_connection() as conn: with get_connection() as conn:
# Find documents under this folder
rows = conn.execute( rows = conn.execute(
"SELECT id, file_path FROM source_documents WHERE file_path LIKE ?", "SELECT id, file_path FROM source_documents"
(f"{folder_str}%",),
).fetchall() ).fetchall()
for row in rows: for row in rows:
doc_path = Path(row["file_path"])
try:
doc_path.relative_to(folder_path)
except ValueError:
continue
if row["file_path"] not in current_paths: if row["file_path"] not in current_paths:
doc_id = row["id"] doc_id = row["id"]
# Get chunk IDs for vector deletion chunk_ids_to_delete.extend(
chunk_ids = [
r["id"] r["id"]
for r in conn.execute( for r in conn.execute(
"SELECT id FROM source_chunks WHERE document_id = ?", "SELECT id FROM source_chunks WHERE document_id = ?",
(doc_id,), (doc_id,),
).fetchall() ).fetchall()
] )
# Delete from DB
conn.execute("DELETE FROM source_chunks WHERE document_id = ?", (doc_id,)) conn.execute("DELETE FROM source_chunks WHERE document_id = ?", (doc_id,))
conn.execute("DELETE FROM source_documents WHERE id = ?", (doc_id,)) conn.execute("DELETE FROM source_documents WHERE id = ?", (doc_id,))
# Delete from vectors
if chunk_ids:
vector_store.delete(chunk_ids)
log.info("purged_deleted_file", file_path=row["file_path"]) log.info("purged_deleted_file", file_path=row["file_path"])
deleted_count += 1 deleted_count += 1
if chunk_ids_to_delete:
vector_store.delete(chunk_ids_to_delete)
return deleted_count return deleted_count

View File

@@ -3,7 +3,7 @@
from fastapi import FastAPI from fastapi import FastAPI
from atocore.api.routes import router from atocore.api.routes import router
from atocore.config import settings import atocore.config as _config
from atocore.context.project_state import init_project_state_schema from atocore.context.project_state import init_project_state_schema
from atocore.models.database import init_db from atocore.models.database import init_db
from atocore.observability.logger import setup_logging from atocore.observability.logger import setup_logging
@@ -29,7 +29,7 @@ if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"atocore.main:app", "atocore.main:app",
host=settings.host, host=_config.settings.host,
port=settings.port, port=_config.settings.port,
reload=True, reload=True,
) )

View File

@@ -14,7 +14,6 @@ Memories have:
- optional link to source chunk: traceability - optional link to source chunk: traceability
""" """
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -57,6 +56,7 @@ def create_memory(
"""Create a new memory entry.""" """Create a new memory entry."""
if memory_type not in MEMORY_TYPES: if memory_type not in MEMORY_TYPES:
raise ValueError(f"Invalid memory type '{memory_type}'. Must be one of: {MEMORY_TYPES}") raise ValueError(f"Invalid memory type '{memory_type}'. Must be one of: {MEMORY_TYPES}")
_validate_confidence(confidence)
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
@@ -64,8 +64,9 @@ def create_memory(
# Check for duplicate content within same type+project # Check for duplicate content within same type+project
with get_connection() as conn: with get_connection() as conn:
existing = conn.execute( existing = conn.execute(
"SELECT id FROM memories WHERE memory_type = ? AND content = ? AND status = 'active'", "SELECT id FROM memories "
(memory_type, content), "WHERE memory_type = ? AND content = ? AND project = ? AND status = 'active'",
(memory_type, content, project),
).fetchone() ).fetchone()
if existing: if existing:
log.info("memory_duplicate_skipped", memory_type=memory_type, content_preview=content[:80]) log.info("memory_duplicate_skipped", memory_type=memory_type, content_preview=content[:80])
@@ -74,9 +75,9 @@ def create_memory(
) )
conn.execute( conn.execute(
"INSERT INTO memories (id, memory_type, content, source_chunk_id, confidence, status) " "INSERT INTO memories (id, memory_type, content, project, source_chunk_id, confidence, status) "
"VALUES (?, ?, ?, ?, ?, 'active')", "VALUES (?, ?, ?, ?, ?, ?, 'active')",
(memory_id, memory_type, content, source_chunk_id or None, confidence), (memory_id, memory_type, content, project, source_chunk_id or None, confidence),
) )
log.info("memory_created", memory_type=memory_type, content_preview=content[:80]) log.info("memory_created", memory_type=memory_type, content_preview=content[:80])
@@ -96,6 +97,7 @@ def create_memory(
def get_memories( def get_memories(
memory_type: str | None = None, memory_type: str | None = None,
project: str | None = None,
active_only: bool = True, active_only: bool = True,
min_confidence: float = 0.0, min_confidence: float = 0.0,
limit: int = 50, limit: int = 50,
@@ -107,6 +109,9 @@ def get_memories(
if memory_type: if memory_type:
query += " AND memory_type = ?" query += " AND memory_type = ?"
params.append(memory_type) params.append(memory_type)
if project is not None:
query += " AND project = ?"
params.append(project)
if active_only: if active_only:
query += " AND status = 'active'" query += " AND status = 'active'"
if min_confidence > 0: if min_confidence > 0:
@@ -129,28 +134,46 @@ def update_memory(
status: str | None = None, status: str | None = None,
) -> bool: ) -> bool:
"""Update an existing memory.""" """Update an existing memory."""
updates = []
params: list = []
if content is not None:
updates.append("content = ?")
params.append(content)
if confidence is not None:
updates.append("confidence = ?")
params.append(confidence)
if status is not None:
if status not in ("active", "superseded", "invalid"):
raise ValueError(f"Invalid status '{status}'")
updates.append("status = ?")
params.append(status)
if not updates:
return False
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(memory_id)
with get_connection() as conn: with get_connection() as conn:
existing = conn.execute("SELECT * FROM memories WHERE id = ?", (memory_id,)).fetchone()
if existing is None:
return False
next_content = content if content is not None else existing["content"]
next_status = status if status is not None else existing["status"]
if confidence is not None:
_validate_confidence(confidence)
if next_status == "active":
duplicate = conn.execute(
"SELECT id FROM memories "
"WHERE memory_type = ? AND content = ? AND project = ? AND status = 'active' AND id != ?",
(existing["memory_type"], next_content, existing["project"] or "", memory_id),
).fetchone()
if duplicate:
raise ValueError("Update would create a duplicate active memory")
updates = []
params: list = []
if content is not None:
updates.append("content = ?")
params.append(content)
if confidence is not None:
updates.append("confidence = ?")
params.append(confidence)
if status is not None:
if status not in ("active", "superseded", "invalid"):
raise ValueError(f"Invalid status '{status}'")
updates.append("status = ?")
params.append(status)
if not updates:
return False
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(memory_id)
result = conn.execute( result = conn.execute(
f"UPDATE memories SET {', '.join(updates)} WHERE id = ?", f"UPDATE memories SET {', '.join(updates)} WHERE id = ?",
params, params,
@@ -174,6 +197,7 @@ def supersede_memory(memory_id: str) -> bool:
def get_memories_for_context( def get_memories_for_context(
memory_types: list[str] | None = None, memory_types: list[str] | None = None,
project: str | None = None,
budget: int = 500, budget: int = 500,
) -> tuple[str, int]: ) -> tuple[str, int]:
"""Get formatted memories for context injection. """Get formatted memories for context injection.
@@ -186,33 +210,42 @@ def get_memories_for_context(
if memory_types is None: if memory_types is None:
memory_types = ["identity", "preference"] memory_types = ["identity", "preference"]
memories = [] if budget <= 0:
for mtype in memory_types:
memories.extend(get_memories(memory_type=mtype, min_confidence=0.5, limit=10))
if not memories:
return "", 0 return "", 0
lines = ["--- AtoCore Memory ---"] header = "--- AtoCore Memory ---"
used = len(lines[0]) + 1 footer = "--- End Memory ---"
included = [] wrapper_chars = len(header) + len(footer) + 2
if budget <= wrapper_chars:
for mem in memories:
entry = f"[{mem.memory_type}] {mem.content}"
entry_len = len(entry) + 1
if used + entry_len > budget:
break
lines.append(entry)
used += entry_len
included.append(mem)
if len(included) == 0:
return "", 0 return "", 0
lines.append("--- End Memory ---") available = budget - wrapper_chars
selected_entries: list[str] = []
for index, mtype in enumerate(memory_types):
type_budget = available if index == len(memory_types) - 1 else max(0, available // (len(memory_types) - index))
type_used = 0
for mem in get_memories(
memory_type=mtype,
project=project,
min_confidence=0.5,
limit=10,
):
entry = f"[{mem.memory_type}] {mem.content}"
entry_len = len(entry) + 1
if entry_len > type_budget - type_used:
continue
selected_entries.append(entry)
type_used += entry_len
available -= type_used
if not selected_entries:
return "", 0
lines = [header, *selected_entries, footer]
text = "\n".join(lines) text = "\n".join(lines)
log.info("memories_for_context", count=len(included), chars=len(text)) log.info("memories_for_context", count=len(selected_entries), chars=len(text))
return text, len(text) return text, len(text)
@@ -222,10 +255,15 @@ def _row_to_memory(row) -> Memory:
id=row["id"], id=row["id"],
memory_type=row["memory_type"], memory_type=row["memory_type"],
content=row["content"], content=row["content"],
project="", project=row["project"] or "",
source_chunk_id=row["source_chunk_id"] or "", source_chunk_id=row["source_chunk_id"] or "",
confidence=row["confidence"], confidence=row["confidence"],
status=row["status"], status=row["status"],
created_at=row["created_at"], created_at=row["created_at"],
updated_at=row["updated_at"], updated_at=row["updated_at"],
) )
def _validate_confidence(confidence: float) -> None:
if not 0.0 <= confidence <= 1.0:
raise ValueError("Confidence must be between 0.0 and 1.0")

View File

@@ -37,6 +37,7 @@ CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
memory_type TEXT NOT NULL, memory_type TEXT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
project TEXT DEFAULT '',
source_chunk_id TEXT REFERENCES source_chunks(id), source_chunk_id TEXT REFERENCES source_chunks(id),
confidence REAL DEFAULT 1.0, confidence REAL DEFAULT 1.0,
status TEXT DEFAULT 'active', status TEXT DEFAULT 'active',
@@ -64,6 +65,7 @@ CREATE TABLE IF NOT EXISTS interactions (
CREATE INDEX IF NOT EXISTS idx_chunks_document ON source_chunks(document_id); CREATE INDEX IF NOT EXISTS idx_chunks_document ON source_chunks(document_id);
CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(memory_type); CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(memory_type);
CREATE INDEX IF NOT EXISTS idx_memories_project ON memories(project);
CREATE INDEX IF NOT EXISTS idx_memories_status ON memories(status); CREATE INDEX IF NOT EXISTS idx_memories_status ON memories(status);
CREATE INDEX IF NOT EXISTS idx_interactions_project ON interactions(project_id); CREATE INDEX IF NOT EXISTS idx_interactions_project ON interactions(project_id);
""" """
@@ -78,9 +80,22 @@ def init_db() -> None:
_ensure_data_dir() _ensure_data_dir()
with get_connection() as conn: with get_connection() as conn:
conn.executescript(SCHEMA_SQL) conn.executescript(SCHEMA_SQL)
_apply_migrations(conn)
log.info("database_initialized", path=str(_config.settings.db_path)) log.info("database_initialized", path=str(_config.settings.db_path))
def _apply_migrations(conn: sqlite3.Connection) -> None:
"""Apply lightweight schema migrations for existing local databases."""
if not _column_exists(conn, "memories", "project"):
conn.execute("ALTER TABLE memories ADD COLUMN project TEXT DEFAULT ''")
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_project ON memories(project)")
def _column_exists(conn: sqlite3.Connection, table: str, column: str) -> bool:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
return any(row["name"] == column for row in rows)
@contextmanager @contextmanager
def get_connection() -> Generator[sqlite3.Connection, None, None]: def get_connection() -> Generator[sqlite3.Connection, None, None]:
"""Get a database connection with row factory.""" """Get a database connection with row factory."""

View File

@@ -2,10 +2,9 @@
import logging import logging
import atocore.config as _config
import structlog import structlog
from atocore.config import settings
_LOG_LEVELS = { _LOG_LEVELS = {
"DEBUG": logging.DEBUG, "DEBUG": logging.DEBUG,
"INFO": logging.INFO, "INFO": logging.INFO,
@@ -16,7 +15,7 @@ _LOG_LEVELS = {
def setup_logging() -> None: def setup_logging() -> None:
"""Configure structlog with JSON output.""" """Configure structlog with JSON output."""
log_level = "DEBUG" if settings.debug else "INFO" log_level = "DEBUG" if _config.settings.debug else "INFO"
structlog.configure( structlog.configure(
processors=[ processors=[

View File

@@ -1,8 +1,8 @@
"""Embedding model management.""" """Embedding model management."""
import atocore.config as _config
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from atocore.config import settings
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
log = get_logger("embeddings") log = get_logger("embeddings")
@@ -14,9 +14,9 @@ def get_model() -> SentenceTransformer:
"""Load and cache the embedding model.""" """Load and cache the embedding model."""
global _model global _model
if _model is None: if _model is None:
log.info("loading_embedding_model", model=settings.embedding_model) log.info("loading_embedding_model", model=_config.settings.embedding_model)
_model = SentenceTransformer(settings.embedding_model) _model = SentenceTransformer(_config.settings.embedding_model)
log.info("embedding_model_loaded", model=settings.embedding_model) log.info("embedding_model_loaded", model=_config.settings.embedding_model)
return _model return _model

View File

@@ -3,7 +3,8 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from atocore.config import settings import atocore.config as _config
from atocore.models.database import get_connection
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
from atocore.retrieval.embeddings import embed_query from atocore.retrieval.embeddings import embed_query
from atocore.retrieval.vector_store import get_vector_store from atocore.retrieval.vector_store import get_vector_store
@@ -29,7 +30,7 @@ def retrieve(
filter_tags: list[str] | None = None, filter_tags: list[str] | None = None,
) -> list[ChunkResult]: ) -> list[ChunkResult]:
"""Retrieve the most relevant chunks for a query.""" """Retrieve the most relevant chunks for a query."""
top_k = top_k or settings.context_top_k top_k = top_k or _config.settings.context_top_k
start = time.time() start = time.time()
query_embedding = embed_query(query) query_embedding = embed_query(query)
@@ -59,7 +60,10 @@ def retrieve(
chunks = [] chunks = []
if results and results["ids"] and results["ids"][0]: if results and results["ids"] and results["ids"][0]:
existing_ids = _existing_chunk_ids(results["ids"][0])
for i, chunk_id in enumerate(results["ids"][0]): for i, chunk_id in enumerate(results["ids"][0]):
if chunk_id not in existing_ids:
continue
# ChromaDB returns distances (lower = more similar for cosine) # ChromaDB returns distances (lower = more similar for cosine)
# Convert to similarity score (1 - distance) # Convert to similarity score (1 - distance)
distance = results["distances"][0][i] if results["distances"] else 0 distance = results["distances"][0][i] if results["distances"] else 0
@@ -90,3 +94,17 @@ def retrieve(
) )
return chunks return chunks
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
"""Filter out stale vector entries whose chunk rows no longer exist."""
if not chunk_ids:
return set()
placeholders = ", ".join("?" for _ in chunk_ids)
with get_connection() as conn:
rows = conn.execute(
f"SELECT id FROM source_chunks WHERE id IN ({placeholders})",
chunk_ids,
).fetchall()
return {row["id"] for row in rows}

View File

@@ -2,7 +2,7 @@
import chromadb import chromadb
from atocore.config import settings import atocore.config as _config
from atocore.observability.logger import get_logger from atocore.observability.logger import get_logger
from atocore.retrieval.embeddings import embed_texts from atocore.retrieval.embeddings import embed_texts
@@ -17,13 +17,13 @@ class VectorStore:
"""Wrapper around ChromaDB for chunk storage and retrieval.""" """Wrapper around ChromaDB for chunk storage and retrieval."""
def __init__(self) -> None: def __init__(self) -> None:
settings.chroma_path.mkdir(parents=True, exist_ok=True) _config.settings.chroma_path.mkdir(parents=True, exist_ok=True)
self._client = chromadb.PersistentClient(path=str(settings.chroma_path)) self._client = chromadb.PersistentClient(path=str(_config.settings.chroma_path))
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=COLLECTION_NAME, name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"}, metadata={"hnsw:space": "cosine"},
) )
log.info("vector_store_initialized", path=str(settings.chroma_path)) log.info("vector_store_initialized", path=str(_config.settings.chroma_path))
def add( def add(
self, self,

View File

@@ -1,11 +1,14 @@
"""pytest configuration and shared fixtures.""" """pytest configuration and shared fixtures."""
import os import os
import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
import pytest import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
# Default test data directory — overridden per-test by fixtures # Default test data directory — overridden per-test by fixtures
_default_test_dir = tempfile.mkdtemp(prefix="atocore_test_") _default_test_dir = tempfile.mkdtemp(prefix="atocore_test_")
os.environ["ATOCORE_DATA_DIR"] = _default_test_dir os.environ["ATOCORE_DATA_DIR"] = _default_test_dir

View File

@@ -27,6 +27,7 @@ def test_context_respects_budget(tmp_data_dir, sample_markdown):
pack = build_context("What is AtoCore?", budget=500) pack = build_context("What is AtoCore?", budget=500)
assert pack.total_chars <= 500 assert pack.total_chars <= 500
assert len(pack.formatted_context) <= 500
def test_context_with_project_hint(tmp_data_dir, sample_markdown): def test_context_with_project_hint(tmp_data_dir, sample_markdown):
@@ -82,6 +83,18 @@ def test_project_state_included_in_context(tmp_data_dir, sample_markdown):
assert pack.project_state_chars > 0 assert pack.project_state_chars > 0
def test_trusted_state_precedence_is_restated_in_retrieved_context(tmp_data_dir, sample_markdown):
"""When trusted state and retrieval coexist, the context should restate precedence explicitly."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
set_state("atocore", "status", "phase", "Phase 2")
pack = build_context("What is AtoCore?", project_hint="atocore")
assert "If retrieved context conflicts with Trusted Project State above" in pack.formatted_context
def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown): def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown):
"""Test that project state is included even with tight budget.""" """Test that project state is included even with tight budget."""
init_db() init_db()
@@ -95,6 +108,32 @@ def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown):
assert "Phase 1 in progress" in pack.formatted_context assert "Phase 1 in progress" in pack.formatted_context
def test_project_state_respects_total_budget(tmp_data_dir, sample_markdown):
"""Trusted state should still fit within the total context budget."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
set_state("atocore", "status", "notes", "x" * 400)
set_state("atocore", "decision", "details", "y" * 400)
pack = build_context("status?", project_hint="atocore", budget=120)
assert pack.total_chars <= 120
assert pack.budget_remaining >= 0
assert len(pack.formatted_context) <= 120
def test_project_hint_matches_state_case_insensitively(tmp_data_dir, sample_markdown):
"""Project state lookup should not depend on exact casing."""
init_db()
init_project_state_schema()
ingest_file(sample_markdown)
set_state("AtoCore", "status", "phase", "Phase 2")
pack = build_context("status?", project_hint="atocore")
assert "Phase 2" in pack.formatted_context
def test_no_project_state_without_hint(tmp_data_dir, sample_markdown): def test_no_project_state_without_hint(tmp_data_dir, sample_markdown):
"""Test that project state is not included without project hint.""" """Test that project state is not included without project hint."""
init_db() init_db()

View File

@@ -1,10 +1,8 @@
"""Tests for the ingestion pipeline.""" """Tests for the ingestion pipeline."""
from pathlib import Path
from atocore.ingestion.parser import parse_markdown from atocore.ingestion.parser import parse_markdown
from atocore.models.database import get_connection, init_db from atocore.models.database import get_connection, init_db
from atocore.ingestion.pipeline import ingest_file from atocore.ingestion.pipeline import ingest_file, ingest_folder
def test_parse_markdown(sample_markdown): def test_parse_markdown(sample_markdown):
@@ -69,3 +67,104 @@ def test_ingest_updates_changed(tmp_data_dir, sample_markdown):
) )
result = ingest_file(sample_markdown) result = ingest_file(sample_markdown)
assert result["status"] == "ingested" assert result["status"] == "ingested"
def test_parse_markdown_uses_supplied_text(sample_markdown):
"""Parsing should be able to reuse pre-read content from ingestion."""
latin_text = """---\ntags: parser\n---\n# Parser Title\n\nBody text."""
parsed = parse_markdown(sample_markdown, text=latin_text)
assert parsed.title == "Parser Title"
assert "parser" in parsed.tags
def test_reingest_empty_replaces_stale_chunks(tmp_data_dir, sample_markdown, monkeypatch):
"""Re-ingesting a file with no chunks should clear stale DB/vector state."""
init_db()
class FakeVectorStore:
def __init__(self):
self.deleted_ids = []
def add(self, ids, documents, metadatas):
return None
def delete(self, ids):
self.deleted_ids.extend(ids)
fake_store = FakeVectorStore()
monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: fake_store)
first = ingest_file(sample_markdown)
assert first["status"] == "ingested"
sample_markdown.write_text("# Changed\n\nThis update should now produce no chunks after monkeypatching.", encoding="utf-8")
monkeypatch.setattr("atocore.ingestion.pipeline.chunk_markdown", lambda *args, **kwargs: [])
second = ingest_file(sample_markdown)
assert second["status"] == "empty"
with get_connection() as conn:
chunk_count = conn.execute("SELECT COUNT(*) AS c FROM source_chunks").fetchone()
assert chunk_count["c"] == 0
assert fake_store.deleted_ids
def test_ingest_folder_includes_markdown_extension(tmp_data_dir, sample_folder, monkeypatch):
"""Folder ingestion should include both .md and .markdown files."""
init_db()
markdown_file = sample_folder / "third_note.markdown"
markdown_file.write_text("# Third Note\n\nThis file should be discovered during folder ingestion.", encoding="utf-8")
class FakeVectorStore:
def add(self, ids, documents, metadatas):
return None
def delete(self, ids):
return None
@property
def count(self):
return 0
monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: FakeVectorStore())
results = ingest_folder(sample_folder)
files = {result["file"] for result in results if "file" in result}
assert str(markdown_file.resolve()) in files
def test_purge_deleted_files_does_not_match_sibling_prefix(tmp_data_dir, sample_folder, monkeypatch):
"""Purging one folder should not delete entries from a sibling folder with the same prefix."""
init_db()
class FakeVectorStore:
def add(self, ids, documents, metadatas):
return None
def delete(self, ids):
return None
@property
def count(self):
return 0
monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: FakeVectorStore())
kept_folder = tmp_data_dir / "notes"
kept_folder.mkdir()
kept_file = kept_folder / "keep.md"
kept_file.write_text("# Keep\n\nThis document should survive purge.", encoding="utf-8")
ingest_file(kept_file)
purge_folder = tmp_data_dir / "notes-project"
purge_folder.mkdir()
purge_file = purge_folder / "gone.md"
purge_file.write_text("# Gone\n\nThis document will be purged.", encoding="utf-8")
ingest_file(purge_file)
purge_file.unlink()
ingest_folder(purge_folder, purge_deleted=True)
with get_connection() as conn:
rows = conn.execute("SELECT file_path FROM source_documents").fetchall()
remaining_paths = {row["file_path"] for row in rows}
assert str(kept_file.resolve()) in remaining_paths

View File

@@ -47,6 +47,23 @@ def test_create_memory_dedup(isolated_db):
assert m1.id == m2.id assert m1.id == m2.id
def test_create_memory_dedup_is_project_scoped(isolated_db):
from atocore.memory.service import create_memory
m1 = create_memory("project", "Uses SQLite for local state", project="atocore")
m2 = create_memory("project", "Uses SQLite for local state", project="openclaw")
assert m1.id != m2.id
def test_project_is_persisted_and_filterable(isolated_db):
from atocore.memory.service import create_memory, get_memories
create_memory("project", "Uses SQLite for local state", project="atocore")
create_memory("project", "Uses Postgres in production", project="openclaw")
atocore_memories = get_memories(memory_type="project", project="atocore")
assert len(atocore_memories) == 1
assert atocore_memories[0].project == "atocore"
def test_get_memories_all(isolated_db): def test_get_memories_all(isolated_db):
from atocore.memory.service import create_memory, get_memories from atocore.memory.service import create_memory, get_memories
create_memory("identity", "User is an engineer") create_memory("identity", "User is an engineer")
@@ -97,6 +114,25 @@ def test_update_memory(isolated_db):
assert mems[0].confidence == 0.8 assert mems[0].confidence == 0.8
def test_update_memory_rejects_duplicate_active_memory(isolated_db):
from atocore.memory.service import create_memory, update_memory
import pytest
first = create_memory("knowledge", "Canonical fact", project="atocore")
second = create_memory("knowledge", "Different fact", project="atocore")
with pytest.raises(ValueError, match="duplicate active memory"):
update_memory(second.id, content="Canonical fact")
def test_create_memory_validates_confidence(isolated_db):
from atocore.memory.service import create_memory
import pytest
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
create_memory("knowledge", "Out of range", confidence=1.5)
def test_invalidate_memory(isolated_db): def test_invalidate_memory(isolated_db):
from atocore.memory.service import create_memory, get_memories, invalidate_memory from atocore.memory.service import create_memory, get_memories, invalidate_memory
mem = create_memory("knowledge", "Wrong fact") mem = create_memory("knowledge", "Wrong fact")
@@ -126,6 +162,25 @@ def test_memories_for_context(isolated_db):
assert chars > 0 assert chars > 0
def test_memories_for_context_reserves_room_for_each_type(isolated_db):
from atocore.memory.service import create_memory, get_memories_for_context
create_memory("identity", "Identity entry that is intentionally long so it could consume the whole budget on its own")
create_memory("preference", "Preference entry that should still appear")
text, _ = get_memories_for_context(memory_types=["identity", "preference"], budget=120)
assert "[preference]" in text
def test_memories_for_context_respects_actual_serialized_budget(isolated_db):
from atocore.memory.service import create_memory, get_memories_for_context
create_memory("identity", "Identity text that should fit the wrapper-aware memory budget calculation")
create_memory("preference", "Preference text that should also fit")
text, chars = get_memories_for_context(memory_types=["identity", "preference"], budget=140)
assert chars == len(text)
assert chars <= 140
def test_memories_for_context_empty(isolated_db): def test_memories_for_context_empty(isolated_db):
from atocore.memory.service import get_memories_for_context from atocore.memory.service import get_memories_for_context
text, chars = get_memories_for_context() text, chars = get_memories_for_context()

View File

@@ -57,6 +57,12 @@ def test_set_state_invalid_category():
set_state("myproject", "invalid_category", "key", "value") set_state("myproject", "invalid_category", "key", "value")
def test_set_state_validates_confidence():
"""Project-state confidence should stay within the documented range."""
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
set_state("myproject", "status", "phase", "Phase 1", confidence=1.2)
def test_get_state_all(): def test_get_state_all():
"""Test getting all state entries for a project.""" """Test getting all state entries for a project."""
set_state("proj", "status", "phase", "Phase 1") set_state("proj", "status", "phase", "Phase 1")

View File

@@ -1,7 +1,7 @@
"""Tests for the retrieval system.""" """Tests for the retrieval system."""
from atocore.ingestion.pipeline import ingest_file from atocore.ingestion.pipeline import ingest_file
from atocore.models.database import init_db from atocore.models.database import get_connection, init_db
from atocore.retrieval.retriever import retrieve from atocore.retrieval.retriever import retrieve
from atocore.retrieval.vector_store import get_vector_store from atocore.retrieval.vector_store import get_vector_store
@@ -39,3 +39,31 @@ def test_vector_store_count(tmp_data_dir, sample_markdown):
ingest_file(sample_markdown) ingest_file(sample_markdown)
store = get_vector_store() store = get_vector_store()
assert store.count > 0 assert store.count > 0
def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monkeypatch):
"""Retriever should ignore vector hits whose chunk rows no longer exist."""
init_db()
ingest_file(sample_markdown)
with get_connection() as conn:
chunk_ids = [row["id"] for row in conn.execute("SELECT id FROM source_chunks").fetchall()]
class FakeStore:
def query(self, query_embedding, top_k=10, where=None):
return {
"ids": [[chunk_ids[0], "missing-chunk"]],
"documents": [["valid doc", "stale doc"]],
"metadatas": [[
{"heading_path": "Overview", "source_file": "valid.md", "tags": "[]", "title": "Valid", "document_id": "doc-1"},
{"heading_path": "Ghost", "source_file": "ghost.md", "tags": "[]", "title": "Ghost", "document_id": "doc-2"},
]],
"distances": [[0.1, 0.2]],
}
monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore())
monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1])
results = retrieve("overview", top_k=2)
assert len(results) == 1
assert results[0].chunk_id == chunk_ids[0]