"""Retrieval: query to ranked chunks.""" import json import re import time from dataclasses import dataclass import atocore.config as _config from atocore.models.database import get_connection from atocore.observability.logger import get_logger from atocore.projects.registry import RegisteredProject, get_registered_project, load_project_registry from atocore.retrieval.embeddings import embed_query from atocore.retrieval.vector_store import get_vector_store log = get_logger("retriever") _STOP_TOKENS = { "about", "and", "current", "for", "from", "into", "like", "project", "shared", "system", "that", "the", "this", "what", "with", } _HIGH_SIGNAL_HINTS = ( "status", "decision", "requirements", "requirement", "roadmap", "charter", "system-map", "system_map", "contracts", "schema", "architecture", "workflow", "error-budget", "comparison-matrix", "selection-decision", ) _LOW_SIGNAL_HINTS = ( "/_archive/", "\\_archive\\", "/archive/", "\\archive\\", "_history", "history", "pre-cleanup", "pre-migration", "reviews/", ) @dataclass class ChunkResult: chunk_id: str content: str score: float heading_path: str source_file: str tags: str title: str document_id: str def retrieve( query: str, top_k: int | None = None, filter_tags: list[str] | None = None, project_hint: str | None = None, ) -> list[ChunkResult]: """Retrieve the most relevant chunks for a query.""" top_k = top_k or _config.settings.context_top_k start = time.time() scoped_project = get_registered_project(project_hint) if project_hint else None scope_filter_enabled = bool(scoped_project and _config.settings.rank_project_scope_filter) registered_projects = None query_top_k = top_k if scope_filter_enabled: query_top_k = max( top_k, top_k * max(1, _config.settings.rank_project_scope_candidate_multiplier), ) try: registered_projects = load_project_registry() except Exception: registered_projects = None query_embedding = embed_query(query) store = get_vector_store() where = None if filter_tags: if len(filter_tags) == 1: where = {"tags": {"$contains": f'"{filter_tags[0]}"'}} else: where = { "$and": [ {"tags": {"$contains": f'"{tag}"'}} for tag in filter_tags ] } results = store.query( query_embedding=query_embedding, top_k=query_top_k, where=where, ) chunks = [] raw_result_count = len(results["ids"][0]) if results and results["ids"] and results["ids"][0] else 0 if results and results["ids"] and results["ids"][0]: existing_ids = _existing_chunk_ids(results["ids"][0]) for i, chunk_id in enumerate(results["ids"][0]): if chunk_id not in existing_ids: continue distance = results["distances"][0][i] if results["distances"] else 0 score = 1.0 - distance meta = results["metadatas"][0][i] if results["metadatas"] else {} content = results["documents"][0][i] if results["documents"] else "" if scope_filter_enabled and not _is_allowed_for_project_scope( scoped_project, meta, registered_projects, ): continue score *= _query_match_boost(query, meta) score *= _path_signal_boost(meta) if project_hint: score *= _project_match_boost(project_hint, meta) chunks.append( ChunkResult( chunk_id=chunk_id, content=content, score=round(score, 4), heading_path=meta.get("heading_path", ""), source_file=meta.get("source_file", ""), tags=meta.get("tags", "[]"), title=meta.get("title", ""), document_id=meta.get("document_id", ""), ) ) duration_ms = int((time.time() - start) * 1000) chunks.sort(key=lambda chunk: chunk.score, reverse=True) post_filter_count = len(chunks) chunks = chunks[:top_k] log.info( "retrieval_done", query=query[:100], top_k=top_k, query_top_k=query_top_k, raw_results_count=raw_result_count, post_filter_count=post_filter_count, results_count=len(chunks), post_filter_dropped=max(0, raw_result_count - post_filter_count), underfilled=bool(raw_result_count >= query_top_k and len(chunks) < top_k), duration_ms=duration_ms, ) return chunks def _is_allowed_for_project_scope( project: RegisteredProject, metadata: dict, registered_projects: list[RegisteredProject] | None = None, ) -> bool: """Return True when a chunk is target-project or not project-owned. Project-hinted retrieval should not let one registered project's corpus compete with another's. At the same time, unowned/global sources should remain eligible because shared docs and cross-project references can be genuinely useful. The registry gives us the boundary: if metadata matches a registered project and it is not the requested project, filter it out. """ if _metadata_matches_project(project, metadata): return True if registered_projects is None: try: registered_projects = load_project_registry() except Exception: return True for other in registered_projects: if other.project_id == project.project_id: continue if _metadata_matches_project(other, metadata): return False return True def _metadata_matches_project(project: RegisteredProject, metadata: dict) -> bool: path = _metadata_source_path(metadata) tags = _metadata_tags(metadata) for term in _project_scope_terms(project): if _path_matches_term(path, term) or term in tags: return True return False def _project_scope_terms(project: RegisteredProject) -> set[str]: terms = {project.project_id.lower()} terms.update(alias.lower() for alias in project.aliases) for source_ref in project.ingest_roots: normalized = source_ref.subpath.replace("\\", "/").strip("/").lower() if normalized: terms.add(normalized) terms.add(normalized.split("/")[-1]) return {term for term in terms if term} def _metadata_searchable(metadata: dict) -> str: return " ".join( [ str(metadata.get("source_file", "")).replace("\\", "/").lower(), str(metadata.get("title", "")).lower(), str(metadata.get("heading_path", "")).lower(), str(metadata.get("tags", "")).lower(), ] ) def _metadata_source_path(metadata: dict) -> str: return str(metadata.get("source_file", "")).replace("\\", "/").strip("/").lower() def _metadata_tags(metadata: dict) -> set[str]: raw_tags = metadata.get("tags", []) if isinstance(raw_tags, (list, tuple, set)): return {str(tag).strip().lower() for tag in raw_tags if str(tag).strip()} if isinstance(raw_tags, str): try: parsed = json.loads(raw_tags) except json.JSONDecodeError: parsed = [raw_tags] if isinstance(parsed, (list, tuple, set)): return {str(tag).strip().lower() for tag in parsed if str(tag).strip()} if isinstance(parsed, str) and parsed.strip(): return {parsed.strip().lower()} return set() def _path_matches_term(path: str, term: str) -> bool: normalized = term.replace("\\", "/").strip("/").lower() if not path or not normalized: return False if "/" in normalized: return path == normalized or path.startswith(f"{normalized}/") return normalized in set(path.split("/")) def _metadata_has_term(metadata: dict, term: str) -> bool: normalized = term.replace("\\", "/").strip("/").lower() if not normalized: return False if _path_matches_term(_metadata_source_path(metadata), normalized): return True if normalized in _metadata_tags(metadata): return True return re.search( rf"(? float: """Return a project-aware relevance multiplier for raw retrieval.""" hint_lower = project_hint.strip().lower() if not hint_lower: return 1.0 project = get_registered_project(project_hint) candidate_names = _project_scope_terms(project) if project is not None else {hint_lower} for candidate in candidate_names: if _metadata_has_term(metadata, candidate): return _config.settings.rank_project_match_boost return 1.0 def _query_match_boost(query: str, metadata: dict) -> float: """Boost chunks whose path/title/headings echo the query's high-signal terms.""" tokens = [ token for token in re.findall(r"[a-z0-9][a-z0-9_-]{2,}", query.lower()) if token not in _STOP_TOKENS ] if not tokens: return 1.0 searchable = " ".join( [ str(metadata.get("source_file", "")).lower(), str(metadata.get("title", "")).lower(), str(metadata.get("heading_path", "")).lower(), ] ) matches = sum(1 for token in set(tokens) if token in searchable) if matches <= 0: return 1.0 return min( 1.0 + matches * _config.settings.rank_query_token_step, _config.settings.rank_query_token_cap, ) def _path_signal_boost(metadata: dict) -> float: """Prefer current high-signal docs and gently down-rank archival noise.""" searchable = " ".join( [ str(metadata.get("source_file", "")).lower(), str(metadata.get("title", "")).lower(), str(metadata.get("heading_path", "")).lower(), ] ) multiplier = 1.0 if any(hint in searchable for hint in _LOW_SIGNAL_HINTS): multiplier *= _config.settings.rank_path_low_signal_penalty if any(hint in searchable for hint in _HIGH_SIGNAL_HINTS): multiplier *= _config.settings.rank_path_high_signal_boost return multiplier def _existing_chunk_ids(chunk_ids: list[str]) -> set[str]: """Filter out stale vector entries whose chunk rows no longer exist.""" if not chunk_ids: return set() placeholders = ", ".join("?" for _ in chunk_ids) with get_connection() as conn: rows = conn.execute( f"SELECT id FROM source_chunks WHERE id IN ({placeholders})", chunk_ids, ).fetchall() return {row["id"] for row in rows}