fix(retrieval): enforce project-scoped context boundaries

This commit is contained in:
2026-04-24 10:46:56 -04:00
parent c53e61eb67
commit c7212900b0
11 changed files with 737 additions and 68 deletions

View File

@@ -1,5 +1,6 @@
"""Retrieval: query to ranked chunks."""
import json
import re
import time
from dataclasses import dataclass
@@ -7,7 +8,7 @@ from dataclasses import dataclass
import atocore.config as _config
from atocore.models.database import get_connection
from atocore.observability.logger import get_logger
from atocore.projects.registry import get_registered_project
from atocore.projects.registry import RegisteredProject, get_registered_project, load_project_registry
from atocore.retrieval.embeddings import embed_query
from atocore.retrieval.vector_store import get_vector_store
@@ -83,6 +84,19 @@ def retrieve(
"""Retrieve the most relevant chunks for a query."""
top_k = top_k or _config.settings.context_top_k
start = time.time()
scoped_project = get_registered_project(project_hint) if project_hint else None
scope_filter_enabled = bool(scoped_project and _config.settings.rank_project_scope_filter)
registered_projects = None
query_top_k = top_k
if scope_filter_enabled:
query_top_k = max(
top_k,
top_k * max(1, _config.settings.rank_project_scope_candidate_multiplier),
)
try:
registered_projects = load_project_registry()
except Exception:
registered_projects = None
query_embedding = embed_query(query)
store = get_vector_store()
@@ -101,11 +115,12 @@ def retrieve(
results = store.query(
query_embedding=query_embedding,
top_k=top_k,
top_k=query_top_k,
where=where,
)
chunks = []
raw_result_count = len(results["ids"][0]) if results and results["ids"] and results["ids"][0] else 0
if results and results["ids"] and results["ids"][0]:
existing_ids = _existing_chunk_ids(results["ids"][0])
for i, chunk_id in enumerate(results["ids"][0]):
@@ -117,6 +132,13 @@ def retrieve(
meta = results["metadatas"][0][i] if results["metadatas"] else {}
content = results["documents"][0][i] if results["documents"] else ""
if scope_filter_enabled and not _is_allowed_for_project_scope(
scoped_project,
meta,
registered_projects,
):
continue
score *= _query_match_boost(query, meta)
score *= _path_signal_boost(meta)
if project_hint:
@@ -137,42 +159,139 @@ def retrieve(
duration_ms = int((time.time() - start) * 1000)
chunks.sort(key=lambda chunk: chunk.score, reverse=True)
post_filter_count = len(chunks)
chunks = chunks[:top_k]
log.info(
"retrieval_done",
query=query[:100],
top_k=top_k,
query_top_k=query_top_k,
raw_results_count=raw_result_count,
post_filter_count=post_filter_count,
results_count=len(chunks),
post_filter_dropped=max(0, raw_result_count - post_filter_count),
underfilled=bool(raw_result_count >= query_top_k and len(chunks) < top_k),
duration_ms=duration_ms,
)
return chunks
def _is_allowed_for_project_scope(
project: RegisteredProject,
metadata: dict,
registered_projects: list[RegisteredProject] | None = None,
) -> bool:
"""Return True when a chunk is target-project or not project-owned.
Project-hinted retrieval should not let one registered project's corpus
compete with another's. At the same time, unowned/global sources should
remain eligible because shared docs and cross-project references can be
genuinely useful. The registry gives us the boundary: if metadata matches
a registered project and it is not the requested project, filter it out.
"""
if _metadata_matches_project(project, metadata):
return True
if registered_projects is None:
try:
registered_projects = load_project_registry()
except Exception:
return True
for other in registered_projects:
if other.project_id == project.project_id:
continue
if _metadata_matches_project(other, metadata):
return False
return True
def _metadata_matches_project(project: RegisteredProject, metadata: dict) -> bool:
path = _metadata_source_path(metadata)
tags = _metadata_tags(metadata)
for term in _project_scope_terms(project):
if _path_matches_term(path, term) or term in tags:
return True
return False
def _project_scope_terms(project: RegisteredProject) -> set[str]:
terms = {project.project_id.lower()}
terms.update(alias.lower() for alias in project.aliases)
for source_ref in project.ingest_roots:
normalized = source_ref.subpath.replace("\\", "/").strip("/").lower()
if normalized:
terms.add(normalized)
terms.add(normalized.split("/")[-1])
return {term for term in terms if term}
def _metadata_searchable(metadata: dict) -> str:
return " ".join(
[
str(metadata.get("source_file", "")).replace("\\", "/").lower(),
str(metadata.get("title", "")).lower(),
str(metadata.get("heading_path", "")).lower(),
str(metadata.get("tags", "")).lower(),
]
)
def _metadata_source_path(metadata: dict) -> str:
return str(metadata.get("source_file", "")).replace("\\", "/").strip("/").lower()
def _metadata_tags(metadata: dict) -> set[str]:
raw_tags = metadata.get("tags", [])
if isinstance(raw_tags, (list, tuple, set)):
return {str(tag).strip().lower() for tag in raw_tags if str(tag).strip()}
if isinstance(raw_tags, str):
try:
parsed = json.loads(raw_tags)
except json.JSONDecodeError:
parsed = [raw_tags]
if isinstance(parsed, (list, tuple, set)):
return {str(tag).strip().lower() for tag in parsed if str(tag).strip()}
if isinstance(parsed, str) and parsed.strip():
return {parsed.strip().lower()}
return set()
def _path_matches_term(path: str, term: str) -> bool:
normalized = term.replace("\\", "/").strip("/").lower()
if not path or not normalized:
return False
if "/" in normalized:
return path == normalized or path.startswith(f"{normalized}/")
return normalized in set(path.split("/"))
def _metadata_has_term(metadata: dict, term: str) -> bool:
normalized = term.replace("\\", "/").strip("/").lower()
if not normalized:
return False
if _path_matches_term(_metadata_source_path(metadata), normalized):
return True
if normalized in _metadata_tags(metadata):
return True
return re.search(
rf"(?<![a-z0-9]){re.escape(normalized)}(?![a-z0-9])",
_metadata_searchable(metadata),
) is not None
def _project_match_boost(project_hint: str, metadata: dict) -> float:
"""Return a project-aware relevance multiplier for raw retrieval."""
hint_lower = project_hint.strip().lower()
if not hint_lower:
return 1.0
source_file = str(metadata.get("source_file", "")).lower()
title = str(metadata.get("title", "")).lower()
tags = str(metadata.get("tags", "")).lower()
searchable = " ".join([source_file, title, tags])
project = get_registered_project(project_hint)
candidate_names = {hint_lower}
if project is not None:
candidate_names.add(project.project_id.lower())
candidate_names.update(alias.lower() for alias in project.aliases)
candidate_names.update(
source_ref.subpath.replace("\\", "/").strip("/").split("/")[-1].lower()
for source_ref in project.ingest_roots
if source_ref.subpath.strip("/\\")
)
candidate_names = _project_scope_terms(project) if project is not None else {hint_lower}
for candidate in candidate_names:
if candidate and candidate in searchable:
if _metadata_has_term(metadata, candidate):
return _config.settings.rank_project_match_boost
return 1.0