84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
|
|
"""Retrieval: query → ranked chunks."""
|
||
|
|
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from atocore.config import settings
|
||
|
|
from atocore.observability.logger import get_logger
|
||
|
|
from atocore.retrieval.embeddings import embed_query
|
||
|
|
from atocore.retrieval.vector_store import get_vector_store
|
||
|
|
|
||
|
|
log = get_logger("retriever")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ChunkResult:
|
||
|
|
chunk_id: str
|
||
|
|
content: str
|
||
|
|
score: float
|
||
|
|
heading_path: str
|
||
|
|
source_file: str
|
||
|
|
tags: str
|
||
|
|
title: str
|
||
|
|
document_id: str
|
||
|
|
|
||
|
|
|
||
|
|
def retrieve(
|
||
|
|
query: str,
|
||
|
|
top_k: int | None = None,
|
||
|
|
filter_tags: list[str] | None = None,
|
||
|
|
) -> list[ChunkResult]:
|
||
|
|
"""Retrieve the most relevant chunks for a query."""
|
||
|
|
top_k = top_k or settings.context_top_k
|
||
|
|
start = time.time()
|
||
|
|
|
||
|
|
query_embedding = embed_query(query)
|
||
|
|
store = get_vector_store()
|
||
|
|
|
||
|
|
# Build filter
|
||
|
|
where = None
|
||
|
|
if filter_tags:
|
||
|
|
# ChromaDB where filter for tags (stored as JSON string)
|
||
|
|
# Simple contains check — works for single-tag filtering
|
||
|
|
where = {"tags": {"$contains": filter_tags[0]}}
|
||
|
|
|
||
|
|
results = store.query(
|
||
|
|
query_embedding=query_embedding,
|
||
|
|
top_k=top_k,
|
||
|
|
where=where,
|
||
|
|
)
|
||
|
|
|
||
|
|
chunks = []
|
||
|
|
if results and results["ids"] and results["ids"][0]:
|
||
|
|
for i, chunk_id in enumerate(results["ids"][0]):
|
||
|
|
# ChromaDB returns distances (lower = more similar for cosine)
|
||
|
|
# Convert to similarity score (1 - distance)
|
||
|
|
distance = results["distances"][0][i] if results["distances"] else 0
|
||
|
|
score = 1.0 - distance
|
||
|
|
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||
|
|
content = results["documents"][0][i] if results["documents"] else ""
|
||
|
|
|
||
|
|
chunks.append(
|
||
|
|
ChunkResult(
|
||
|
|
chunk_id=chunk_id,
|
||
|
|
content=content,
|
||
|
|
score=round(score, 4),
|
||
|
|
heading_path=meta.get("heading_path", ""),
|
||
|
|
source_file=meta.get("source_file", ""),
|
||
|
|
tags=meta.get("tags", "[]"),
|
||
|
|
title=meta.get("title", ""),
|
||
|
|
document_id=meta.get("document_id", ""),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
duration_ms = int((time.time() - start) * 1000)
|
||
|
|
log.info(
|
||
|
|
"retrieval_done",
|
||
|
|
query=query[:100],
|
||
|
|
top_k=top_k,
|
||
|
|
results_count=len(chunks),
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
return chunks
|