Files
ATOCore/src/atocore/retrieval/retriever.py

356 lines
11 KiB
Python

"""Retrieval: query to ranked chunks."""
import json
import re
import time
from dataclasses import dataclass
import atocore.config as _config
from atocore.models.database import get_connection
from atocore.observability.logger import get_logger
from atocore.projects.registry import RegisteredProject, get_registered_project, load_project_registry
from atocore.retrieval.embeddings import embed_query
from atocore.retrieval.vector_store import get_vector_store
log = get_logger("retriever")
_STOP_TOKENS = {
"about",
"and",
"current",
"for",
"from",
"into",
"like",
"project",
"shared",
"system",
"that",
"the",
"this",
"what",
"with",
}
_HIGH_SIGNAL_HINTS = (
"status",
"decision",
"requirements",
"requirement",
"roadmap",
"charter",
"system-map",
"system_map",
"contracts",
"schema",
"architecture",
"workflow",
"error-budget",
"comparison-matrix",
"selection-decision",
)
_LOW_SIGNAL_HINTS = (
"/_archive/",
"\\_archive\\",
"/archive/",
"\\archive\\",
"_history",
"history",
"pre-cleanup",
"pre-migration",
"reviews/",
)
@dataclass
class ChunkResult:
chunk_id: str
content: str
score: float
heading_path: str
source_file: str
tags: str
title: str
document_id: str
def retrieve(
query: str,
top_k: int | None = None,
filter_tags: list[str] | None = None,
project_hint: str | None = None,
) -> list[ChunkResult]:
"""Retrieve the most relevant chunks for a query."""
top_k = top_k or _config.settings.context_top_k
start = time.time()
scoped_project = get_registered_project(project_hint) if project_hint else None
scope_filter_enabled = bool(scoped_project and _config.settings.rank_project_scope_filter)
registered_projects = None
query_top_k = top_k
if scope_filter_enabled:
query_top_k = max(
top_k,
top_k * max(1, _config.settings.rank_project_scope_candidate_multiplier),
)
try:
registered_projects = load_project_registry()
except Exception:
registered_projects = None
query_embedding = embed_query(query)
store = get_vector_store()
where = None
if filter_tags:
if len(filter_tags) == 1:
where = {"tags": {"$contains": f'"{filter_tags[0]}"'}}
else:
where = {
"$and": [
{"tags": {"$contains": f'"{tag}"'}}
for tag in filter_tags
]
}
results = store.query(
query_embedding=query_embedding,
top_k=query_top_k,
where=where,
)
chunks = []
raw_result_count = len(results["ids"][0]) if results and results["ids"] and results["ids"][0] else 0
if results and results["ids"] and results["ids"][0]:
existing_ids = _existing_chunk_ids(results["ids"][0])
for i, chunk_id in enumerate(results["ids"][0]):
if chunk_id not in existing_ids:
continue
distance = results["distances"][0][i] if results["distances"] else 0
score = 1.0 - distance
meta = results["metadatas"][0][i] if results["metadatas"] else {}
content = results["documents"][0][i] if results["documents"] else ""
if scope_filter_enabled and not _is_allowed_for_project_scope(
scoped_project,
meta,
registered_projects,
):
continue
score *= _query_match_boost(query, meta)
score *= _path_signal_boost(meta)
if project_hint:
score *= _project_match_boost(project_hint, meta)
chunks.append(
ChunkResult(
chunk_id=chunk_id,
content=content,
score=round(score, 4),
heading_path=meta.get("heading_path", ""),
source_file=meta.get("source_file", ""),
tags=meta.get("tags", "[]"),
title=meta.get("title", ""),
document_id=meta.get("document_id", ""),
)
)
duration_ms = int((time.time() - start) * 1000)
chunks.sort(key=lambda chunk: chunk.score, reverse=True)
post_filter_count = len(chunks)
chunks = chunks[:top_k]
log.info(
"retrieval_done",
query=query[:100],
top_k=top_k,
query_top_k=query_top_k,
raw_results_count=raw_result_count,
post_filter_count=post_filter_count,
results_count=len(chunks),
post_filter_dropped=max(0, raw_result_count - post_filter_count),
underfilled=bool(raw_result_count >= query_top_k and len(chunks) < top_k),
duration_ms=duration_ms,
)
return chunks
def _is_allowed_for_project_scope(
project: RegisteredProject,
metadata: dict,
registered_projects: list[RegisteredProject] | None = None,
) -> bool:
"""Return True when a chunk is target-project or not project-owned.
Project-hinted retrieval should not let one registered project's corpus
compete with another's. At the same time, unowned/global sources should
remain eligible because shared docs and cross-project references can be
genuinely useful. The registry gives us the boundary: if metadata matches
a registered project and it is not the requested project, filter it out.
"""
if _metadata_matches_project(project, metadata):
return True
if registered_projects is None:
try:
registered_projects = load_project_registry()
except Exception:
return True
for other in registered_projects:
if other.project_id == project.project_id:
continue
if _metadata_matches_project(other, metadata):
return False
return True
def _metadata_matches_project(project: RegisteredProject, metadata: dict) -> bool:
path = _metadata_source_path(metadata)
tags = _metadata_tags(metadata)
for term in _project_scope_terms(project):
if _path_matches_term(path, term) or term in tags:
return True
return False
def _project_scope_terms(project: RegisteredProject) -> set[str]:
terms = {project.project_id.lower()}
terms.update(alias.lower() for alias in project.aliases)
for source_ref in project.ingest_roots:
normalized = source_ref.subpath.replace("\\", "/").strip("/").lower()
if normalized:
terms.add(normalized)
terms.add(normalized.split("/")[-1])
return {term for term in terms if term}
def _metadata_searchable(metadata: dict) -> str:
return " ".join(
[
str(metadata.get("source_file", "")).replace("\\", "/").lower(),
str(metadata.get("title", "")).lower(),
str(metadata.get("heading_path", "")).lower(),
str(metadata.get("tags", "")).lower(),
]
)
def _metadata_source_path(metadata: dict) -> str:
return str(metadata.get("source_file", "")).replace("\\", "/").strip("/").lower()
def _metadata_tags(metadata: dict) -> set[str]:
raw_tags = metadata.get("tags", [])
if isinstance(raw_tags, (list, tuple, set)):
return {str(tag).strip().lower() for tag in raw_tags if str(tag).strip()}
if isinstance(raw_tags, str):
try:
parsed = json.loads(raw_tags)
except json.JSONDecodeError:
parsed = [raw_tags]
if isinstance(parsed, (list, tuple, set)):
return {str(tag).strip().lower() for tag in parsed if str(tag).strip()}
if isinstance(parsed, str) and parsed.strip():
return {parsed.strip().lower()}
return set()
def _path_matches_term(path: str, term: str) -> bool:
normalized = term.replace("\\", "/").strip("/").lower()
if not path or not normalized:
return False
if "/" in normalized:
return path == normalized or path.startswith(f"{normalized}/")
return normalized in set(path.split("/"))
def _metadata_has_term(metadata: dict, term: str) -> bool:
normalized = term.replace("\\", "/").strip("/").lower()
if not normalized:
return False
if _path_matches_term(_metadata_source_path(metadata), normalized):
return True
if normalized in _metadata_tags(metadata):
return True
return re.search(
rf"(?<![a-z0-9]){re.escape(normalized)}(?![a-z0-9])",
_metadata_searchable(metadata),
) is not None
def _project_match_boost(project_hint: str, metadata: dict) -> float:
"""Return a project-aware relevance multiplier for raw retrieval."""
hint_lower = project_hint.strip().lower()
if not hint_lower:
return 1.0
project = get_registered_project(project_hint)
candidate_names = _project_scope_terms(project) if project is not None else {hint_lower}
for candidate in candidate_names:
if _metadata_has_term(metadata, candidate):
return _config.settings.rank_project_match_boost
return 1.0
def _query_match_boost(query: str, metadata: dict) -> float:
"""Boost chunks whose path/title/headings echo the query's high-signal terms."""
tokens = [
token
for token in re.findall(r"[a-z0-9][a-z0-9_-]{2,}", query.lower())
if token not in _STOP_TOKENS
]
if not tokens:
return 1.0
searchable = " ".join(
[
str(metadata.get("source_file", "")).lower(),
str(metadata.get("title", "")).lower(),
str(metadata.get("heading_path", "")).lower(),
]
)
matches = sum(1 for token in set(tokens) if token in searchable)
if matches <= 0:
return 1.0
return min(
1.0 + matches * _config.settings.rank_query_token_step,
_config.settings.rank_query_token_cap,
)
def _path_signal_boost(metadata: dict) -> float:
"""Prefer current high-signal docs and gently down-rank archival noise."""
searchable = " ".join(
[
str(metadata.get("source_file", "")).lower(),
str(metadata.get("title", "")).lower(),
str(metadata.get("heading_path", "")).lower(),
]
)
multiplier = 1.0
if any(hint in searchable for hint in _LOW_SIGNAL_HINTS):
multiplier *= _config.settings.rank_path_low_signal_penalty
if any(hint in searchable for hint in _HIGH_SIGNAL_HINTS):
multiplier *= _config.settings.rank_path_high_signal_boost
return multiplier
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
"""Filter out stale vector entries whose chunk rows no longer exist."""
if not chunk_ids:
return set()
placeholders = ", ".join("?" for _ in chunk_ids)
with get_connection() as conn:
rows = conn.execute(
f"SELECT id FROM source_chunks WHERE id IN ({placeholders})",
chunk_ids,
).fetchall()
return {row["id"] for row in rows}