356 lines
11 KiB
Python
356 lines
11 KiB
Python
"""Retrieval: query to ranked chunks."""
|
|
|
|
import json
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import atocore.config as _config
|
|
from atocore.models.database import get_connection
|
|
from atocore.observability.logger import get_logger
|
|
from atocore.projects.registry import RegisteredProject, get_registered_project, load_project_registry
|
|
from atocore.retrieval.embeddings import embed_query
|
|
from atocore.retrieval.vector_store import get_vector_store
|
|
|
|
log = get_logger("retriever")
|
|
|
|
_STOP_TOKENS = {
|
|
"about",
|
|
"and",
|
|
"current",
|
|
"for",
|
|
"from",
|
|
"into",
|
|
"like",
|
|
"project",
|
|
"shared",
|
|
"system",
|
|
"that",
|
|
"the",
|
|
"this",
|
|
"what",
|
|
"with",
|
|
}
|
|
|
|
_HIGH_SIGNAL_HINTS = (
|
|
"status",
|
|
"decision",
|
|
"requirements",
|
|
"requirement",
|
|
"roadmap",
|
|
"charter",
|
|
"system-map",
|
|
"system_map",
|
|
"contracts",
|
|
"schema",
|
|
"architecture",
|
|
"workflow",
|
|
"error-budget",
|
|
"comparison-matrix",
|
|
"selection-decision",
|
|
)
|
|
|
|
_LOW_SIGNAL_HINTS = (
|
|
"/_archive/",
|
|
"\\_archive\\",
|
|
"/archive/",
|
|
"\\archive\\",
|
|
"_history",
|
|
"history",
|
|
"pre-cleanup",
|
|
"pre-migration",
|
|
"reviews/",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ChunkResult:
|
|
chunk_id: str
|
|
content: str
|
|
score: float
|
|
heading_path: str
|
|
source_file: str
|
|
tags: str
|
|
title: str
|
|
document_id: str
|
|
|
|
|
|
def retrieve(
|
|
query: str,
|
|
top_k: int | None = None,
|
|
filter_tags: list[str] | None = None,
|
|
project_hint: str | None = None,
|
|
) -> list[ChunkResult]:
|
|
"""Retrieve the most relevant chunks for a query."""
|
|
top_k = top_k or _config.settings.context_top_k
|
|
start = time.time()
|
|
scoped_project = get_registered_project(project_hint) if project_hint else None
|
|
scope_filter_enabled = bool(scoped_project and _config.settings.rank_project_scope_filter)
|
|
registered_projects = None
|
|
query_top_k = top_k
|
|
if scope_filter_enabled:
|
|
query_top_k = max(
|
|
top_k,
|
|
top_k * max(1, _config.settings.rank_project_scope_candidate_multiplier),
|
|
)
|
|
try:
|
|
registered_projects = load_project_registry()
|
|
except Exception:
|
|
registered_projects = None
|
|
|
|
query_embedding = embed_query(query)
|
|
store = get_vector_store()
|
|
|
|
where = None
|
|
if filter_tags:
|
|
if len(filter_tags) == 1:
|
|
where = {"tags": {"$contains": f'"{filter_tags[0]}"'}}
|
|
else:
|
|
where = {
|
|
"$and": [
|
|
{"tags": {"$contains": f'"{tag}"'}}
|
|
for tag in filter_tags
|
|
]
|
|
}
|
|
|
|
results = store.query(
|
|
query_embedding=query_embedding,
|
|
top_k=query_top_k,
|
|
where=where,
|
|
)
|
|
|
|
chunks = []
|
|
raw_result_count = len(results["ids"][0]) if results and results["ids"] and results["ids"][0] else 0
|
|
if results and results["ids"] and results["ids"][0]:
|
|
existing_ids = _existing_chunk_ids(results["ids"][0])
|
|
for i, chunk_id in enumerate(results["ids"][0]):
|
|
if chunk_id not in existing_ids:
|
|
continue
|
|
|
|
distance = results["distances"][0][i] if results["distances"] else 0
|
|
score = 1.0 - distance
|
|
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
|
content = results["documents"][0][i] if results["documents"] else ""
|
|
|
|
if scope_filter_enabled and not _is_allowed_for_project_scope(
|
|
scoped_project,
|
|
meta,
|
|
registered_projects,
|
|
):
|
|
continue
|
|
|
|
score *= _query_match_boost(query, meta)
|
|
score *= _path_signal_boost(meta)
|
|
if project_hint:
|
|
score *= _project_match_boost(project_hint, meta)
|
|
|
|
chunks.append(
|
|
ChunkResult(
|
|
chunk_id=chunk_id,
|
|
content=content,
|
|
score=round(score, 4),
|
|
heading_path=meta.get("heading_path", ""),
|
|
source_file=meta.get("source_file", ""),
|
|
tags=meta.get("tags", "[]"),
|
|
title=meta.get("title", ""),
|
|
document_id=meta.get("document_id", ""),
|
|
)
|
|
)
|
|
|
|
duration_ms = int((time.time() - start) * 1000)
|
|
chunks.sort(key=lambda chunk: chunk.score, reverse=True)
|
|
post_filter_count = len(chunks)
|
|
chunks = chunks[:top_k]
|
|
|
|
log.info(
|
|
"retrieval_done",
|
|
query=query[:100],
|
|
top_k=top_k,
|
|
query_top_k=query_top_k,
|
|
raw_results_count=raw_result_count,
|
|
post_filter_count=post_filter_count,
|
|
results_count=len(chunks),
|
|
post_filter_dropped=max(0, raw_result_count - post_filter_count),
|
|
underfilled=bool(raw_result_count >= query_top_k and len(chunks) < top_k),
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
return chunks
|
|
|
|
|
|
def _is_allowed_for_project_scope(
|
|
project: RegisteredProject,
|
|
metadata: dict,
|
|
registered_projects: list[RegisteredProject] | None = None,
|
|
) -> bool:
|
|
"""Return True when a chunk is target-project or not project-owned.
|
|
|
|
Project-hinted retrieval should not let one registered project's corpus
|
|
compete with another's. At the same time, unowned/global sources should
|
|
remain eligible because shared docs and cross-project references can be
|
|
genuinely useful. The registry gives us the boundary: if metadata matches
|
|
a registered project and it is not the requested project, filter it out.
|
|
"""
|
|
if _metadata_matches_project(project, metadata):
|
|
return True
|
|
|
|
if registered_projects is None:
|
|
try:
|
|
registered_projects = load_project_registry()
|
|
except Exception:
|
|
return True
|
|
|
|
for other in registered_projects:
|
|
if other.project_id == project.project_id:
|
|
continue
|
|
if _metadata_matches_project(other, metadata):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _metadata_matches_project(project: RegisteredProject, metadata: dict) -> bool:
|
|
path = _metadata_source_path(metadata)
|
|
tags = _metadata_tags(metadata)
|
|
for term in _project_scope_terms(project):
|
|
if _path_matches_term(path, term) or term in tags:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _project_scope_terms(project: RegisteredProject) -> set[str]:
|
|
terms = {project.project_id.lower()}
|
|
terms.update(alias.lower() for alias in project.aliases)
|
|
for source_ref in project.ingest_roots:
|
|
normalized = source_ref.subpath.replace("\\", "/").strip("/").lower()
|
|
if normalized:
|
|
terms.add(normalized)
|
|
terms.add(normalized.split("/")[-1])
|
|
return {term for term in terms if term}
|
|
|
|
|
|
def _metadata_searchable(metadata: dict) -> str:
|
|
return " ".join(
|
|
[
|
|
str(metadata.get("source_file", "")).replace("\\", "/").lower(),
|
|
str(metadata.get("title", "")).lower(),
|
|
str(metadata.get("heading_path", "")).lower(),
|
|
str(metadata.get("tags", "")).lower(),
|
|
]
|
|
)
|
|
|
|
|
|
def _metadata_source_path(metadata: dict) -> str:
|
|
return str(metadata.get("source_file", "")).replace("\\", "/").strip("/").lower()
|
|
|
|
|
|
def _metadata_tags(metadata: dict) -> set[str]:
|
|
raw_tags = metadata.get("tags", [])
|
|
if isinstance(raw_tags, (list, tuple, set)):
|
|
return {str(tag).strip().lower() for tag in raw_tags if str(tag).strip()}
|
|
if isinstance(raw_tags, str):
|
|
try:
|
|
parsed = json.loads(raw_tags)
|
|
except json.JSONDecodeError:
|
|
parsed = [raw_tags]
|
|
if isinstance(parsed, (list, tuple, set)):
|
|
return {str(tag).strip().lower() for tag in parsed if str(tag).strip()}
|
|
if isinstance(parsed, str) and parsed.strip():
|
|
return {parsed.strip().lower()}
|
|
return set()
|
|
|
|
|
|
def _path_matches_term(path: str, term: str) -> bool:
|
|
normalized = term.replace("\\", "/").strip("/").lower()
|
|
if not path or not normalized:
|
|
return False
|
|
if "/" in normalized:
|
|
return path == normalized or path.startswith(f"{normalized}/")
|
|
return normalized in set(path.split("/"))
|
|
|
|
|
|
def _metadata_has_term(metadata: dict, term: str) -> bool:
|
|
normalized = term.replace("\\", "/").strip("/").lower()
|
|
if not normalized:
|
|
return False
|
|
if _path_matches_term(_metadata_source_path(metadata), normalized):
|
|
return True
|
|
if normalized in _metadata_tags(metadata):
|
|
return True
|
|
return re.search(
|
|
rf"(?<![a-z0-9]){re.escape(normalized)}(?![a-z0-9])",
|
|
_metadata_searchable(metadata),
|
|
) is not None
|
|
|
|
|
|
def _project_match_boost(project_hint: str, metadata: dict) -> float:
|
|
"""Return a project-aware relevance multiplier for raw retrieval."""
|
|
hint_lower = project_hint.strip().lower()
|
|
if not hint_lower:
|
|
return 1.0
|
|
|
|
project = get_registered_project(project_hint)
|
|
candidate_names = _project_scope_terms(project) if project is not None else {hint_lower}
|
|
for candidate in candidate_names:
|
|
if _metadata_has_term(metadata, candidate):
|
|
return _config.settings.rank_project_match_boost
|
|
|
|
return 1.0
|
|
|
|
|
|
def _query_match_boost(query: str, metadata: dict) -> float:
|
|
"""Boost chunks whose path/title/headings echo the query's high-signal terms."""
|
|
tokens = [
|
|
token
|
|
for token in re.findall(r"[a-z0-9][a-z0-9_-]{2,}", query.lower())
|
|
if token not in _STOP_TOKENS
|
|
]
|
|
if not tokens:
|
|
return 1.0
|
|
|
|
searchable = " ".join(
|
|
[
|
|
str(metadata.get("source_file", "")).lower(),
|
|
str(metadata.get("title", "")).lower(),
|
|
str(metadata.get("heading_path", "")).lower(),
|
|
]
|
|
)
|
|
matches = sum(1 for token in set(tokens) if token in searchable)
|
|
if matches <= 0:
|
|
return 1.0
|
|
return min(
|
|
1.0 + matches * _config.settings.rank_query_token_step,
|
|
_config.settings.rank_query_token_cap,
|
|
)
|
|
|
|
|
|
def _path_signal_boost(metadata: dict) -> float:
|
|
"""Prefer current high-signal docs and gently down-rank archival noise."""
|
|
searchable = " ".join(
|
|
[
|
|
str(metadata.get("source_file", "")).lower(),
|
|
str(metadata.get("title", "")).lower(),
|
|
str(metadata.get("heading_path", "")).lower(),
|
|
]
|
|
)
|
|
|
|
multiplier = 1.0
|
|
if any(hint in searchable for hint in _LOW_SIGNAL_HINTS):
|
|
multiplier *= _config.settings.rank_path_low_signal_penalty
|
|
if any(hint in searchable for hint in _HIGH_SIGNAL_HINTS):
|
|
multiplier *= _config.settings.rank_path_high_signal_boost
|
|
return multiplier
|
|
|
|
|
|
def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]:
|
|
"""Filter out stale vector entries whose chunk rows no longer exist."""
|
|
if not chunk_ids:
|
|
return set()
|
|
|
|
placeholders = ", ".join("?" for _ in chunk_ids)
|
|
with get_connection() as conn:
|
|
rows = conn.execute(
|
|
f"SELECT id FROM source_chunks WHERE id IN ({placeholders})",
|
|
chunk_ids,
|
|
).fetchall()
|
|
return {row["id"] for row in rows}
|