Add project-aware boost to raw query
This commit is contained in:
@@ -88,6 +88,7 @@ class QueryRequest(BaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
filter_tags: list[str] | None = None
|
filter_tags: list[str] | None = None
|
||||||
|
project: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
@@ -258,7 +259,12 @@ def api_refresh_project(project_name: str, purge_deleted: bool = False) -> Proje
|
|||||||
def api_query(req: QueryRequest) -> QueryResponse:
|
def api_query(req: QueryRequest) -> QueryResponse:
|
||||||
"""Retrieve relevant chunks for a prompt."""
|
"""Retrieve relevant chunks for a prompt."""
|
||||||
try:
|
try:
|
||||||
chunks = retrieve(req.prompt, top_k=req.top_k, filter_tags=req.filter_tags)
|
chunks = retrieve(
|
||||||
|
req.prompt,
|
||||||
|
top_k=req.top_k,
|
||||||
|
filter_tags=req.filter_tags,
|
||||||
|
project_hint=req.project,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("query_failed", prompt=req.prompt[:100], error=str(e))
|
log.error("query_failed", prompt=req.prompt[:100], error=str(e))
|
||||||
raise HTTPException(status_code=500, detail=f"Query failed: {e}")
|
raise HTTPException(status_code=500, detail=f"Query failed: {e}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
import atocore.config as _config
|
import atocore.config as _config
|
||||||
from atocore.models.database import get_connection
|
from atocore.models.database import get_connection
|
||||||
from atocore.observability.logger import get_logger
|
from atocore.observability.logger import get_logger
|
||||||
|
from atocore.projects.registry import get_registered_project
|
||||||
from atocore.retrieval.embeddings import embed_query
|
from atocore.retrieval.embeddings import embed_query
|
||||||
from atocore.retrieval.vector_store import get_vector_store
|
from atocore.retrieval.vector_store import get_vector_store
|
||||||
|
|
||||||
@@ -28,6 +29,7 @@ def retrieve(
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int | None = None,
|
top_k: int | None = None,
|
||||||
filter_tags: list[str] | None = None,
|
filter_tags: list[str] | None = None,
|
||||||
|
project_hint: str | None = None,
|
||||||
) -> list[ChunkResult]:
|
) -> list[ChunkResult]:
|
||||||
"""Retrieve the most relevant chunks for a query."""
|
"""Retrieve the most relevant chunks for a query."""
|
||||||
top_k = top_k or _config.settings.context_top_k
|
top_k = top_k or _config.settings.context_top_k
|
||||||
@@ -71,6 +73,9 @@ def retrieve(
|
|||||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||||
content = results["documents"][0][i] if results["documents"] else ""
|
content = results["documents"][0][i] if results["documents"] else ""
|
||||||
|
|
||||||
|
if project_hint:
|
||||||
|
score *= _project_match_boost(project_hint, meta)
|
||||||
|
|
||||||
chunks.append(
|
chunks.append(
|
||||||
ChunkResult(
|
ChunkResult(
|
||||||
chunk_id=chunk_id,
|
chunk_id=chunk_id,
|
||||||
@@ -85,6 +90,8 @@ def retrieve(
|
|||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = int((time.time() - start) * 1000)
|
duration_ms = int((time.time() - start) * 1000)
|
||||||
|
chunks.sort(key=lambda chunk: chunk.score, reverse=True)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"retrieval_done",
|
"retrieval_done",
|
||||||
query=query[:100],
|
query=query[:100],
|
||||||
@@ -96,6 +103,35 @@ def retrieve(
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _project_match_boost(project_hint: str, metadata: dict) -> float:
|
||||||
|
"""Return a project-aware relevance multiplier for raw retrieval."""
|
||||||
|
hint_lower = project_hint.strip().lower()
|
||||||
|
if not hint_lower:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
source_file = str(metadata.get("source_file", "")).lower()
|
||||||
|
title = str(metadata.get("title", "")).lower()
|
||||||
|
tags = str(metadata.get("tags", "")).lower()
|
||||||
|
searchable = " ".join([source_file, title, tags])
|
||||||
|
|
||||||
|
project = get_registered_project(project_hint)
|
||||||
|
candidate_names = {hint_lower}
|
||||||
|
if project is not None:
|
||||||
|
candidate_names.add(project.project_id.lower())
|
||||||
|
candidate_names.update(alias.lower() for alias in project.aliases)
|
||||||
|
candidate_names.update(
|
||||||
|
source_ref.subpath.replace("\\", "/").strip("/").split("/")[-1].lower()
|
||||||
|
for source_ref in project.ingest_roots
|
||||||
|
if source_ref.subpath.strip("/\\")
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidate_names:
|
||||||
|
if candidate and candidate in searchable:
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
|
||||||
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
|
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
|
||||||
"""Filter out stale vector entries whose chunk rows no longer exist."""
|
"""Filter out stale vector entries whose chunk rows no longer exist."""
|
||||||
if not chunk_ids:
|
if not chunk_ids:
|
||||||
|
|||||||
@@ -393,3 +393,26 @@ def test_project_update_endpoint_rejects_collisions(tmp_data_dir, monkeypatch):
|
|||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "collisions" in response.json()["detail"]
|
assert "collisions" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_endpoint_accepts_project_hint(monkeypatch):
|
||||||
|
def fake_retrieve(prompt, top_k=10, filter_tags=None, project_hint=None):
|
||||||
|
assert prompt == "architecture"
|
||||||
|
assert top_k == 3
|
||||||
|
assert project_hint == "p04-gigabit"
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr("atocore.api.routes.retrieve", fake_retrieve)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.post(
|
||||||
|
"/query",
|
||||||
|
json={
|
||||||
|
"prompt": "architecture",
|
||||||
|
"top_k": 3,
|
||||||
|
"project": "p04-gigabit",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["results"] == []
|
||||||
|
|||||||
@@ -67,3 +67,54 @@ def test_retrieve_skips_stale_vector_entries(tmp_data_dir, sample_markdown, monk
|
|||||||
results = retrieve("overview", top_k=2)
|
results = retrieve("overview", top_k=2)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].chunk_id == chunk_ids[0]
|
assert results[0].chunk_id == chunk_ids[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_project_hint_boosts_matching_chunks(monkeypatch):
|
||||||
|
class FakeStore:
|
||||||
|
def query(self, query_embedding, top_k=10, where=None):
|
||||||
|
return {
|
||||||
|
"ids": [["chunk-a", "chunk-b"]],
|
||||||
|
"documents": [["project doc", "other doc"]],
|
||||||
|
"metadatas": [[
|
||||||
|
{
|
||||||
|
"heading_path": "Overview",
|
||||||
|
"source_file": "p04-gigabit/pkm/_index.md",
|
||||||
|
"tags": '["p04-gigabit"]',
|
||||||
|
"title": "P04",
|
||||||
|
"document_id": "doc-a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"heading_path": "Overview",
|
||||||
|
"source_file": "p05-interferometer/pkm/_index.md",
|
||||||
|
"tags": '["p05-interferometer"]',
|
||||||
|
"title": "P05",
|
||||||
|
"document_id": "doc-b",
|
||||||
|
},
|
||||||
|
]],
|
||||||
|
"distances": [[0.3, 0.25]],
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("atocore.retrieval.retriever.get_vector_store", lambda: FakeStore())
|
||||||
|
monkeypatch.setattr("atocore.retrieval.retriever.embed_query", lambda query: [0.0, 0.1])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"atocore.retrieval.retriever._existing_chunk_ids",
|
||||||
|
lambda chunk_ids: set(chunk_ids),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"atocore.retrieval.retriever.get_registered_project",
|
||||||
|
lambda project_name: type(
|
||||||
|
"Project",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"project_id": "p04-gigabit",
|
||||||
|
"aliases": ("p04", "gigabit"),
|
||||||
|
"ingest_roots": (),
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = retrieve("mirror architecture", top_k=2, project_hint="p04")
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].chunk_id == "chunk-a"
|
||||||
|
assert results[0].score > results[1].score
|
||||||
|
|||||||
Reference in New Issue
Block a user