Files
ATOCore/scripts/memory_dedup.py

279 lines
10 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Phase 7A — semantic memory dedup detector.
Finds clusters of near-duplicate active memories and writes merge-
candidate proposals for human review in the triage UI.
Algorithm:
1. Fetch active memories via HTTP
2. Group by (project, memory_type) cross-bucket merges are deferred
to Phase 7B contradiction flow
3. Within each group, embed contents via atocore.retrieval.embeddings
4. Greedy transitive cluster at similarity >= threshold
5. For each cluster of size >= 2, ask claude-p to draft unified content
6. POST the proposal to /admin/memory/merge-candidates/create (server-
side dedupes by the sorted memory-id set, so re-runs don't double-
create)
Host-side because claude CLI lives on Dalidou, not the container. Reuses
the same PYTHONPATH=src pattern as scripts/graduate_memories.py for
atocore imports (embeddings, similarity, prompt module).
Usage:
python3 scripts/memory_dedup.py --base-url http://127.0.0.1:8100 \\
--similarity-threshold 0.88 --max-batch 50
Threshold conventions (see Phase 7 doc):
0.88 interactive / default balanced precision/recall
0.90 nightly cron tight, only near-duplicates
0.85 weekly cron deeper cleanup
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
import tempfile
import time
import urllib.error
import urllib.request
from collections import defaultdict
from typing import Any
# Make src/ importable — same pattern as graduate_memories.py
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_SRC_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, "..", "src"))
if _SRC_DIR not in sys.path:
sys.path.insert(0, _SRC_DIR)
from atocore.memory._dedup_prompt import ( # noqa: E402
DEDUP_PROMPT_VERSION,
SYSTEM_PROMPT,
build_user_message,
normalize_merge_verdict,
parse_merge_verdict,
)
from atocore.memory.similarity import cluster_by_threshold # noqa: E402
DEFAULT_BASE_URL = os.environ.get("ATOCORE_BASE_URL", "http://127.0.0.1:8100")
DEFAULT_MODEL = os.environ.get("ATOCORE_DEDUP_MODEL", "sonnet")
DEFAULT_TIMEOUT_S = float(os.environ.get("ATOCORE_DEDUP_TIMEOUT_S", "60"))
_sandbox_cwd = None
def get_sandbox_cwd() -> str:
global _sandbox_cwd
if _sandbox_cwd is None:
_sandbox_cwd = tempfile.mkdtemp(prefix="ato-dedup-")
return _sandbox_cwd
def api_get(base_url: str, path: str) -> dict:
req = urllib.request.Request(f"{base_url}{path}")
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
def api_post(base_url: str, path: str, body: dict | None = None) -> dict:
data = json.dumps(body or {}).encode("utf-8")
req = urllib.request.Request(
f"{base_url}{path}", method="POST",
headers={"Content-Type": "application/json"}, data=data,
)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
def call_claude(system_prompt: str, user_message: str, model: str, timeout_s: float) -> tuple[str | None, str | None]:
"""Shared CLI caller with retry + stderr capture (mirrors auto_triage)."""
if not shutil.which("claude"):
return None, "claude CLI not available"
args = [
"claude", "-p",
"--model", model,
"--append-system-prompt", system_prompt,
"--disable-slash-commands",
user_message,
]
last_error = ""
for attempt in range(3):
if attempt > 0:
time.sleep(2 ** attempt)
try:
completed = subprocess.run(
args, capture_output=True, text=True,
timeout=timeout_s, cwd=get_sandbox_cwd(),
encoding="utf-8", errors="replace",
)
except subprocess.TimeoutExpired:
last_error = f"{model} timed out"
continue
except Exception as exc:
last_error = f"subprocess error: {exc}"
continue
if completed.returncode == 0:
return (completed.stdout or "").strip(), None
stderr = (completed.stderr or "").strip()[:200]
last_error = f"{model} exit {completed.returncode}: {stderr}" if stderr else f"{model} exit {completed.returncode}"
return None, last_error
def fetch_active_memories(base_url: str, project: str | None) -> list[dict]:
# The /memory endpoint with active_only=true returns active memories.
# Graduated memories are exempt from dedup — they're frozen pointers
# to entities. Filter them out on the client side.
params = "active_only=true&limit=2000"
if project:
params += f"&project={urllib.request.quote(project)}"
try:
result = api_get(base_url, f"/memory?{params}")
except Exception as e:
print(f"ERROR: could not fetch memories: {e}", file=sys.stderr)
return []
mems = result.get("memories", [])
return [m for m in mems if (m.get("status") or "active") == "active"]
def group_memories(mems: list[dict]) -> dict[tuple[str, str], list[dict]]:
"""Bucket by (project, memory_type). Empty project is its own bucket."""
buckets: dict[tuple[str, str], list[dict]] = defaultdict(list)
for m in mems:
key = ((m.get("project") or "").strip().lower(), (m.get("memory_type") or "").strip().lower())
buckets[key].append(m)
return buckets
def draft_merge(sources: list[dict], model: str, timeout_s: float) -> dict[str, Any] | None:
user_msg = build_user_message(sources)
raw, err = call_claude(SYSTEM_PROMPT, user_msg, model, timeout_s)
if err:
print(f" WARN: claude call failed: {err}", file=sys.stderr)
return None
parsed = parse_merge_verdict(raw or "")
if parsed is None:
print(f" WARN: could not parse verdict: {(raw or '')[:200]}", file=sys.stderr)
return None
return normalize_merge_verdict(parsed)
def submit_candidate(
base_url: str,
memory_ids: list[str],
similarity: float,
verdict: dict[str, Any],
dry_run: bool,
) -> str | None:
body = {
"memory_ids": memory_ids,
"similarity": similarity,
"proposed_content": verdict["content"],
"proposed_memory_type": verdict["memory_type"],
"proposed_project": verdict["project"],
"proposed_tags": verdict["domain_tags"],
"proposed_confidence": verdict["confidence"],
"reason": verdict["reason"],
}
if dry_run:
print(f" [dry-run] would POST: {json.dumps(body)[:200]}...")
return "dry-run"
try:
result = api_post(base_url, "/admin/memory/merge-candidates/create", body)
return result.get("candidate_id")
except urllib.error.HTTPError as e:
print(f" ERROR: submit failed: {e.code} {e.read().decode()[:200]}", file=sys.stderr)
return None
except Exception as e:
print(f" ERROR: submit failed: {e}", file=sys.stderr)
return None
def main() -> None:
parser = argparse.ArgumentParser(description="Phase 7A semantic dedup detector")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
parser.add_argument("--project", default="", help="Only scan this project (empty = all)")
parser.add_argument("--similarity-threshold", type=float, default=0.88)
parser.add_argument("--max-batch", type=int, default=50,
help="Max clusters to propose per run")
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--timeout-s", type=float, default=DEFAULT_TIMEOUT_S)
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
base = args.base_url.rstrip("/")
print(f"memory_dedup {DEDUP_PROMPT_VERSION} | threshold={args.similarity_threshold} | model={args.model}")
mems = fetch_active_memories(base, args.project or None)
print(f"fetched {len(mems)} active memories")
if not mems:
return
buckets = group_memories(mems)
print(f"grouped into {len(buckets)} (project, memory_type) buckets")
clusters_found = 0
candidates_created = 0
skipped_existing = 0
llm_rejections = 0
for (proj, mtype), group in sorted(buckets.items()):
if len(group) < 2:
continue
if candidates_created >= args.max_batch:
print(f"reached max-batch={args.max_batch}, stopping")
break
texts = [(m.get("content") or "") for m in group]
clusters = cluster_by_threshold(texts, args.similarity_threshold)
# Keep only non-trivial clusters
clusters = [c for c in clusters if len(c) >= 2]
if not clusters:
continue
print(f"\n[{proj or '(global)'}/{mtype}] {len(group)} mems → {len(clusters)} cluster(s)")
for cluster in clusters:
if candidates_created >= args.max_batch:
break
clusters_found += 1
sources = [group[i] for i in cluster]
ids = [s["id"] for s in sources]
# Approximate cluster similarity = min pairwise within cluster.
# For reporting, just use threshold (we know all pairs >= threshold
# transitively; min may be lower). Keep it simple.
sim = args.similarity_threshold
print(f" cluster of {len(cluster)}: {[s['id'][:8] for s in sources]}")
verdict = draft_merge(sources, args.model, args.timeout_s)
if verdict is None:
continue
if verdict["action"] == "reject":
llm_rejections += 1
print(f" LLM rejected: {verdict['reason'][:100]}")
continue
cid = submit_candidate(base, ids, sim, verdict, args.dry_run)
if cid == "dry-run":
candidates_created += 1
elif cid:
candidates_created += 1
print(f" → candidate {cid[:8]}")
else:
skipped_existing += 1
time.sleep(0.3) # be kind to claude CLI
print(
f"\nsummary: clusters_found={clusters_found} "
f"candidates_created={candidates_created} "
f"llm_rejections={llm_rejections} "
f"skipped_existing={skipped_existing}"
)
if __name__ == "__main__":
main()