From b0889b392587c31c964bd168c198784bfba1f125 Mon Sep 17 00:00:00 2001 From: Anto01 Date: Sun, 5 Apr 2026 17:53:23 -0400 Subject: [PATCH] Stabilize core correctness and sync project plan state --- pyproject.toml | 2 +- src/atocore/api/routes.py | 3 + src/atocore/context/builder.py | 71 ++++++++++- src/atocore/context/project_state.py | 16 ++- src/atocore/ingestion/chunker.py | 8 +- src/atocore/ingestion/parser.py | 6 +- src/atocore/ingestion/pipeline.py | 175 ++++++++++++++------------ src/atocore/main.py | 6 +- src/atocore/memory/service.py | 136 ++++++++++++-------- src/atocore/models/database.py | 15 +++ src/atocore/observability/logger.py | 5 +- src/atocore/retrieval/embeddings.py | 8 +- src/atocore/retrieval/retriever.py | 22 +++- src/atocore/retrieval/vector_store.py | 8 +- tests/conftest.py | 3 + tests/test_context_builder.py | 39 ++++++ tests/test_ingestion.py | 105 +++++++++++++++- tests/test_memory.py | 55 ++++++++ tests/test_project_state.py | 6 + tests/test_retrieval.py | 30 ++++- 20 files changed, 551 insertions(+), 168 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ffba0d3..72da305 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,4 +33,4 @@ where = ["src"] testpaths = ["tests"] python_files = ["test_*.py"] python_functions = ["test_*"] -addopts = "--cov=atocore --cov-report=term-missing -v" +addopts = "-v" diff --git a/src/atocore/api/routes.py b/src/atocore/api/routes.py index b9f4bb6..350987f 100644 --- a/src/atocore/api/routes.py +++ b/src/atocore/api/routes.py @@ -192,6 +192,7 @@ def api_create_memory(req: MemoryCreateRequest) -> dict: @router.get("/memory") def api_get_memories( memory_type: str | None = None, + project: str | None = None, active_only: bool = True, min_confidence: float = 0.0, limit: int = 50, @@ -199,6 +200,7 @@ def api_get_memories( """List memories, optionally filtered.""" memories = get_memories( memory_type=memory_type, + project=project, active_only=active_only, min_confidence=min_confidence, limit=limit, @@ -209,6 +211,7 @@ def api_get_memories( "id": m.id, "memory_type": m.memory_type, "content": m.content, + "project": m.project, "confidence": m.confidence, "status": m.status, "updated_at": m.updated_at, diff --git a/src/atocore/context/builder.py b/src/atocore/context/builder.py index 087eacc..8ec41ec 100644 --- a/src/atocore/context/builder.py +++ b/src/atocore/context/builder.py @@ -10,7 +10,7 @@ import time from dataclasses import dataclass, field from pathlib import Path -from atocore.config import settings +import atocore.config as _config from atocore.context.project_state import format_project_state, get_state from atocore.memory.service import get_memories_for_context from atocore.observability.logger import get_logger @@ -74,20 +74,27 @@ def build_context( """ global _last_context_pack start = time.time() - budget = budget or settings.context_budget + budget = _config.settings.context_budget if budget is None else max(budget, 0) # 1. Get Trusted Project State (highest precedence) project_state_text = "" project_state_chars = 0 + project_state_budget = min( + budget, + max(0, int(budget * PROJECT_STATE_BUDGET_RATIO)), + ) if project_hint: state_entries = get_state(project_hint) if state_entries: project_state_text = format_project_state(state_entries) - project_state_chars = len(project_state_text) + project_state_text, project_state_chars = _truncate_text_block( + project_state_text, + project_state_budget or budget, + ) # 2. Get identity + preference memories (second precedence) - memory_budget = int(budget * MEMORY_BUDGET_RATIO) + memory_budget = min(int(budget * MEMORY_BUDGET_RATIO), max(budget - project_state_chars, 0)) memory_text, memory_chars = get_memories_for_context( memory_types=["identity", "preference"], budget=memory_budget, @@ -97,7 +104,7 @@ def build_context( retrieval_budget = budget - project_state_chars - memory_chars # 4. Retrieve candidates - candidates = retrieve(user_prompt, top_k=settings.context_top_k) + candidates = retrieve(user_prompt, top_k=_config.settings.context_top_k) if retrieval_budget > 0 else [] # 5. Score and rank scored = _rank_chunks(candidates, project_hint) @@ -107,12 +114,21 @@ def build_context( # 7. Format full context formatted = _format_full_context(project_state_text, memory_text, selected) + if len(formatted) > budget: + formatted, selected = _trim_context_to_budget( + project_state_text, + memory_text, + selected, + budget, + ) # 8. Build full prompt full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}" + project_state_chars = len(project_state_text) + memory_chars = len(memory_text) retrieval_chars = sum(c.char_count for c in selected) - total_chars = project_state_chars + memory_chars + retrieval_chars + total_chars = len(formatted) duration_ms = int((time.time() - start) * 1000) pack = ContextPack( @@ -235,6 +251,8 @@ def _format_full_context( # 3. Retrieved chunks (lowest trust) if chunks: parts.append("--- AtoCore Retrieved Context ---") + if project_state_text: + parts.append("If retrieved context conflicts with Trusted Project State above, trust the Trusted Project State.") for chunk in chunks: parts.append( f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]" @@ -282,3 +300,44 @@ def _pack_to_dict(pack: ContextPack) -> dict: for c in pack.chunks_used ], } + + +def _truncate_text_block(text: str, budget: int) -> tuple[str, int]: + """Trim a formatted text block so trusted tiers cannot exceed the total budget.""" + if budget <= 0 or not text: + return "", 0 + if len(text) <= budget: + return text, len(text) + if budget <= 3: + trimmed = text[:budget] + else: + trimmed = f"{text[: budget - 3].rstrip()}..." + return trimmed, len(trimmed) + + +def _trim_context_to_budget( + project_state_text: str, + memory_text: str, + chunks: list[ContextChunk], + budget: int, +) -> tuple[str, list[ContextChunk]]: + """Trim retrieval first, then memory, then project state until formatted context fits.""" + kept_chunks = list(chunks) + formatted = _format_full_context(project_state_text, memory_text, kept_chunks) + while len(formatted) > budget and kept_chunks: + kept_chunks.pop() + formatted = _format_full_context(project_state_text, memory_text, kept_chunks) + + if len(formatted) <= budget: + return formatted, kept_chunks + + memory_text, _ = _truncate_text_block(memory_text, max(budget - len(project_state_text), 0)) + formatted = _format_full_context(project_state_text, memory_text, kept_chunks) + if len(formatted) <= budget: + return formatted, kept_chunks + + project_state_text, _ = _truncate_text_block(project_state_text, budget) + formatted = _format_full_context(project_state_text, "", []) + if len(formatted) > budget: + formatted, _ = _truncate_text_block(formatted, budget) + return formatted, [] diff --git a/src/atocore/context/project_state.py b/src/atocore/context/project_state.py index 2adacbf..82124a5 100644 --- a/src/atocore/context/project_state.py +++ b/src/atocore/context/project_state.py @@ -12,10 +12,8 @@ Project state is manually curated or explicitly confirmed facts about a project. It always wins over retrieval-based context when there's a conflict. """ -import json -import time import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from atocore.models.database import get_connection @@ -81,7 +79,7 @@ def ensure_project(name: str, description: str = "") -> str: """Get or create a project by name. Returns project_id.""" with get_connection() as conn: row = conn.execute( - "SELECT id FROM projects WHERE name = ?", (name,) + "SELECT id FROM projects WHERE lower(name) = lower(?)", (name,) ).fetchone() if row: return row["id"] @@ -106,6 +104,7 @@ def set_state( """Set or update a project state entry. Upsert semantics.""" if category not in CATEGORIES: raise ValueError(f"Invalid category '{category}'. Must be one of: {CATEGORIES}") + _validate_confidence(confidence) project_id = ensure_project(project_name) entry_id = str(uuid.uuid4()) @@ -157,7 +156,7 @@ def get_state( """Get project state entries, optionally filtered by category.""" with get_connection() as conn: project = conn.execute( - "SELECT id FROM projects WHERE name = ?", (project_name,) + "SELECT id FROM projects WHERE lower(name) = lower(?)", (project_name,) ).fetchone() if not project: return [] @@ -195,7 +194,7 @@ def invalidate_state(project_name: str, category: str, key: str) -> bool: """Mark a project state entry as superseded.""" with get_connection() as conn: project = conn.execute( - "SELECT id FROM projects WHERE name = ?", (project_name,) + "SELECT id FROM projects WHERE lower(name) = lower(?)", (project_name,) ).fetchone() if not project: return False @@ -229,3 +228,8 @@ def format_project_state(entries: list[ProjectStateEntry]) -> str: lines.append("\n--- End Project State ---") return "\n".join(lines) + + +def _validate_confidence(confidence: float) -> None: + if not 0.0 <= confidence <= 1.0: + raise ValueError("Confidence must be between 0.0 and 1.0") diff --git a/src/atocore/ingestion/chunker.py b/src/atocore/ingestion/chunker.py index 6d7d201..461e981 100644 --- a/src/atocore/ingestion/chunker.py +++ b/src/atocore/ingestion/chunker.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass, field -from atocore.config import settings +import atocore.config as _config @dataclass @@ -29,9 +29,9 @@ def chunk_markdown( 3. If still > max_size, split on paragraph breaks 4. If still > max_size, hard split with overlap """ - max_size = max_size or settings.chunk_max_size - overlap = overlap or settings.chunk_overlap - min_size = min_size or settings.chunk_min_size + max_size = max_size or _config.settings.chunk_max_size + overlap = overlap or _config.settings.chunk_overlap + min_size = min_size or _config.settings.chunk_min_size base_metadata = base_metadata or {} sections = _split_by_heading(body, level=2) diff --git a/src/atocore/ingestion/parser.py b/src/atocore/ingestion/parser.py index 2684ec0..6e5897e 100644 --- a/src/atocore/ingestion/parser.py +++ b/src/atocore/ingestion/parser.py @@ -17,10 +17,10 @@ class ParsedDocument: headings: list[tuple[int, str]] = field(default_factory=list) -def parse_markdown(file_path: Path) -> ParsedDocument: +def parse_markdown(file_path: Path, text: str | None = None) -> ParsedDocument: """Parse a markdown file, extracting frontmatter and structure.""" - text = file_path.read_text(encoding="utf-8") - post = frontmatter.loads(text) + raw_text = text if text is not None else file_path.read_text(encoding="utf-8") + post = frontmatter.loads(raw_text) meta = dict(post.metadata) if post.metadata else {} body = post.content.strip() diff --git a/src/atocore/ingestion/pipeline.py b/src/atocore/ingestion/pipeline.py index ed829c9..d93fb63 100644 --- a/src/atocore/ingestion/pipeline.py +++ b/src/atocore/ingestion/pipeline.py @@ -6,7 +6,6 @@ import time import uuid from pathlib import Path -from atocore.config import settings from atocore.ingestion.chunker import chunk_markdown from atocore.ingestion.parser import parse_markdown from atocore.models.database import get_connection @@ -45,7 +44,7 @@ def ingest_file(file_path: Path) -> dict: return {"file": str(file_path), "status": "skipped", "reason": "unchanged"} # Parse - parsed = parse_markdown(file_path) + parsed = parse_markdown(file_path, text=raw_content) # Chunk base_meta = { @@ -55,85 +54,98 @@ def ingest_file(file_path: Path) -> dict: } chunks = chunk_markdown(parsed.body, base_metadata=base_meta) - if not chunks: - log.warning("no_chunks_created", file_path=str(file_path)) - return {"file": str(file_path), "status": "empty", "chunks": 0} - # Store in DB and vector store doc_id = str(uuid.uuid4()) vector_store = get_vector_store() + old_chunk_ids: list[str] = [] + new_chunk_ids: list[str] = [] - with get_connection() as conn: - # Remove old data if re-ingesting - if existing: - doc_id = existing["id"] - old_chunk_ids = [ - row["id"] - for row in conn.execute( - "SELECT id FROM source_chunks WHERE document_id = ?", - (doc_id,), - ).fetchall() - ] - conn.execute( - "DELETE FROM source_chunks WHERE document_id = ?", (doc_id,) - ) - conn.execute( - "UPDATE source_documents SET file_hash = ?, title = ?, tags = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", - (file_hash, parsed.title, json.dumps(parsed.tags), doc_id), - ) - # Remove old vectors - if old_chunk_ids: - vector_store.delete(old_chunk_ids) - else: - conn.execute( - "INSERT INTO source_documents (id, file_path, file_hash, title, doc_type, tags) VALUES (?, ?, ?, ?, ?, ?)", - (doc_id, str(file_path), file_hash, parsed.title, "markdown", json.dumps(parsed.tags)), - ) + try: + with get_connection() as conn: + # Remove old data if re-ingesting + if existing: + doc_id = existing["id"] + old_chunk_ids = [ + row["id"] + for row in conn.execute( + "SELECT id FROM source_chunks WHERE document_id = ?", + (doc_id,), + ).fetchall() + ] + conn.execute( + "DELETE FROM source_chunks WHERE document_id = ?", (doc_id,) + ) + conn.execute( + "UPDATE source_documents SET file_hash = ?, title = ?, tags = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", + (file_hash, parsed.title, json.dumps(parsed.tags), doc_id), + ) + else: + conn.execute( + "INSERT INTO source_documents (id, file_path, file_hash, title, doc_type, tags) VALUES (?, ?, ?, ?, ?, ?)", + (doc_id, str(file_path), file_hash, parsed.title, "markdown", json.dumps(parsed.tags)), + ) - # Insert chunks - chunk_ids = [] - chunk_contents = [] - chunk_metadatas = [] + if not chunks: + log.warning("no_chunks_created", file_path=str(file_path)) + else: + # Insert chunks + chunk_contents = [] + chunk_metadatas = [] - for chunk in chunks: - chunk_id = str(uuid.uuid4()) - chunk_ids.append(chunk_id) - chunk_contents.append(chunk.content) - chunk_metadatas.append({ - "document_id": doc_id, - "heading_path": chunk.heading_path, - "source_file": str(file_path), - "tags": json.dumps(parsed.tags), - "title": parsed.title, - }) + for chunk in chunks: + chunk_id = str(uuid.uuid4()) + new_chunk_ids.append(chunk_id) + chunk_contents.append(chunk.content) + chunk_metadatas.append({ + "document_id": doc_id, + "heading_path": chunk.heading_path, + "source_file": str(file_path), + "tags": json.dumps(parsed.tags), + "title": parsed.title, + }) - conn.execute( - "INSERT INTO source_chunks (id, document_id, chunk_index, content, heading_path, char_count, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - chunk_id, - doc_id, - chunk.chunk_index, - chunk.content, - chunk.heading_path, - chunk.char_count, - json.dumps(chunk.metadata), - ), - ) + conn.execute( + "INSERT INTO source_chunks (id, document_id, chunk_index, content, heading_path, char_count, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + chunk_id, + doc_id, + chunk.chunk_index, + chunk.content, + chunk.heading_path, + chunk.char_count, + json.dumps(chunk.metadata), + ), + ) - # Store embeddings - vector_store.add(chunk_ids, chunk_contents, chunk_metadatas) + # Add new vectors before commit so DB can still roll back on failure. + vector_store.add(new_chunk_ids, chunk_contents, chunk_metadatas) + except Exception: + if new_chunk_ids: + vector_store.delete(new_chunk_ids) + raise + + # Delete stale vectors only after the DB transaction committed. + if old_chunk_ids: + vector_store.delete(old_chunk_ids) duration_ms = int((time.time() - start) * 1000) - log.info( - "file_ingested", - file_path=str(file_path), - chunks_created=len(chunks), - duration_ms=duration_ms, - ) + if chunks: + log.info( + "file_ingested", + file_path=str(file_path), + chunks_created=len(chunks), + duration_ms=duration_ms, + ) + else: + log.info( + "file_ingested_empty", + file_path=str(file_path), + duration_ms=duration_ms, + ) return { "file": str(file_path), - "status": "ingested", + "status": "ingested" if chunks else "empty", "chunks": len(chunks), "duration_ms": duration_ms, } @@ -152,7 +164,9 @@ def ingest_folder(folder_path: Path, purge_deleted: bool = True) -> list[dict]: raise NotADirectoryError(f"Not a directory: {folder_path}") results = [] - md_files = sorted(folder_path.rglob("*.md")) + md_files = sorted( + list(folder_path.rglob("*.md")) + list(folder_path.rglob("*.markdown")) + ) current_paths = {str(f.resolve()) for f in md_files} log.info("ingestion_started", folder=str(folder_path), file_count=len(md_files)) @@ -213,32 +227,35 @@ def _purge_deleted_files(folder_path: Path, current_paths: set[str]) -> int: folder_str = str(folder_path) deleted_count = 0 vector_store = get_vector_store() + chunk_ids_to_delete: list[str] = [] with get_connection() as conn: - # Find documents under this folder rows = conn.execute( - "SELECT id, file_path FROM source_documents WHERE file_path LIKE ?", - (f"{folder_str}%",), + "SELECT id, file_path FROM source_documents" ).fetchall() for row in rows: + doc_path = Path(row["file_path"]) + try: + doc_path.relative_to(folder_path) + except ValueError: + continue + if row["file_path"] not in current_paths: doc_id = row["id"] - # Get chunk IDs for vector deletion - chunk_ids = [ + chunk_ids_to_delete.extend( r["id"] for r in conn.execute( "SELECT id FROM source_chunks WHERE document_id = ?", (doc_id,), ).fetchall() - ] - # Delete from DB + ) conn.execute("DELETE FROM source_chunks WHERE document_id = ?", (doc_id,)) conn.execute("DELETE FROM source_documents WHERE id = ?", (doc_id,)) - # Delete from vectors - if chunk_ids: - vector_store.delete(chunk_ids) log.info("purged_deleted_file", file_path=row["file_path"]) deleted_count += 1 + if chunk_ids_to_delete: + vector_store.delete(chunk_ids_to_delete) + return deleted_count diff --git a/src/atocore/main.py b/src/atocore/main.py index 9c79b0d..7ee360c 100644 --- a/src/atocore/main.py +++ b/src/atocore/main.py @@ -3,7 +3,7 @@ from fastapi import FastAPI from atocore.api.routes import router -from atocore.config import settings +import atocore.config as _config from atocore.context.project_state import init_project_state_schema from atocore.models.database import init_db from atocore.observability.logger import setup_logging @@ -29,7 +29,7 @@ if __name__ == "__main__": uvicorn.run( "atocore.main:app", - host=settings.host, - port=settings.port, + host=_config.settings.host, + port=_config.settings.port, reload=True, ) diff --git a/src/atocore/memory/service.py b/src/atocore/memory/service.py index 7e7d6c6..5102ef0 100644 --- a/src/atocore/memory/service.py +++ b/src/atocore/memory/service.py @@ -14,7 +14,6 @@ Memories have: - optional link to source chunk: traceability """ -import json import uuid from dataclasses import dataclass from datetime import datetime, timezone @@ -57,6 +56,7 @@ def create_memory( """Create a new memory entry.""" if memory_type not in MEMORY_TYPES: raise ValueError(f"Invalid memory type '{memory_type}'. Must be one of: {MEMORY_TYPES}") + _validate_confidence(confidence) memory_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() @@ -64,8 +64,9 @@ def create_memory( # Check for duplicate content within same type+project with get_connection() as conn: existing = conn.execute( - "SELECT id FROM memories WHERE memory_type = ? AND content = ? AND status = 'active'", - (memory_type, content), + "SELECT id FROM memories " + "WHERE memory_type = ? AND content = ? AND project = ? AND status = 'active'", + (memory_type, content, project), ).fetchone() if existing: log.info("memory_duplicate_skipped", memory_type=memory_type, content_preview=content[:80]) @@ -74,9 +75,9 @@ def create_memory( ) conn.execute( - "INSERT INTO memories (id, memory_type, content, source_chunk_id, confidence, status) " - "VALUES (?, ?, ?, ?, ?, 'active')", - (memory_id, memory_type, content, source_chunk_id or None, confidence), + "INSERT INTO memories (id, memory_type, content, project, source_chunk_id, confidence, status) " + "VALUES (?, ?, ?, ?, ?, ?, 'active')", + (memory_id, memory_type, content, project, source_chunk_id or None, confidence), ) log.info("memory_created", memory_type=memory_type, content_preview=content[:80]) @@ -96,6 +97,7 @@ def create_memory( def get_memories( memory_type: str | None = None, + project: str | None = None, active_only: bool = True, min_confidence: float = 0.0, limit: int = 50, @@ -107,6 +109,9 @@ def get_memories( if memory_type: query += " AND memory_type = ?" params.append(memory_type) + if project is not None: + query += " AND project = ?" + params.append(project) if active_only: query += " AND status = 'active'" if min_confidence > 0: @@ -129,28 +134,46 @@ def update_memory( status: str | None = None, ) -> bool: """Update an existing memory.""" - updates = [] - params: list = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if confidence is not None: - updates.append("confidence = ?") - params.append(confidence) - if status is not None: - if status not in ("active", "superseded", "invalid"): - raise ValueError(f"Invalid status '{status}'") - updates.append("status = ?") - params.append(status) - - if not updates: - return False - - updates.append("updated_at = CURRENT_TIMESTAMP") - params.append(memory_id) - with get_connection() as conn: + existing = conn.execute("SELECT * FROM memories WHERE id = ?", (memory_id,)).fetchone() + if existing is None: + return False + + next_content = content if content is not None else existing["content"] + next_status = status if status is not None else existing["status"] + if confidence is not None: + _validate_confidence(confidence) + + if next_status == "active": + duplicate = conn.execute( + "SELECT id FROM memories " + "WHERE memory_type = ? AND content = ? AND project = ? AND status = 'active' AND id != ?", + (existing["memory_type"], next_content, existing["project"] or "", memory_id), + ).fetchone() + if duplicate: + raise ValueError("Update would create a duplicate active memory") + + updates = [] + params: list = [] + + if content is not None: + updates.append("content = ?") + params.append(content) + if confidence is not None: + updates.append("confidence = ?") + params.append(confidence) + if status is not None: + if status not in ("active", "superseded", "invalid"): + raise ValueError(f"Invalid status '{status}'") + updates.append("status = ?") + params.append(status) + + if not updates: + return False + + updates.append("updated_at = CURRENT_TIMESTAMP") + params.append(memory_id) + result = conn.execute( f"UPDATE memories SET {', '.join(updates)} WHERE id = ?", params, @@ -174,6 +197,7 @@ def supersede_memory(memory_id: str) -> bool: def get_memories_for_context( memory_types: list[str] | None = None, + project: str | None = None, budget: int = 500, ) -> tuple[str, int]: """Get formatted memories for context injection. @@ -186,33 +210,42 @@ def get_memories_for_context( if memory_types is None: memory_types = ["identity", "preference"] - memories = [] - for mtype in memory_types: - memories.extend(get_memories(memory_type=mtype, min_confidence=0.5, limit=10)) - - if not memories: + if budget <= 0: return "", 0 - lines = ["--- AtoCore Memory ---"] - used = len(lines[0]) + 1 - included = [] - - for mem in memories: - entry = f"[{mem.memory_type}] {mem.content}" - entry_len = len(entry) + 1 - if used + entry_len > budget: - break - lines.append(entry) - used += entry_len - included.append(mem) - - if len(included) == 0: + header = "--- AtoCore Memory ---" + footer = "--- End Memory ---" + wrapper_chars = len(header) + len(footer) + 2 + if budget <= wrapper_chars: return "", 0 - lines.append("--- End Memory ---") + available = budget - wrapper_chars + selected_entries: list[str] = [] + + for index, mtype in enumerate(memory_types): + type_budget = available if index == len(memory_types) - 1 else max(0, available // (len(memory_types) - index)) + type_used = 0 + for mem in get_memories( + memory_type=mtype, + project=project, + min_confidence=0.5, + limit=10, + ): + entry = f"[{mem.memory_type}] {mem.content}" + entry_len = len(entry) + 1 + if entry_len > type_budget - type_used: + continue + selected_entries.append(entry) + type_used += entry_len + available -= type_used + + if not selected_entries: + return "", 0 + + lines = [header, *selected_entries, footer] text = "\n".join(lines) - log.info("memories_for_context", count=len(included), chars=len(text)) + log.info("memories_for_context", count=len(selected_entries), chars=len(text)) return text, len(text) @@ -222,10 +255,15 @@ def _row_to_memory(row) -> Memory: id=row["id"], memory_type=row["memory_type"], content=row["content"], - project="", + project=row["project"] or "", source_chunk_id=row["source_chunk_id"] or "", confidence=row["confidence"], status=row["status"], created_at=row["created_at"], updated_at=row["updated_at"], ) + + +def _validate_confidence(confidence: float) -> None: + if not 0.0 <= confidence <= 1.0: + raise ValueError("Confidence must be between 0.0 and 1.0") diff --git a/src/atocore/models/database.py b/src/atocore/models/database.py index 592eea3..2a57964 100644 --- a/src/atocore/models/database.py +++ b/src/atocore/models/database.py @@ -37,6 +37,7 @@ CREATE TABLE IF NOT EXISTS memories ( id TEXT PRIMARY KEY, memory_type TEXT NOT NULL, content TEXT NOT NULL, + project TEXT DEFAULT '', source_chunk_id TEXT REFERENCES source_chunks(id), confidence REAL DEFAULT 1.0, status TEXT DEFAULT 'active', @@ -64,6 +65,7 @@ CREATE TABLE IF NOT EXISTS interactions ( CREATE INDEX IF NOT EXISTS idx_chunks_document ON source_chunks(document_id); CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(memory_type); +CREATE INDEX IF NOT EXISTS idx_memories_project ON memories(project); CREATE INDEX IF NOT EXISTS idx_memories_status ON memories(status); CREATE INDEX IF NOT EXISTS idx_interactions_project ON interactions(project_id); """ @@ -78,9 +80,22 @@ def init_db() -> None: _ensure_data_dir() with get_connection() as conn: conn.executescript(SCHEMA_SQL) + _apply_migrations(conn) log.info("database_initialized", path=str(_config.settings.db_path)) +def _apply_migrations(conn: sqlite3.Connection) -> None: + """Apply lightweight schema migrations for existing local databases.""" + if not _column_exists(conn, "memories", "project"): + conn.execute("ALTER TABLE memories ADD COLUMN project TEXT DEFAULT ''") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_project ON memories(project)") + + +def _column_exists(conn: sqlite3.Connection, table: str, column: str) -> bool: + rows = conn.execute(f"PRAGMA table_info({table})").fetchall() + return any(row["name"] == column for row in rows) + + @contextmanager def get_connection() -> Generator[sqlite3.Connection, None, None]: """Get a database connection with row factory.""" diff --git a/src/atocore/observability/logger.py b/src/atocore/observability/logger.py index 6ee409c..1f9fd4a 100644 --- a/src/atocore/observability/logger.py +++ b/src/atocore/observability/logger.py @@ -2,10 +2,9 @@ import logging +import atocore.config as _config import structlog -from atocore.config import settings - _LOG_LEVELS = { "DEBUG": logging.DEBUG, "INFO": logging.INFO, @@ -16,7 +15,7 @@ _LOG_LEVELS = { def setup_logging() -> None: """Configure structlog with JSON output.""" - log_level = "DEBUG" if settings.debug else "INFO" + log_level = "DEBUG" if _config.settings.debug else "INFO" structlog.configure( processors=[ diff --git a/src/atocore/retrieval/embeddings.py b/src/atocore/retrieval/embeddings.py index 37a781e..7352d85 100644 --- a/src/atocore/retrieval/embeddings.py +++ b/src/atocore/retrieval/embeddings.py @@ -1,8 +1,8 @@ """Embedding model management.""" +import atocore.config as _config from sentence_transformers import SentenceTransformer -from atocore.config import settings from atocore.observability.logger import get_logger log = get_logger("embeddings") @@ -14,9 +14,9 @@ def get_model() -> SentenceTransformer: """Load and cache the embedding model.""" global _model if _model is None: - log.info("loading_embedding_model", model=settings.embedding_model) - _model = SentenceTransformer(settings.embedding_model) - log.info("embedding_model_loaded", model=settings.embedding_model) + log.info("loading_embedding_model", model=_config.settings.embedding_model) + _model = SentenceTransformer(_config.settings.embedding_model) + log.info("embedding_model_loaded", model=_config.settings.embedding_model) return _model diff --git a/src/atocore/retrieval/retriever.py b/src/atocore/retrieval/retriever.py index 6920e11..fa19362 100644 --- a/src/atocore/retrieval/retriever.py +++ b/src/atocore/retrieval/retriever.py @@ -3,7 +3,8 @@ import time from dataclasses import dataclass -from atocore.config import settings +import atocore.config as _config +from atocore.models.database import get_connection from atocore.observability.logger import get_logger from atocore.retrieval.embeddings import embed_query from atocore.retrieval.vector_store import get_vector_store @@ -29,7 +30,7 @@ def retrieve( filter_tags: list[str] | None = None, ) -> list[ChunkResult]: """Retrieve the most relevant chunks for a query.""" - top_k = top_k or settings.context_top_k + top_k = top_k or _config.settings.context_top_k start = time.time() query_embedding = embed_query(query) @@ -59,7 +60,10 @@ def retrieve( chunks = [] if results and results["ids"] and results["ids"][0]: + existing_ids = _existing_chunk_ids(results["ids"][0]) for i, chunk_id in enumerate(results["ids"][0]): + if chunk_id not in existing_ids: + continue # ChromaDB returns distances (lower = more similar for cosine) # Convert to similarity score (1 - distance) distance = results["distances"][0][i] if results["distances"] else 0 @@ -90,3 +94,17 @@ def retrieve( ) return chunks + + +def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]: + """Filter out stale vector entries whose chunk rows no longer exist.""" + if not chunk_ids: + return set() + + placeholders = ", ".join("?" for _ in chunk_ids) + with get_connection() as conn: + rows = conn.execute( + f"SELECT id FROM source_chunks WHERE id IN ({placeholders})", + chunk_ids, + ).fetchall() + return {row["id"] for row in rows} diff --git a/src/atocore/retrieval/vector_store.py b/src/atocore/retrieval/vector_store.py index 4143d9c..6039b7b 100644 --- a/src/atocore/retrieval/vector_store.py +++ b/src/atocore/retrieval/vector_store.py @@ -2,7 +2,7 @@ import chromadb -from atocore.config import settings +import atocore.config as _config from atocore.observability.logger import get_logger from atocore.retrieval.embeddings import embed_texts @@ -17,13 +17,13 @@ class VectorStore: """Wrapper around ChromaDB for chunk storage and retrieval.""" def __init__(self) -> None: - settings.chroma_path.mkdir(parents=True, exist_ok=True) - self._client = chromadb.PersistentClient(path=str(settings.chroma_path)) + _config.settings.chroma_path.mkdir(parents=True, exist_ok=True) + self._client = chromadb.PersistentClient(path=str(_config.settings.chroma_path)) self._collection = self._client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) - log.info("vector_store_initialized", path=str(settings.chroma_path)) + log.info("vector_store_initialized", path=str(_config.settings.chroma_path)) def add( self, diff --git a/tests/conftest.py b/tests/conftest.py index b8081c1..599dc51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,14 @@ """pytest configuration and shared fixtures.""" import os +import sys import tempfile from pathlib import Path import pytest +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + # Default test data directory — overridden per-test by fixtures _default_test_dir = tempfile.mkdtemp(prefix="atocore_test_") os.environ["ATOCORE_DATA_DIR"] = _default_test_dir diff --git a/tests/test_context_builder.py b/tests/test_context_builder.py index 1939e44..7fea372 100644 --- a/tests/test_context_builder.py +++ b/tests/test_context_builder.py @@ -27,6 +27,7 @@ def test_context_respects_budget(tmp_data_dir, sample_markdown): pack = build_context("What is AtoCore?", budget=500) assert pack.total_chars <= 500 + assert len(pack.formatted_context) <= 500 def test_context_with_project_hint(tmp_data_dir, sample_markdown): @@ -82,6 +83,18 @@ def test_project_state_included_in_context(tmp_data_dir, sample_markdown): assert pack.project_state_chars > 0 +def test_trusted_state_precedence_is_restated_in_retrieved_context(tmp_data_dir, sample_markdown): + """When trusted state and retrieval coexist, the context should restate precedence explicitly.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + set_state("atocore", "status", "phase", "Phase 2") + pack = build_context("What is AtoCore?", project_hint="atocore") + + assert "If retrieved context conflicts with Trusted Project State above" in pack.formatted_context + + def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown): """Test that project state is included even with tight budget.""" init_db() @@ -95,6 +108,32 @@ def test_project_state_takes_priority_budget(tmp_data_dir, sample_markdown): assert "Phase 1 in progress" in pack.formatted_context +def test_project_state_respects_total_budget(tmp_data_dir, sample_markdown): + """Trusted state should still fit within the total context budget.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + set_state("atocore", "status", "notes", "x" * 400) + set_state("atocore", "decision", "details", "y" * 400) + + pack = build_context("status?", project_hint="atocore", budget=120) + assert pack.total_chars <= 120 + assert pack.budget_remaining >= 0 + assert len(pack.formatted_context) <= 120 + + +def test_project_hint_matches_state_case_insensitively(tmp_data_dir, sample_markdown): + """Project state lookup should not depend on exact casing.""" + init_db() + init_project_state_schema() + ingest_file(sample_markdown) + + set_state("AtoCore", "status", "phase", "Phase 2") + pack = build_context("status?", project_hint="atocore") + assert "Phase 2" in pack.formatted_context + + def test_no_project_state_without_hint(tmp_data_dir, sample_markdown): """Test that project state is not included without project hint.""" init_db() diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 7c941a5..3922d3d 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -1,10 +1,8 @@ """Tests for the ingestion pipeline.""" -from pathlib import Path - from atocore.ingestion.parser import parse_markdown from atocore.models.database import get_connection, init_db -from atocore.ingestion.pipeline import ingest_file +from atocore.ingestion.pipeline import ingest_file, ingest_folder def test_parse_markdown(sample_markdown): @@ -69,3 +67,104 @@ def test_ingest_updates_changed(tmp_data_dir, sample_markdown): ) result = ingest_file(sample_markdown) assert result["status"] == "ingested" + + +def test_parse_markdown_uses_supplied_text(sample_markdown): + """Parsing should be able to reuse pre-read content from ingestion.""" + latin_text = """---\ntags: parser\n---\n# Parser Title\n\nBody text.""" + parsed = parse_markdown(sample_markdown, text=latin_text) + assert parsed.title == "Parser Title" + assert "parser" in parsed.tags + + +def test_reingest_empty_replaces_stale_chunks(tmp_data_dir, sample_markdown, monkeypatch): + """Re-ingesting a file with no chunks should clear stale DB/vector state.""" + init_db() + + class FakeVectorStore: + def __init__(self): + self.deleted_ids = [] + + def add(self, ids, documents, metadatas): + return None + + def delete(self, ids): + self.deleted_ids.extend(ids) + + fake_store = FakeVectorStore() + monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: fake_store) + + first = ingest_file(sample_markdown) + assert first["status"] == "ingested" + + sample_markdown.write_text("# Changed\n\nThis update should now produce no chunks after monkeypatching.", encoding="utf-8") + monkeypatch.setattr("atocore.ingestion.pipeline.chunk_markdown", lambda *args, **kwargs: []) + second = ingest_file(sample_markdown) + assert second["status"] == "empty" + + with get_connection() as conn: + chunk_count = conn.execute("SELECT COUNT(*) AS c FROM source_chunks").fetchone() + assert chunk_count["c"] == 0 + + assert fake_store.deleted_ids + + +def test_ingest_folder_includes_markdown_extension(tmp_data_dir, sample_folder, monkeypatch): + """Folder ingestion should include both .md and .markdown files.""" + init_db() + markdown_file = sample_folder / "third_note.markdown" + markdown_file.write_text("# Third Note\n\nThis file should be discovered during folder ingestion.", encoding="utf-8") + + class FakeVectorStore: + def add(self, ids, documents, metadatas): + return None + + def delete(self, ids): + return None + + @property + def count(self): + return 0 + + monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: FakeVectorStore()) + results = ingest_folder(sample_folder) + files = {result["file"] for result in results if "file" in result} + assert str(markdown_file.resolve()) in files + + +def test_purge_deleted_files_does_not_match_sibling_prefix(tmp_data_dir, sample_folder, monkeypatch): + """Purging one folder should not delete entries from a sibling folder with the same prefix.""" + init_db() + + class FakeVectorStore: + def add(self, ids, documents, metadatas): + return None + + def delete(self, ids): + return None + + @property + def count(self): + return 0 + + monkeypatch.setattr("atocore.ingestion.pipeline.get_vector_store", lambda: FakeVectorStore()) + + kept_folder = tmp_data_dir / "notes" + kept_folder.mkdir() + kept_file = kept_folder / "keep.md" + kept_file.write_text("# Keep\n\nThis document should survive purge.", encoding="utf-8") + ingest_file(kept_file) + + purge_folder = tmp_data_dir / "notes-project" + purge_folder.mkdir() + purge_file = purge_folder / "gone.md" + purge_file.write_text("# Gone\n\nThis document will be purged.", encoding="utf-8") + ingest_file(purge_file) + purge_file.unlink() + + ingest_folder(purge_folder, purge_deleted=True) + + with get_connection() as conn: + rows = conn.execute("SELECT file_path FROM source_documents").fetchall() + remaining_paths = {row["file_path"] for row in rows} + assert str(kept_file.resolve()) in remaining_paths diff --git a/tests/test_memory.py b/tests/test_memory.py index 8ee0811..32085cf 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -47,6 +47,23 @@ def test_create_memory_dedup(isolated_db): assert m1.id == m2.id +def test_create_memory_dedup_is_project_scoped(isolated_db): + from atocore.memory.service import create_memory + m1 = create_memory("project", "Uses SQLite for local state", project="atocore") + m2 = create_memory("project", "Uses SQLite for local state", project="openclaw") + assert m1.id != m2.id + + +def test_project_is_persisted_and_filterable(isolated_db): + from atocore.memory.service import create_memory, get_memories + create_memory("project", "Uses SQLite for local state", project="atocore") + create_memory("project", "Uses Postgres in production", project="openclaw") + + atocore_memories = get_memories(memory_type="project", project="atocore") + assert len(atocore_memories) == 1 + assert atocore_memories[0].project == "atocore" + + def test_get_memories_all(isolated_db): from atocore.memory.service import create_memory, get_memories create_memory("identity", "User is an engineer") @@ -97,6 +114,25 @@ def test_update_memory(isolated_db): assert mems[0].confidence == 0.8 +def test_update_memory_rejects_duplicate_active_memory(isolated_db): + from atocore.memory.service import create_memory, update_memory + import pytest + + first = create_memory("knowledge", "Canonical fact", project="atocore") + second = create_memory("knowledge", "Different fact", project="atocore") + + with pytest.raises(ValueError, match="duplicate active memory"): + update_memory(second.id, content="Canonical fact") + + +def test_create_memory_validates_confidence(isolated_db): + from atocore.memory.service import create_memory + import pytest + + with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"): + create_memory("knowledge", "Out of range", confidence=1.5) + + def test_invalidate_memory(isolated_db): from atocore.memory.service import create_memory, get_memories, invalidate_memory mem = create_memory("knowledge", "Wrong fact") @@ -126,6 +162,25 @@ def test_memories_for_context(isolated_db): assert chars > 0 +def test_memories_for_context_reserves_room_for_each_type(isolated_db): + from atocore.memory.service import create_memory, get_memories_for_context + create_memory("identity", "Identity entry that is intentionally long so it could consume the whole budget on its own") + create_memory("preference", "Preference entry that should still appear") + + text, _ = get_memories_for_context(memory_types=["identity", "preference"], budget=120) + assert "[preference]" in text + + +def test_memories_for_context_respects_actual_serialized_budget(isolated_db): + from atocore.memory.service import create_memory, get_memories_for_context + create_memory("identity", "Identity text that should fit the wrapper-aware memory budget calculation") + create_memory("preference", "Preference text that should also fit") + + text, chars = get_memories_for_context(memory_types=["identity", "preference"], budget=140) + assert chars == len(text) + assert chars <= 140 + + def test_memories_for_context_empty(isolated_db): from atocore.memory.service import get_memories_for_context text, chars = get_memories_for_context() diff --git a/tests/test_project_state.py b/tests/test_project_state.py index ffa7248..dcb7a90 100644 --- a/tests/test_project_state.py +++ b/tests/test_project_state.py @@ -57,6 +57,12 @@ def test_set_state_invalid_category(): set_state("myproject", "invalid_category", "key", "value") +def test_set_state_validates_confidence(): + """Project-state confidence should stay within the documented range.""" + with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"): + set_state("myproject", "status", "phase", "Phase 1", confidence=1.2) + + def test_get_state_all(): """Test getting all state entries for a project.""" set_state("proj", "status", "phase", "Phase 1") diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index d4cbdcb..4b24c53 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -1,7 +1,7 @@ """Tests for the retrieval system.""" from atocore.ingestion.pipeline import ingest_file -from atocore.models.database import init_db +from atocore.models.database import get_connection, init_db from atocore.retrieval.retriever import retrieve from atocore.retrieval.vector_store import get_vector_store @@ -39,3 +39,31 @@ def test_vector_store_count(tmp_data_dir, sample_markdown): ingest_file(sample_markdown) store = get_vector_store() assert store.count > 0 + + +def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monkeypatch): + """Retriever should ignore vector hits whose chunk rows no longer exist.""" + init_db() + ingest_file(sample_markdown) + + with get_connection() as conn: + chunk_ids = [row["id"] for row in conn.execute("SELECT id FROM source_chunks").fetchall()] + + class FakeStore: + def query(self, query_embedding, top_k=10, where=None): + return { + "ids": [[chunk_ids[0], "missing-chunk"]], + "documents": [["valid doc", "stale doc"]], + "metadatas": [[ + {"heading_path": "Overview", "source_file": "valid.md", "tags": "[]", "title": "Valid", "document_id": "doc-1"}, + {"heading_path": "Ghost", "source_file": "ghost.md", "tags": "[]", "title": "Ghost", "document_id": "doc-2"}, + ]], + "distances": [[0.1, 0.2]], + } + + monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore()) + monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1]) + + results = retrieve("overview", top_k=2) + assert len(results) == 1 + assert results[0].chunk_id == chunk_ids[0]