fix(retrieval): enforce project-scoped context boundaries
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Retrieval: query to ranked chunks."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -7,7 +8,7 @@ from dataclasses import dataclass
|
||||
import atocore.config as _config
|
||||
from atocore.models.database import get_connection
|
||||
from atocore.observability.logger import get_logger
|
||||
from atocore.projects.registry import get_registered_project
|
||||
from atocore.projects.registry import RegisteredProject, get_registered_project, load_project_registry
|
||||
from atocore.retrieval.embeddings import embed_query
|
||||
from atocore.retrieval.vector_store import get_vector_store
|
||||
|
||||
@@ -83,6 +84,19 @@ def retrieve(
|
||||
"""Retrieve the most relevant chunks for a query."""
|
||||
top_k = top_k or _config.settings.context_top_k
|
||||
start = time.time()
|
||||
scoped_project = get_registered_project(project_hint) if project_hint else None
|
||||
scope_filter_enabled = bool(scoped_project and _config.settings.rank_project_scope_filter)
|
||||
registered_projects = None
|
||||
query_top_k = top_k
|
||||
if scope_filter_enabled:
|
||||
query_top_k = max(
|
||||
top_k,
|
||||
top_k * max(1, _config.settings.rank_project_scope_candidate_multiplier),
|
||||
)
|
||||
try:
|
||||
registered_projects = load_project_registry()
|
||||
except Exception:
|
||||
registered_projects = None
|
||||
|
||||
query_embedding = embed_query(query)
|
||||
store = get_vector_store()
|
||||
@@ -101,11 +115,12 @@ def retrieve(
|
||||
|
||||
results = store.query(
|
||||
query_embedding=query_embedding,
|
||||
top_k=top_k,
|
||||
top_k=query_top_k,
|
||||
where=where,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
raw_result_count = len(results["ids"][0]) if results and results["ids"] and results["ids"][0] else 0
|
||||
if results and results["ids"] and results["ids"][0]:
|
||||
existing_ids = _existing_chunk_ids(results["ids"][0])
|
||||
for i, chunk_id in enumerate(results["ids"][0]):
|
||||
@@ -117,6 +132,13 @@ def retrieve(
|
||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
content = results["documents"][0][i] if results["documents"] else ""
|
||||
|
||||
if scope_filter_enabled and not _is_allowed_for_project_scope(
|
||||
scoped_project,
|
||||
meta,
|
||||
registered_projects,
|
||||
):
|
||||
continue
|
||||
|
||||
score *= _query_match_boost(query, meta)
|
||||
score *= _path_signal_boost(meta)
|
||||
if project_hint:
|
||||
@@ -137,42 +159,139 @@ def retrieve(
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
chunks.sort(key=lambda chunk: chunk.score, reverse=True)
|
||||
post_filter_count = len(chunks)
|
||||
chunks = chunks[:top_k]
|
||||
|
||||
log.info(
|
||||
"retrieval_done",
|
||||
query=query[:100],
|
||||
top_k=top_k,
|
||||
query_top_k=query_top_k,
|
||||
raw_results_count=raw_result_count,
|
||||
post_filter_count=post_filter_count,
|
||||
results_count=len(chunks),
|
||||
post_filter_dropped=max(0, raw_result_count - post_filter_count),
|
||||
underfilled=bool(raw_result_count >= query_top_k and len(chunks) < top_k),
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _is_allowed_for_project_scope(
|
||||
project: RegisteredProject,
|
||||
metadata: dict,
|
||||
registered_projects: list[RegisteredProject] | None = None,
|
||||
) -> bool:
|
||||
"""Return True when a chunk is target-project or not project-owned.
|
||||
|
||||
Project-hinted retrieval should not let one registered project's corpus
|
||||
compete with another's. At the same time, unowned/global sources should
|
||||
remain eligible because shared docs and cross-project references can be
|
||||
genuinely useful. The registry gives us the boundary: if metadata matches
|
||||
a registered project and it is not the requested project, filter it out.
|
||||
"""
|
||||
if _metadata_matches_project(project, metadata):
|
||||
return True
|
||||
|
||||
if registered_projects is None:
|
||||
try:
|
||||
registered_projects = load_project_registry()
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
for other in registered_projects:
|
||||
if other.project_id == project.project_id:
|
||||
continue
|
||||
if _metadata_matches_project(other, metadata):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _metadata_matches_project(project: RegisteredProject, metadata: dict) -> bool:
|
||||
path = _metadata_source_path(metadata)
|
||||
tags = _metadata_tags(metadata)
|
||||
for term in _project_scope_terms(project):
|
||||
if _path_matches_term(path, term) or term in tags:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _project_scope_terms(project: RegisteredProject) -> set[str]:
|
||||
terms = {project.project_id.lower()}
|
||||
terms.update(alias.lower() for alias in project.aliases)
|
||||
for source_ref in project.ingest_roots:
|
||||
normalized = source_ref.subpath.replace("\\", "/").strip("/").lower()
|
||||
if normalized:
|
||||
terms.add(normalized)
|
||||
terms.add(normalized.split("/")[-1])
|
||||
return {term for term in terms if term}
|
||||
|
||||
|
||||
def _metadata_searchable(metadata: dict) -> str:
|
||||
return " ".join(
|
||||
[
|
||||
str(metadata.get("source_file", "")).replace("\\", "/").lower(),
|
||||
str(metadata.get("title", "")).lower(),
|
||||
str(metadata.get("heading_path", "")).lower(),
|
||||
str(metadata.get("tags", "")).lower(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _metadata_source_path(metadata: dict) -> str:
|
||||
return str(metadata.get("source_file", "")).replace("\\", "/").strip("/").lower()
|
||||
|
||||
|
||||
def _metadata_tags(metadata: dict) -> set[str]:
|
||||
raw_tags = metadata.get("tags", [])
|
||||
if isinstance(raw_tags, (list, tuple, set)):
|
||||
return {str(tag).strip().lower() for tag in raw_tags if str(tag).strip()}
|
||||
if isinstance(raw_tags, str):
|
||||
try:
|
||||
parsed = json.loads(raw_tags)
|
||||
except json.JSONDecodeError:
|
||||
parsed = [raw_tags]
|
||||
if isinstance(parsed, (list, tuple, set)):
|
||||
return {str(tag).strip().lower() for tag in parsed if str(tag).strip()}
|
||||
if isinstance(parsed, str) and parsed.strip():
|
||||
return {parsed.strip().lower()}
|
||||
return set()
|
||||
|
||||
|
||||
def _path_matches_term(path: str, term: str) -> bool:
|
||||
normalized = term.replace("\\", "/").strip("/").lower()
|
||||
if not path or not normalized:
|
||||
return False
|
||||
if "/" in normalized:
|
||||
return path == normalized or path.startswith(f"{normalized}/")
|
||||
return normalized in set(path.split("/"))
|
||||
|
||||
|
||||
def _metadata_has_term(metadata: dict, term: str) -> bool:
|
||||
normalized = term.replace("\\", "/").strip("/").lower()
|
||||
if not normalized:
|
||||
return False
|
||||
if _path_matches_term(_metadata_source_path(metadata), normalized):
|
||||
return True
|
||||
if normalized in _metadata_tags(metadata):
|
||||
return True
|
||||
return re.search(
|
||||
rf"(?<![a-z0-9]){re.escape(normalized)}(?![a-z0-9])",
|
||||
_metadata_searchable(metadata),
|
||||
) is not None
|
||||
|
||||
|
||||
def _project_match_boost(project_hint: str, metadata: dict) -> float:
|
||||
"""Return a project-aware relevance multiplier for raw retrieval."""
|
||||
hint_lower = project_hint.strip().lower()
|
||||
if not hint_lower:
|
||||
return 1.0
|
||||
|
||||
source_file = str(metadata.get("source_file", "")).lower()
|
||||
title = str(metadata.get("title", "")).lower()
|
||||
tags = str(metadata.get("tags", "")).lower()
|
||||
searchable = " ".join([source_file, title, tags])
|
||||
|
||||
project = get_registered_project(project_hint)
|
||||
candidate_names = {hint_lower}
|
||||
if project is not None:
|
||||
candidate_names.add(project.project_id.lower())
|
||||
candidate_names.update(alias.lower() for alias in project.aliases)
|
||||
candidate_names.update(
|
||||
source_ref.subpath.replace("\\", "/").strip("/").split("/")[-1].lower()
|
||||
for source_ref in project.ingest_roots
|
||||
if source_ref.subpath.strip("/\\")
|
||||
)
|
||||
|
||||
candidate_names = _project_scope_terms(project) if project is not None else {hint_lower}
|
||||
for candidate in candidate_names:
|
||||
if candidate and candidate in searchable:
|
||||
if _metadata_has_term(metadata, candidate):
|
||||
return _config.settings.rank_project_match_boost
|
||||
|
||||
return 1.0
|
||||
|
||||
Reference in New Issue
Block a user