Files
ATOCore/scripts/backfill_chunk_project_ids.py

179 lines
6.1 KiB
Python

"""Backfill explicit project_id into chunk and vector metadata.
Dry-run by default. The script derives ownership from the registered project
ingest roots and updates both SQLite source_chunks.metadata and Chroma vector
metadata only when --apply is provided.
"""
from __future__ import annotations
import argparse
import json
import sqlite3
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from atocore.models.database import get_connection # noqa: E402
from atocore.projects.registry import derive_project_id_for_path # noqa: E402
from atocore.retrieval.vector_store import get_vector_store # noqa: E402
DEFAULT_BATCH_SIZE = 500
def _decode_metadata(raw: str | None) -> dict | None:
if not raw:
return {}
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _chunk_rows() -> tuple[list[dict], str]:
try:
with get_connection() as conn:
rows = conn.execute(
"""
SELECT
sc.id AS chunk_id,
sc.metadata AS chunk_metadata,
sd.file_path AS file_path
FROM source_chunks sc
JOIN source_documents sd ON sd.id = sc.document_id
ORDER BY sd.file_path, sc.chunk_index
"""
).fetchall()
except sqlite3.OperationalError as exc:
if "source_chunks" in str(exc) or "source_documents" in str(exc):
return [], f"missing ingestion tables: {exc}"
raise
return [dict(row) for row in rows], ""
def _batches(items: list, batch_size: int) -> list[list]:
return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
def backfill(
apply: bool = False,
project_filter: str = "",
batch_size: int = DEFAULT_BATCH_SIZE,
require_chroma_snapshot: bool = False,
) -> dict:
rows, db_warning = _chunk_rows()
updates: list[tuple[str, str, dict]] = []
by_project: dict[str, int] = {}
skipped_unowned = 0
already_tagged = 0
malformed_metadata = 0
for row in rows:
project_id = derive_project_id_for_path(row["file_path"])
if project_filter and project_id != project_filter:
continue
if not project_id:
skipped_unowned += 1
continue
metadata = _decode_metadata(row["chunk_metadata"])
if metadata is None:
malformed_metadata += 1
continue
if metadata.get("project_id") == project_id:
already_tagged += 1
continue
metadata["project_id"] = project_id
updates.append((row["chunk_id"], project_id, metadata))
by_project[project_id] = by_project.get(project_id, 0) + 1
missing_vectors: list[str] = []
applied_updates = 0
if apply and updates:
if not require_chroma_snapshot:
raise ValueError(
"--apply requires --chroma-snapshot-confirmed after taking a Chroma backup"
)
vector_store = get_vector_store()
for batch in _batches(updates, max(1, batch_size)):
chunk_ids = [chunk_id for chunk_id, _, _ in batch]
vector_payload = vector_store.get_metadatas(chunk_ids)
existing_vector_metadata = {
chunk_id: metadata
for chunk_id, metadata in zip(
vector_payload.get("ids", []),
vector_payload.get("metadatas", []),
strict=False,
)
if isinstance(metadata, dict)
}
vector_ids = []
vector_metadatas = []
sql_updates = []
for chunk_id, project_id, chunk_metadata in batch:
vector_metadata = existing_vector_metadata.get(chunk_id)
if vector_metadata is None:
missing_vectors.append(chunk_id)
continue
vector_metadata = dict(vector_metadata)
vector_metadata["project_id"] = project_id
vector_ids.append(chunk_id)
vector_metadatas.append(vector_metadata)
sql_updates.append((json.dumps(chunk_metadata, ensure_ascii=True), chunk_id))
if not vector_ids:
continue
vector_store.update_metadatas(vector_ids, vector_metadatas)
with get_connection() as conn:
cursor = conn.executemany(
"UPDATE source_chunks SET metadata = ? WHERE id = ?",
sql_updates,
)
if cursor.rowcount != len(sql_updates):
raise RuntimeError(
f"SQLite rowcount mismatch: {cursor.rowcount} != {len(sql_updates)}"
)
applied_updates += len(sql_updates)
return {
"apply": apply,
"total_chunks": len(rows),
"updates": len(updates),
"applied_updates": applied_updates,
"already_tagged": already_tagged,
"skipped_unowned": skipped_unowned,
"malformed_metadata": malformed_metadata,
"missing_vectors": len(missing_vectors),
"db_warning": db_warning,
"by_project": dict(sorted(by_project.items())),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--apply", action="store_true", help="write SQLite and Chroma metadata updates")
parser.add_argument("--project", default="", help="optional canonical project_id filter")
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
parser.add_argument(
"--chroma-snapshot-confirmed",
action="store_true",
help="required with --apply; confirms a Chroma snapshot exists",
)
args = parser.parse_args()
payload = backfill(
apply=args.apply,
project_filter=args.project.strip(),
batch_size=args.batch_size,
require_chroma_snapshot=args.chroma_snapshot_confirmed,
)
print(json.dumps(payload, indent=2, ensure_ascii=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())