diff --git a/src/atocore/api/routes.py b/src/atocore/api/routes.py index d723eaf..6d268d1 100644 --- a/src/atocore/api/routes.py +++ b/src/atocore/api/routes.py @@ -88,6 +88,7 @@ class QueryRequest(BaseModel): prompt: str top_k: int = 10 filter_tags: list[str] | None = None + project: str | None = None class QueryResponse(BaseModel): @@ -258,7 +259,12 @@ def api_refresh_project(project_name: str, purge_deleted: bool = False) -> Proje def api_query(req: QueryRequest) -> QueryResponse: """Retrieve relevant chunks for a prompt.""" try: - chunks = retrieve(req.prompt, top_k=req.top_k, filter_tags=req.filter_tags) + chunks = retrieve( + req.prompt, + top_k=req.top_k, + filter_tags=req.filter_tags, + project_hint=req.project, + ) except Exception as e: log.error("query_failed", prompt=req.prompt[:100], error=str(e)) raise HTTPException(status_code=500, detail=f"Query failed: {e}") diff --git a/src/atocore/retrieval/retriever.py b/src/atocore/retrieval/retriever.py index fa19362..524a523 100644 --- a/src/atocore/retrieval/retriever.py +++ b/src/atocore/retrieval/retriever.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import atocore.config as _config from atocore.models.database import get_connection from atocore.observability.logger import get_logger +from atocore.projects.registry import get_registered_project from atocore.retrieval.embeddings import embed_query from atocore.retrieval.vector_store import get_vector_store @@ -28,6 +29,7 @@ def retrieve( query: str, top_k: int | None = None, filter_tags: list[str] | None = None, + project_hint: str | None = None, ) -> list[ChunkResult]: """Retrieve the most relevant chunks for a query.""" top_k = top_k or _config.settings.context_top_k @@ -71,6 +73,9 @@ def retrieve( meta = results["metadatas"][0][i] if results["metadatas"] else {} content = results["documents"][0][i] if results["documents"] else "" + if project_hint: + score *= _project_match_boost(project_hint, meta) + chunks.append( ChunkResult( chunk_id=chunk_id, @@ -85,6 +90,8 @@ def retrieve( ) duration_ms = int((time.time() - start) * 1000) + chunks.sort(key=lambda chunk: chunk.score, reverse=True) + log.info( "retrieval_done", query=query[:100], @@ -96,6 +103,35 @@ def retrieve( return chunks +def _project_match_boost(project_hint: str, metadata: dict) -> float: + """Return a project-aware relevance multiplier for raw retrieval.""" + hint_lower = project_hint.strip().lower() + if not hint_lower: + return 1.0 + + source_file = str(metadata.get("source_file", "")).lower() + title = str(metadata.get("title", "")).lower() + tags = str(metadata.get("tags", "")).lower() + searchable = " ".join([source_file, title, tags]) + + project = get_registered_project(project_hint) + candidate_names = {hint_lower} + if project is not None: + candidate_names.add(project.project_id.lower()) + candidate_names.update(alias.lower() for alias in project.aliases) + candidate_names.update( + source_ref.subpath.replace("\\", "/").strip("/").split("/")[-1].lower() + for source_ref in project.ingest_roots + if source_ref.subpath.strip("/\\") + ) + + for candidate in candidate_names: + if candidate and candidate in searchable: + return 2.0 + + return 1.0 + + def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]: """Filter out stale vector entries whose chunk rows no longer exist.""" if not chunk_ids: diff --git a/tests/test_api_storage.py b/tests/test_api_storage.py index 7492bf3..25fee45 100644 --- a/tests/test_api_storage.py +++ b/tests/test_api_storage.py @@ -393,3 +393,26 @@ def test_project_update_endpoint_rejects_collisions(tmp_data_dir, monkeypatch): assert response.status_code == 400 assert "collisions" in response.json()["detail"] + + +def test_query_endpoint_accepts_project_hint(monkeypatch): + def fake_retrieve(prompt, top_k=10, filter_tags=None, project_hint=None): + assert prompt == "architecture" + assert top_k == 3 + assert project_hint == "p04-gigabit" + return [] + + monkeypatch.setattr("atocore.api.routes.retrieve", fake_retrieve) + + client = TestClient(app) + response = client.post( + "/query", + json={ + "prompt": "architecture", + "top_k": 3, + "project": "p04-gigabit", + }, + ) + + assert response.status_code == 200 + assert response.json()["results"] == [] diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 4b24c53..46ff746 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -67,3 +67,54 @@ def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monk results = retrieve("overview", top_k=2) assert len(results) == 1 assert results[0].chunk_id == chunk_ids[0] + + +def test_retrieve_project_hint_boosts_matching_chunks(monkeypatch): + class FakeStore: + def query(self, query_embedding, top_k=10, where=None): + return { + "ids": [["chunk-a", "chunk-b"]], + "documents": [["project doc", "other doc"]], + "metadatas": [[ + { + "heading_path": "Overview", + "source_file": "p04-gigabit/pkm/_index.md", + "tags": '["p04-gigabit"]', + "title": "P04", + "document_id": "doc-a", + }, + { + "heading_path": "Overview", + "source_file": "p05-interferometer/pkm/_index.md", + "tags": '["p05-interferometer"]', + "title": "P05", + "document_id": "doc-b", + }, + ]], + "distances": [[0.3, 0.25]], + } + + monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore()) + monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1]) + monkeypatch.setattr( + "atocore.retrieval.retriever._existing_chunk_ids", + lambda chunk_ids: set(chunk_ids), + ) + monkeypatch.setattr( + "atocore.retrieval.retriever.get_registered_project", + lambda project_name: type( + "Project", + (), + { + "project_id": "p04-gigabit", + "aliases": ("p04", "gigabit"), + "ingest_roots": (), + }, + )(), + ) + + results = retrieve("mirror architecture", top_k=2, project_hint="p04") + + assert len(results) == 2 + assert results[0].chunk_id == "chunk-a" + assert results[0].score > results[1].score