fix(retrieval): preserve project ids across unscoped ingest
This commit is contained in:
@@ -9,123 +9,145 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from atocore.models.database import get_connection # noqa: E402
|
||||
from atocore.projects.registry import list_registered_projects # noqa: E402
|
||||
from atocore.projects.registry import derive_project_id_for_path # noqa: E402
|
||||
from atocore.retrieval.vector_store import get_vector_store # noqa: E402
|
||||
|
||||
|
||||
def _load_project_roots() -> list[tuple[str, Path]]:
|
||||
roots: list[tuple[str, Path]] = []
|
||||
for project in list_registered_projects():
|
||||
project_id = project["id"]
|
||||
for root in project.get("ingest_roots", []):
|
||||
root_path = root.get("path")
|
||||
if root_path:
|
||||
roots.append((project_id, Path(root_path).resolve(strict=False)))
|
||||
roots.sort(key=lambda item: len(str(item[1])), reverse=True)
|
||||
return roots
|
||||
DEFAULT_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _derive_project_id(file_path: str, roots: list[tuple[str, Path]]) -> str:
|
||||
if not file_path:
|
||||
return ""
|
||||
doc_path = Path(file_path).resolve(strict=False)
|
||||
for project_id, root_path in roots:
|
||||
try:
|
||||
doc_path.relative_to(root_path)
|
||||
except ValueError:
|
||||
continue
|
||||
return project_id
|
||||
return ""
|
||||
|
||||
|
||||
def _decode_metadata(raw: str | None) -> dict:
|
||||
def _decode_metadata(raw: str | None) -> dict | None:
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _chunk_rows() -> list[dict]:
|
||||
with get_connection() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
sc.id AS chunk_id,
|
||||
sc.metadata AS chunk_metadata,
|
||||
sd.file_path AS file_path
|
||||
FROM source_chunks sc
|
||||
JOIN source_documents sd ON sd.id = sc.document_id
|
||||
ORDER BY sd.file_path, sc.chunk_index
|
||||
"""
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
def _chunk_rows() -> tuple[list[dict], str]:
|
||||
try:
|
||||
with get_connection() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
sc.id AS chunk_id,
|
||||
sc.metadata AS chunk_metadata,
|
||||
sd.file_path AS file_path
|
||||
FROM source_chunks sc
|
||||
JOIN source_documents sd ON sd.id = sc.document_id
|
||||
ORDER BY sd.file_path, sc.chunk_index
|
||||
"""
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "source_chunks" in str(exc) or "source_documents" in str(exc):
|
||||
return [], f"missing ingestion tables: {exc}"
|
||||
raise
|
||||
return [dict(row) for row in rows], ""
|
||||
|
||||
|
||||
def backfill(apply: bool = False, project_filter: str = "") -> dict:
|
||||
roots = _load_project_roots()
|
||||
rows = _chunk_rows()
|
||||
def _batches(items: list, batch_size: int) -> list[list]:
|
||||
return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
|
||||
def backfill(
|
||||
apply: bool = False,
|
||||
project_filter: str = "",
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
require_chroma_snapshot: bool = False,
|
||||
) -> dict:
|
||||
rows, db_warning = _chunk_rows()
|
||||
updates: list[tuple[str, str, dict]] = []
|
||||
by_project: dict[str, int] = {}
|
||||
skipped_unowned = 0
|
||||
already_tagged = 0
|
||||
malformed_metadata = 0
|
||||
|
||||
for row in rows:
|
||||
project_id = _derive_project_id(row["file_path"], roots)
|
||||
project_id = derive_project_id_for_path(row["file_path"])
|
||||
if project_filter and project_id != project_filter:
|
||||
continue
|
||||
if not project_id:
|
||||
skipped_unowned += 1
|
||||
continue
|
||||
metadata = _decode_metadata(row["chunk_metadata"])
|
||||
if metadata is None:
|
||||
malformed_metadata += 1
|
||||
continue
|
||||
if metadata.get("project_id") == project_id:
|
||||
already_tagged += 1
|
||||
continue
|
||||
metadata["project_id"] = project_id
|
||||
updates.append((row["chunk_id"], project_id, metadata))
|
||||
by_project[project_id] = by_project.get(project_id, 0) + 1
|
||||
|
||||
missing_vectors: list[str] = []
|
||||
applied_updates = 0
|
||||
if apply and updates:
|
||||
if not require_chroma_snapshot:
|
||||
raise ValueError(
|
||||
"--apply requires --chroma-snapshot-confirmed after taking a Chroma backup"
|
||||
)
|
||||
vector_store = get_vector_store()
|
||||
chunk_ids = [chunk_id for chunk_id, _, _ in updates]
|
||||
vector_payload = vector_store.get_metadatas(chunk_ids)
|
||||
existing_vector_metadata = {
|
||||
chunk_id: metadata or {}
|
||||
for chunk_id, metadata in zip(
|
||||
vector_payload.get("ids", []),
|
||||
vector_payload.get("metadatas", []),
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
vector_metadatas = []
|
||||
for chunk_id, project_id, chunk_metadata in updates:
|
||||
vector_metadata = dict(existing_vector_metadata.get(chunk_id) or {})
|
||||
if not vector_metadata:
|
||||
vector_metadata = dict(chunk_metadata)
|
||||
vector_metadata["project_id"] = project_id
|
||||
vector_metadatas.append(vector_metadata)
|
||||
for batch in _batches(updates, max(1, batch_size)):
|
||||
chunk_ids = [chunk_id for chunk_id, _, _ in batch]
|
||||
vector_payload = vector_store.get_metadatas(chunk_ids)
|
||||
existing_vector_metadata = {
|
||||
chunk_id: metadata
|
||||
for chunk_id, metadata in zip(
|
||||
vector_payload.get("ids", []),
|
||||
vector_payload.get("metadatas", []),
|
||||
strict=False,
|
||||
)
|
||||
if isinstance(metadata, dict)
|
||||
}
|
||||
|
||||
with get_connection() as conn:
|
||||
conn.executemany(
|
||||
"UPDATE source_chunks SET metadata = ? WHERE id = ?",
|
||||
[
|
||||
(json.dumps(metadata, ensure_ascii=True), chunk_id)
|
||||
for chunk_id, _, metadata in updates
|
||||
],
|
||||
)
|
||||
vector_store.update_metadatas(chunk_ids, vector_metadatas)
|
||||
vector_ids = []
|
||||
vector_metadatas = []
|
||||
sql_updates = []
|
||||
for chunk_id, project_id, chunk_metadata in batch:
|
||||
vector_metadata = existing_vector_metadata.get(chunk_id)
|
||||
if vector_metadata is None:
|
||||
missing_vectors.append(chunk_id)
|
||||
continue
|
||||
vector_metadata = dict(vector_metadata)
|
||||
vector_metadata["project_id"] = project_id
|
||||
vector_ids.append(chunk_id)
|
||||
vector_metadatas.append(vector_metadata)
|
||||
sql_updates.append((json.dumps(chunk_metadata, ensure_ascii=True), chunk_id))
|
||||
|
||||
if not vector_ids:
|
||||
continue
|
||||
|
||||
vector_store.update_metadatas(vector_ids, vector_metadatas)
|
||||
with get_connection() as conn:
|
||||
cursor = conn.executemany(
|
||||
"UPDATE source_chunks SET metadata = ? WHERE id = ?",
|
||||
sql_updates,
|
||||
)
|
||||
if cursor.rowcount != len(sql_updates):
|
||||
raise RuntimeError(
|
||||
f"SQLite rowcount mismatch: {cursor.rowcount} != {len(sql_updates)}"
|
||||
)
|
||||
applied_updates += len(sql_updates)
|
||||
|
||||
return {
|
||||
"apply": apply,
|
||||
"total_chunks": len(rows),
|
||||
"updates": len(updates),
|
||||
"applied_updates": applied_updates,
|
||||
"already_tagged": already_tagged,
|
||||
"skipped_unowned": skipped_unowned,
|
||||
"malformed_metadata": malformed_metadata,
|
||||
"missing_vectors": len(missing_vectors),
|
||||
"db_warning": db_warning,
|
||||
"by_project": dict(sorted(by_project.items())),
|
||||
}
|
||||
|
||||
@@ -134,9 +156,20 @@ def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--apply", action="store_true", help="write SQLite and Chroma metadata updates")
|
||||
parser.add_argument("--project", default="", help="optional canonical project_id filter")
|
||||
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
|
||||
parser.add_argument(
|
||||
"--chroma-snapshot-confirmed",
|
||||
action="store_true",
|
||||
help="required with --apply; confirms a Chroma snapshot exists",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
payload = backfill(apply=args.apply, project_filter=args.project.strip())
|
||||
payload = backfill(
|
||||
apply=args.apply,
|
||||
project_filter=args.project.strip(),
|
||||
batch_size=args.batch_size,
|
||||
require_chroma_snapshot=args.chroma_snapshot_confirmed,
|
||||
)
|
||||
print(json.dumps(payload, indent=2, ensure_ascii=True))
|
||||
return 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user