Files
ATOCore/src/atocore/retrieval/retriever.py

111 lines
3.3 KiB
Python
Raw Normal View History

"""Retrieval: query → ranked chunks."""
import time
from dataclasses import dataclass
import atocore.config as _config
from atocore.models.database import get_connection
from atocore.observability.logger import get_logger
from atocore.retrieval.embeddings import embed_query
from atocore.retrieval.vector_store import get_vector_store
log = get_logger("retriever")
@dataclass
class ChunkResult:
chunk_id: str
content: str
score: float
heading_path: str
source_file: str
tags: str
title: str
document_id: str
def retrieve(
query: str,
top_k: int | None = None,
filter_tags: list[str] | None = None,
) -> list[ChunkResult]:
"""Retrieve the most relevant chunks for a query."""
top_k = top_k or _config.settings.context_top_k
start = time.time()
query_embedding = embed_query(query)
store = get_vector_store()
# Build filter
# Tags are stored as JSON strings like '["tag1", "tag2"]'.
# We use $contains with quoted tag to avoid substring false positives
# (e.g. searching "prod" won't match "production" because we search '"prod"').
where = None
if filter_tags:
if len(filter_tags) == 1:
where = {"tags": {"$contains": f'"{filter_tags[0]}"'}}
else:
where = {
"$and": [
{"tags": {"$contains": f'"{tag}"'}}
for tag in filter_tags
]
}
results = store.query(
query_embedding=query_embedding,
top_k=top_k,
where=where,
)
chunks = []
if results and results["ids"] and results["ids"][0]:
existing_ids = _existing_chunk_ids(results["ids"][0])
for i, chunk_id in enumerate(results["ids"][0]):
if chunk_id not in existing_ids:
continue
# ChromaDB returns distances (lower = more similar for cosine)
# Convert to similarity score (1 - distance)
distance = results["distances"][0][i] if results["distances"] else 0
score = 1.0 - distance
meta = results["metadatas"][0][i] if results["metadatas"] else {}
content = results["documents"][0][i] if results["documents"] else ""
chunks.append(
ChunkResult(
chunk_id=chunk_id,
content=content,
score=round(score, 4),
heading_path=meta.get("heading_path", ""),
source_file=meta.get("source_file", ""),
tags=meta.get("tags", "[]"),
title=meta.get("title", ""),
document_id=meta.get("document_id", ""),
)
)
duration_ms = int((time.time() - start) * 1000)
log.info(
"retrieval_done",
query=query[:100],
top_k=top_k,
results_count=len(chunks),
duration_ms=duration_ms,
)
return chunks
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
"""Filter out stale vector entries whose chunk rows no longer exist."""
if not chunk_ids:
return set()
placeholders = ", ".join("?" for _ in chunk_ids)
with get_connection() as conn:
rows = conn.execute(
f"SELECT id FROM source_chunks WHERE id IN ({placeholders})",
chunk_ids,
).fetchall()
return {row["id"] for row in rows}