"""Tests for the retrieval system.""" from atocore.ingestion.pipeline import ingest_file from atocore.models.database import get_connection, init_db from atocore.retrieval.retriever import retrieve from atocore.retrieval.vector_store import get_vector_store def test_retrieve_returns_results(tmp_data_dir, sample_markdown): """Test that retrieval returns relevant chunks.""" init_db() ingest_file(sample_markdown) results = retrieve("What are the memory types?", top_k=5) assert len(results) > 0 assert all(r.score > 0 for r in results) assert all(r.content for r in results) def test_retrieve_scores_ranked(tmp_data_dir, sample_markdown): """Test that results are ranked by score.""" init_db() ingest_file(sample_markdown) results = retrieve("architecture layers", top_k=5) if len(results) >= 2: scores = [r.score for r in results] assert scores == sorted(scores, reverse=True) def test_vector_store_count(tmp_data_dir, sample_markdown): """Test that vector store tracks chunk count.""" init_db() # Reset singleton for clean test import atocore.retrieval.vector_store as vs vs._store = None ingest_file(sample_markdown) store = get_vector_store() assert store.count > 0 def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monkeypatch): """Retriever should ignore vector hits whose chunk rows no longer exist.""" init_db() ingest_file(sample_markdown) with get_connection() as conn: chunk_ids = [row["id"] for row in conn.execute("SELECT id FROM source_chunks").fetchall()] class FakeStore: def query(self, query_embedding, top_k=10, where=None): return { "ids": [[chunk_ids[0], "missing-chunk"]], "documents": [["valid doc", "stale doc"]], "metadatas": [[ {"heading_path": "Overview", "source_file": "valid.md", "tags": "[]", "title": "Valid", "document_id": "doc-1"}, {"heading_path": "Ghost", "source_file": "ghost.md", "tags": "[]", "title": "Ghost", "document_id": "doc-2"}, ]], "distances": [[0.1, 0.2]], } monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore()) monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1]) results = retrieve("overview", top_k=2) assert len(results) == 1 assert results[0].chunk_id == chunk_ids[0]