Add project-aware boost to raw query

This commit is contained in:
2026-04-06 13:32:33 -04:00
parent 4aa2b696a9
commit 26bfa94c65
4 changed files with 117 additions and 1 deletions

View File

@@ -393,3 +393,26 @@ def test_project_update_endpoint_rejects_collisions(tmp_data_dir, monkeypatch):
assert response.status_code == 400
assert "collisions" in response.json()["detail"]
def test_query_endpoint_accepts_project_hint(monkeypatch):
def fake_retrieve(prompt, top_k=10, filter_tags=None, project_hint=None):
assert prompt == "architecture"
assert top_k == 3
assert project_hint == "p04-gigabit"
return []
monkeypatch.setattr("atocore.api.routes.retrieve", fake_retrieve)
client = TestClient(app)
response = client.post(
"/query",
json={
"prompt": "architecture",
"top_k": 3,
"project": "p04-gigabit",
},
)
assert response.status_code == 200
assert response.json()["results"] == []

View File

@@ -67,3 +67,54 @@ def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monk
results = retrieve("overview", top_k=2)
assert len(results) == 1
assert results[0].chunk_id == chunk_ids[0]
def test_retrieve_project_hint_boosts_matching_chunks(monkeypatch):
class FakeStore:
def query(self, query_embedding, top_k=10, where=None):
return {
"ids": [["chunk-a", "chunk-b"]],
"documents": [["project doc", "other doc"]],
"metadatas": [[
{
"heading_path": "Overview",
"source_file": "p04-gigabit/pkm/_index.md",
"tags": '["p04-gigabit"]',
"title": "P04",
"document_id": "doc-a",
},
{
"heading_path": "Overview",
"source_file": "p05-interferometer/pkm/_index.md",
"tags": '["p05-interferometer"]',
"title": "P05",
"document_id": "doc-b",
},
]],
"distances": [[0.3, 0.25]],
}
monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore())
monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1])
monkeypatch.setattr(
"atocore.retrieval.retriever._existing_chunk_ids",
lambda chunk_ids: set(chunk_ids),
)
monkeypatch.setattr(
"atocore.retrieval.retriever.get_registered_project",
lambda project_name: type(
"Project",
(),
{
"project_id": "p04-gigabit",
"aliases": ("p04", "gigabit"),
"ingest_roots": (),
},
)(),
)
results = retrieve("mirror architecture", top_k=2, project_hint="p04")
assert len(results) == 2
assert results[0].chunk_id == "chunk-a"
assert results[0].score > results[1].score