Add project-aware boost to raw query
This commit is contained in:
@@ -393,3 +393,26 @@ def test_project_update_endpoint_rejects_collisions(tmp_data_dir, monkeypatch):
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "collisions" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_query_endpoint_accepts_project_hint(monkeypatch):
|
||||
def fake_retrieve(prompt, top_k=10, filter_tags=None, project_hint=None):
|
||||
assert prompt == "architecture"
|
||||
assert top_k == 3
|
||||
assert project_hint == "p04-gigabit"
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("atocore.api.routes.retrieve", fake_retrieve)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/query",
|
||||
json={
|
||||
"prompt": "architecture",
|
||||
"top_k": 3,
|
||||
"project": "p04-gigabit",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["results"] == []
|
||||
|
||||
@@ -67,3 +67,54 @@ def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monk
|
||||
results = retrieve("overview", top_k=2)
|
||||
assert len(results) == 1
|
||||
assert results[0].chunk_id == chunk_ids[0]
|
||||
|
||||
|
||||
def test_retrieve_project_hint_boosts_matching_chunks(monkeypatch):
|
||||
class FakeStore:
|
||||
def query(self, query_embedding, top_k=10, where=None):
|
||||
return {
|
||||
"ids": [["chunk-a", "chunk-b"]],
|
||||
"documents": [["project doc", "other doc"]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"heading_path": "Overview",
|
||||
"source_file": "p04-gigabit/pkm/_index.md",
|
||||
"tags": '["p04-gigabit"]',
|
||||
"title": "P04",
|
||||
"document_id": "doc-a",
|
||||
},
|
||||
{
|
||||
"heading_path": "Overview",
|
||||
"source_file": "p05-interferometer/pkm/_index.md",
|
||||
"tags": '["p05-interferometer"]',
|
||||
"title": "P05",
|
||||
"document_id": "doc-b",
|
||||
},
|
||||
]],
|
||||
"distances": [[0.3, 0.25]],
|
||||
}
|
||||
|
||||
monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore())
|
||||
monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1])
|
||||
monkeypatch.setattr(
|
||||
"atocore.retrieval.retriever._existing_chunk_ids",
|
||||
lambda chunk_ids: set(chunk_ids),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"atocore.retrieval.retriever.get_registered_project",
|
||||
lambda project_name: type(
|
||||
"Project",
|
||||
(),
|
||||
{
|
||||
"project_id": "p04-gigabit",
|
||||
"aliases": ("p04", "gigabit"),
|
||||
"ingest_roots": (),
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
results = retrieve("mirror architecture", top_k=2, project_hint="p04")
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].chunk_id == "chunk-a"
|
||||
assert results[0].score > results[1].score
|
||||
|
||||
Reference in New Issue
Block a user