feat(retrieval): persist explicit chunk project ids

This commit is contained in:
2026-04-24 11:02:30 -04:00
parent f44a211497
commit c03022d864
12 changed files with 332 additions and 24 deletions

View File

@@ -0,0 +1,145 @@
"""Backfill explicit project_id into chunk and vector metadata.
Dry-run by default. The script derives ownership from the registered project
ingest roots and updates both SQLite source_chunks.metadata and Chroma vector
metadata only when --apply is provided.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from atocore.models.database import get_connection # noqa: E402
from atocore.projects.registry import list_registered_projects # noqa: E402
from atocore.retrieval.vector_store import get_vector_store # noqa: E402
def _load_project_roots() -> list[tuple[str, Path]]:
roots: list[tuple[str, Path]] = []
for project in list_registered_projects():
project_id = project["id"]
for root in project.get("ingest_roots", []):
root_path = root.get("path")
if root_path:
roots.append((project_id, Path(root_path).resolve(strict=False)))
roots.sort(key=lambda item: len(str(item[1])), reverse=True)
return roots
def _derive_project_id(file_path: str, roots: list[tuple[str, Path]]) -> str:
if not file_path:
return ""
doc_path = Path(file_path).resolve(strict=False)
for project_id, root_path in roots:
try:
doc_path.relative_to(root_path)
except ValueError:
continue
return project_id
return ""
def _decode_metadata(raw: str | None) -> dict:
if not raw:
return {}
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
def _chunk_rows() -> list[dict]:
with get_connection() as conn:
rows = conn.execute(
"""
SELECT
sc.id AS chunk_id,
sc.metadata AS chunk_metadata,
sd.file_path AS file_path
FROM source_chunks sc
JOIN source_documents sd ON sd.id = sc.document_id
ORDER BY sd.file_path, sc.chunk_index
"""
).fetchall()
return [dict(row) for row in rows]
def backfill(apply: bool = False, project_filter: str = "") -> dict:
roots = _load_project_roots()
rows = _chunk_rows()
updates: list[tuple[str, str, dict]] = []
by_project: dict[str, int] = {}
skipped_unowned = 0
for row in rows:
project_id = _derive_project_id(row["file_path"], roots)
if project_filter and project_id != project_filter:
continue
if not project_id:
skipped_unowned += 1
continue
metadata = _decode_metadata(row["chunk_metadata"])
if metadata.get("project_id") == project_id:
continue
metadata["project_id"] = project_id
updates.append((row["chunk_id"], project_id, metadata))
by_project[project_id] = by_project.get(project_id, 0) + 1
if apply and updates:
vector_store = get_vector_store()
chunk_ids = [chunk_id for chunk_id, _, _ in updates]
vector_payload = vector_store.get_metadatas(chunk_ids)
existing_vector_metadata = {
chunk_id: metadata or {}
for chunk_id, metadata in zip(
vector_payload.get("ids", []),
vector_payload.get("metadatas", []),
strict=False,
)
}
vector_metadatas = []
for chunk_id, project_id, chunk_metadata in updates:
vector_metadata = dict(existing_vector_metadata.get(chunk_id) or {})
if not vector_metadata:
vector_metadata = dict(chunk_metadata)
vector_metadata["project_id"] = project_id
vector_metadatas.append(vector_metadata)
with get_connection() as conn:
conn.executemany(
"UPDATE source_chunks SET metadata = ? WHERE id = ?",
[
(json.dumps(metadata, ensure_ascii=True), chunk_id)
for chunk_id, _, metadata in updates
],
)
vector_store.update_metadatas(chunk_ids, vector_metadatas)
return {
"apply": apply,
"total_chunks": len(rows),
"updates": len(updates),
"skipped_unowned": skipped_unowned,
"by_project": dict(sorted(by_project.items())),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--apply", action="store_true", help="write SQLite and Chroma metadata updates")
parser.add_argument("--project", default="", help="optional canonical project_id filter")
args = parser.parse_args()
payload = backfill(apply=args.apply, project_filter=args.project.strip())
print(json.dumps(payload, indent=2, ensure_ascii=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())