feat: Phase 7A — semantic memory dedup ("sleep cycle" V1)
New table memory_merge_candidates + service functions to cluster near-duplicate active memories within (project, memory_type) buckets, draft a unified content via LLM, and merge on human approval. Source memories become superseded (never deleted); merged memory carries union of tags, max of confidence, sum of reference_count. - schema migration for memory_merge_candidates - atocore.memory.similarity: cosine + transitive clustering - atocore.memory._dedup_prompt: stdlib-only LLM prompt preserving every specific - service: merge_memories / create_merge_candidate / get_merge_candidates / reject_merge_candidate - scripts/memory_dedup.py: host-side detector (HTTP-only, idempotent) - 5 API endpoints under /admin/memory/merge-candidates* + /admin/memory/dedup-scan - triage UI: purple "🔗 Merge Candidates" section + "🔗 Scan for duplicates" bar - batch-extract.sh Step B3 (0.90 daily, 0.85 Sundays) - deploy/dalidou/dedup-watcher.sh for UI-triggered scans - 21 new tests (374 → 395) - docs/PHASE-7-MEMORY-CONSOLIDATION.md covering 7A-7H roadmap Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
156
src/atocore/memory/_dedup_prompt.py
Normal file
156
src/atocore/memory/_dedup_prompt.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Shared LLM prompt + parser for memory dedup (Phase 7A).
|
||||
|
||||
Stdlib-only — must be importable from both the in-container service
|
||||
layer (when a user clicks "scan for duplicates" in the UI) and the
|
||||
host-side batch script (``scripts/memory_dedup.py``), which runs on
|
||||
Dalidou where the container's Python deps are not available.
|
||||
|
||||
The prompt instructs the model to draft a UNIFIED memory that
|
||||
preserves every specific detail from the sources. We never want a
|
||||
merge to lose information — if two memories disagree on a number, the
|
||||
merged content should surface both with context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
DEDUP_PROMPT_VERSION = "dedup-0.1.0"
|
||||
MAX_CONTENT_CHARS = 1000
|
||||
MAX_SOURCES = 8 # cluster size cap — bigger clusters are suspicious
|
||||
|
||||
SYSTEM_PROMPT = """You consolidate near-duplicate memories for AtoCore, a personal context engine.
|
||||
|
||||
Given 2-8 memories that a semantic-similarity scan flagged as likely duplicates, draft a UNIFIED replacement that preserves every specific detail from every source.
|
||||
|
||||
CORE PRINCIPLE: information never gets lost. If the sources disagree on a number, date, vendor, or spec, surface BOTH with attribution (e.g., "quoted at $3.2k on 2026-03-01, revised to $3.8k on 2026-04-10"). If one source is more specific than another, keep the specificity. If they say the same thing differently, pick the clearer wording.
|
||||
|
||||
YOU MUST:
|
||||
- Produce content under 500 characters that reads as a single coherent statement
|
||||
- Keep all project/vendor/person/part names that appear in any source
|
||||
- Keep all numbers, dates, and identifiers
|
||||
- Keep the strongest claim wording ("ratified", "decided", "committed") if any source has it
|
||||
- Propose domain_tags as a UNION of the sources' tags (lowercase, deduped, cap 6)
|
||||
- Return valid_until = latest non-null valid_until across sources, or null if any source has null (permanent beats transient)
|
||||
|
||||
REFUSE TO MERGE (return action="reject") if:
|
||||
- The memories are actually about DIFFERENT subjects that just share vocabulary (e.g., "p04 mirror" and "p05 mirror" — same project bucket means same project, but different components)
|
||||
- One memory CONTRADICTS another and you cannot reconcile them — flag for contradiction review instead
|
||||
- The sources span different time snapshots of a changing state that should stay as a timeline, not be collapsed
|
||||
|
||||
OUTPUT — raw JSON, no prose, no markdown fences:
|
||||
{
|
||||
"action": "merge" | "reject",
|
||||
"content": "the unified memory content",
|
||||
"memory_type": "knowledge|project|preference|adaptation|episodic|identity",
|
||||
"project": "project-slug or empty",
|
||||
"domain_tags": ["tag1", "tag2"],
|
||||
"confidence": 0.5,
|
||||
"reason": "one sentence explaining the merge (or the rejection)"
|
||||
}
|
||||
|
||||
On action=reject, still fill content with a short explanation and set confidence=0."""
|
||||
|
||||
|
||||
def build_user_message(sources: list[dict[str, Any]]) -> str:
|
||||
"""Format N source memories for the model to consolidate.
|
||||
|
||||
Each source dict should carry id, content, project, memory_type,
|
||||
domain_tags, confidence, valid_until, reference_count.
|
||||
"""
|
||||
lines = [f"You have {len(sources)} source memories in the same (project, memory_type) bucket:\n"]
|
||||
for i, src in enumerate(sources[:MAX_SOURCES], start=1):
|
||||
tags = src.get("domain_tags") or []
|
||||
if isinstance(tags, str):
|
||||
try:
|
||||
tags = json.loads(tags)
|
||||
except Exception:
|
||||
tags = []
|
||||
lines.append(
|
||||
f"--- Source {i} (id={src.get('id','?')[:8]}, "
|
||||
f"refs={src.get('reference_count',0)}, "
|
||||
f"conf={src.get('confidence',0):.2f}, "
|
||||
f"valid_until={src.get('valid_until') or 'permanent'}) ---"
|
||||
)
|
||||
lines.append(f"project: {src.get('project','')}")
|
||||
lines.append(f"type: {src.get('memory_type','')}")
|
||||
lines.append(f"tags: {tags}")
|
||||
lines.append(f"content: {(src.get('content') or '')[:MAX_CONTENT_CHARS]}")
|
||||
lines.append("")
|
||||
lines.append("Return the JSON object now.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def parse_merge_verdict(raw_output: str) -> dict[str, Any] | None:
|
||||
"""Strip markdown fences / leading prose and return the parsed JSON
|
||||
object. Returns None on parse failure."""
|
||||
text = (raw_output or "").strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`")
|
||||
nl = text.find("\n")
|
||||
if nl >= 0:
|
||||
text = text[nl + 1:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
text = text.strip()
|
||||
|
||||
if not text.lstrip().startswith("{"):
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
text = text[start:end + 1]
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def normalize_merge_verdict(verdict: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Validate + normalize a raw merge verdict. Returns None if the
|
||||
verdict is unusable (no content, unknown action)."""
|
||||
action = str(verdict.get("action") or "").strip().lower()
|
||||
if action not in ("merge", "reject"):
|
||||
return None
|
||||
|
||||
content = str(verdict.get("content") or "").strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
memory_type = str(verdict.get("memory_type") or "knowledge").strip().lower()
|
||||
project = str(verdict.get("project") or "").strip()
|
||||
|
||||
raw_tags = verdict.get("domain_tags") or []
|
||||
if isinstance(raw_tags, str):
|
||||
raw_tags = [t.strip() for t in raw_tags.split(",") if t.strip()]
|
||||
if not isinstance(raw_tags, list):
|
||||
raw_tags = []
|
||||
tags: list[str] = []
|
||||
for t in raw_tags[:6]:
|
||||
if not isinstance(t, str):
|
||||
continue
|
||||
tt = t.strip().lower()
|
||||
if tt and tt not in tags:
|
||||
tags.append(tt)
|
||||
|
||||
try:
|
||||
confidence = float(verdict.get("confidence", 0.5))
|
||||
except (TypeError, ValueError):
|
||||
confidence = 0.5
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
|
||||
reason = str(verdict.get("reason") or "").strip()[:500]
|
||||
|
||||
return {
|
||||
"action": action,
|
||||
"content": content[:1000],
|
||||
"memory_type": memory_type,
|
||||
"project": project,
|
||||
"domain_tags": tags,
|
||||
"confidence": confidence,
|
||||
"reason": reason,
|
||||
}
|
||||
@@ -925,3 +925,327 @@ def _row_to_memory(row) -> Memory:
|
||||
def _validate_confidence(confidence: float) -> None:
|
||||
if not 0.0 <= confidence <= 1.0:
|
||||
raise ValueError("Confidence must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Phase 7A — Memory Consolidation: merge-candidate lifecycle
|
||||
# ---------------------------------------------------------------------
|
||||
#
|
||||
# The detector (scripts/memory_dedup.py) writes proposals into
|
||||
# memory_merge_candidates. The triage UI lists pending rows, a human
|
||||
# reviews, and on approve we execute the merge here — never at detect
|
||||
# time. This keeps the audit trail clean: every mutation is a human
|
||||
# decision.
|
||||
|
||||
|
||||
def create_merge_candidate(
|
||||
memory_ids: list[str],
|
||||
similarity: float,
|
||||
proposed_content: str,
|
||||
proposed_memory_type: str,
|
||||
proposed_project: str,
|
||||
proposed_tags: list[str] | None = None,
|
||||
proposed_confidence: float = 0.6,
|
||||
reason: str = "",
|
||||
) -> str | None:
|
||||
"""Insert a merge-candidate row. Returns the new row id, or None if
|
||||
a pending candidate already covers this exact set of memory ids
|
||||
(idempotent scan — re-running the detector doesn't double-create)."""
|
||||
import json as _json
|
||||
|
||||
if not memory_ids or len(memory_ids) < 2:
|
||||
raise ValueError("merge candidate requires at least 2 memory_ids")
|
||||
|
||||
memory_ids_sorted = sorted(set(memory_ids))
|
||||
memory_ids_json = _json.dumps(memory_ids_sorted)
|
||||
tags_json = _json.dumps(_normalize_tags(proposed_tags))
|
||||
candidate_id = str(uuid.uuid4())
|
||||
|
||||
with get_connection() as conn:
|
||||
# Idempotency: same sorted-id set already pending? skip.
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM memory_merge_candidates "
|
||||
"WHERE status = 'pending' AND memory_ids = ?",
|
||||
(memory_ids_json,),
|
||||
).fetchone()
|
||||
if existing:
|
||||
return None
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memory_merge_candidates "
|
||||
"(id, status, memory_ids, similarity, proposed_content, "
|
||||
"proposed_memory_type, proposed_project, proposed_tags, "
|
||||
"proposed_confidence, reason) "
|
||||
"VALUES (?, 'pending', ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
candidate_id, memory_ids_json, float(similarity or 0.0),
|
||||
(proposed_content or "")[:2000],
|
||||
(proposed_memory_type or "knowledge")[:50],
|
||||
(proposed_project or "")[:100],
|
||||
tags_json,
|
||||
max(0.0, min(1.0, float(proposed_confidence))),
|
||||
(reason or "")[:500],
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
"merge_candidate_created",
|
||||
candidate_id=candidate_id,
|
||||
memory_count=len(memory_ids_sorted),
|
||||
similarity=round(similarity, 4),
|
||||
)
|
||||
return candidate_id
|
||||
|
||||
|
||||
def get_merge_candidates(status: str = "pending", limit: int = 100) -> list[dict]:
|
||||
"""List merge candidates with their source memories inlined."""
|
||||
import json as _json
|
||||
|
||||
with get_connection() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM memory_merge_candidates "
|
||||
"WHERE status = ? ORDER BY created_at DESC LIMIT ?",
|
||||
(status, limit),
|
||||
).fetchall()
|
||||
|
||||
out = []
|
||||
for r in rows:
|
||||
try:
|
||||
mem_ids = _json.loads(r["memory_ids"] or "[]")
|
||||
except Exception:
|
||||
mem_ids = []
|
||||
try:
|
||||
tags = _json.loads(r["proposed_tags"] or "[]")
|
||||
except Exception:
|
||||
tags = []
|
||||
|
||||
sources = []
|
||||
for mid in mem_ids:
|
||||
srow = conn.execute(
|
||||
"SELECT id, memory_type, content, project, confidence, "
|
||||
"status, reference_count, domain_tags, valid_until "
|
||||
"FROM memories WHERE id = ?",
|
||||
(mid,),
|
||||
).fetchone()
|
||||
if srow:
|
||||
try:
|
||||
stags = _json.loads(srow["domain_tags"] or "[]")
|
||||
except Exception:
|
||||
stags = []
|
||||
sources.append({
|
||||
"id": srow["id"],
|
||||
"memory_type": srow["memory_type"],
|
||||
"content": srow["content"],
|
||||
"project": srow["project"] or "",
|
||||
"confidence": srow["confidence"],
|
||||
"status": srow["status"],
|
||||
"reference_count": int(srow["reference_count"] or 0),
|
||||
"domain_tags": stags,
|
||||
"valid_until": srow["valid_until"] or "",
|
||||
})
|
||||
|
||||
out.append({
|
||||
"id": r["id"],
|
||||
"status": r["status"],
|
||||
"memory_ids": mem_ids,
|
||||
"similarity": r["similarity"],
|
||||
"proposed_content": r["proposed_content"] or "",
|
||||
"proposed_memory_type": r["proposed_memory_type"] or "knowledge",
|
||||
"proposed_project": r["proposed_project"] or "",
|
||||
"proposed_tags": tags,
|
||||
"proposed_confidence": r["proposed_confidence"],
|
||||
"reason": r["reason"] or "",
|
||||
"created_at": r["created_at"],
|
||||
"resolved_at": r["resolved_at"],
|
||||
"resolved_by": r["resolved_by"],
|
||||
"result_memory_id": r["result_memory_id"],
|
||||
"sources": sources,
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def reject_merge_candidate(candidate_id: str, actor: str = "human-triage", note: str = "") -> bool:
|
||||
"""Mark a merge candidate as rejected. Source memories stay untouched."""
|
||||
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
with get_connection() as conn:
|
||||
result = conn.execute(
|
||||
"UPDATE memory_merge_candidates "
|
||||
"SET status = 'rejected', resolved_at = ?, resolved_by = ? "
|
||||
"WHERE id = ? AND status = 'pending'",
|
||||
(now_str, actor, candidate_id),
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
return False
|
||||
log.info("merge_candidate_rejected", candidate_id=candidate_id, actor=actor, note=note[:100])
|
||||
return True
|
||||
|
||||
|
||||
def merge_memories(
|
||||
candidate_id: str,
|
||||
actor: str = "human-triage",
|
||||
override_content: str | None = None,
|
||||
override_tags: list[str] | None = None,
|
||||
) -> str | None:
|
||||
"""Execute an approved merge candidate.
|
||||
|
||||
1. Validate all source memories still status=active
|
||||
2. Create the new merged memory (status=active)
|
||||
3. Mark each source status=superseded with an audit row pointing at
|
||||
the new merged id
|
||||
4. Mark the candidate status=approved, record result_memory_id
|
||||
5. Write a consolidated audit row on the new memory
|
||||
|
||||
Returns the new merged memory's id, or None if the candidate cannot
|
||||
be executed (already resolved, source tampered, etc.).
|
||||
|
||||
``override_content`` and ``override_tags`` let the UI pass the human's
|
||||
edits before clicking approve.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memory_merge_candidates WHERE id = ?",
|
||||
(candidate_id,),
|
||||
).fetchone()
|
||||
if row is None or row["status"] != "pending":
|
||||
log.warning("merge_candidate_not_pending", candidate_id=candidate_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
mem_ids = _json.loads(row["memory_ids"] or "[]")
|
||||
except Exception:
|
||||
mem_ids = []
|
||||
if not mem_ids or len(mem_ids) < 2:
|
||||
log.warning("merge_candidate_invalid_memory_ids", candidate_id=candidate_id)
|
||||
return None
|
||||
|
||||
# Snapshot sources + validate all active
|
||||
source_rows = []
|
||||
for mid in mem_ids:
|
||||
srow = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?", (mid,)
|
||||
).fetchone()
|
||||
if srow is None or srow["status"] != "active":
|
||||
log.warning(
|
||||
"merge_source_not_active",
|
||||
candidate_id=candidate_id,
|
||||
memory_id=mid,
|
||||
actual_status=(srow["status"] if srow else "missing"),
|
||||
)
|
||||
return None
|
||||
source_rows.append(srow)
|
||||
|
||||
# Build merged memory fields — prefer human overrides, then proposed
|
||||
content = (override_content or row["proposed_content"] or "").strip()
|
||||
if not content:
|
||||
log.warning("merge_candidate_empty_content", candidate_id=candidate_id)
|
||||
return None
|
||||
|
||||
merged_type = (row["proposed_memory_type"] or source_rows[0]["memory_type"]).lower()
|
||||
if merged_type not in MEMORY_TYPES:
|
||||
merged_type = source_rows[0]["memory_type"]
|
||||
|
||||
merged_project = row["proposed_project"] or source_rows[0]["project"] or ""
|
||||
merged_project = resolve_project_name(merged_project)
|
||||
|
||||
# Tags: override wins, else proposed, else union of sources
|
||||
if override_tags is not None:
|
||||
merged_tags = _normalize_tags(override_tags)
|
||||
else:
|
||||
try:
|
||||
proposed_tags = _json.loads(row["proposed_tags"] or "[]")
|
||||
except Exception:
|
||||
proposed_tags = []
|
||||
if proposed_tags:
|
||||
merged_tags = _normalize_tags(proposed_tags)
|
||||
else:
|
||||
union: list[str] = []
|
||||
for srow in source_rows:
|
||||
try:
|
||||
stags = _json.loads(srow["domain_tags"] or "[]")
|
||||
except Exception:
|
||||
stags = []
|
||||
for t in stags:
|
||||
if isinstance(t, str) and t and t not in union:
|
||||
union.append(t)
|
||||
merged_tags = union
|
||||
|
||||
# confidence = max; reference_count = sum
|
||||
merged_confidence = max(float(s["confidence"]) for s in source_rows)
|
||||
total_refs = sum(int(s["reference_count"] or 0) for s in source_rows)
|
||||
|
||||
# valid_until: if any source is permanent (None/empty), merged is permanent.
|
||||
# Otherwise take the latest (lexical compare on ISO dates works).
|
||||
merged_vu: str | None = "" # placeholder
|
||||
has_permanent = any(not (s["valid_until"] or "").strip() for s in source_rows)
|
||||
if has_permanent:
|
||||
merged_vu = None
|
||||
else:
|
||||
merged_vu = max((s["valid_until"] or "").strip() for s in source_rows) or None
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
tags_json = _json.dumps(merged_tags)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, memory_type, content, project, "
|
||||
"source_chunk_id, confidence, status, domain_tags, valid_until, "
|
||||
"reference_count, last_referenced_at) "
|
||||
"VALUES (?, ?, ?, ?, NULL, ?, 'active', ?, ?, ?, ?)",
|
||||
(
|
||||
new_id, merged_type, content[:2000], merged_project,
|
||||
merged_confidence, tags_json, merged_vu, total_refs, now_str,
|
||||
),
|
||||
)
|
||||
|
||||
# Mark sources superseded
|
||||
for srow in source_rows:
|
||||
conn.execute(
|
||||
"UPDATE memories SET status = 'superseded', updated_at = ? "
|
||||
"WHERE id = ?",
|
||||
(now_str, srow["id"]),
|
||||
)
|
||||
|
||||
# Mark candidate approved
|
||||
conn.execute(
|
||||
"UPDATE memory_merge_candidates SET status = 'approved', "
|
||||
"resolved_at = ?, resolved_by = ?, result_memory_id = ? WHERE id = ?",
|
||||
(now_str, actor, new_id, candidate_id),
|
||||
)
|
||||
|
||||
# Audit rows (out of the transaction; fail-open via _audit_memory)
|
||||
_audit_memory(
|
||||
memory_id=new_id,
|
||||
action="created_via_merge",
|
||||
actor=actor,
|
||||
after={
|
||||
"memory_type": merged_type,
|
||||
"content": content,
|
||||
"project": merged_project,
|
||||
"confidence": merged_confidence,
|
||||
"domain_tags": merged_tags,
|
||||
"reference_count": total_refs,
|
||||
"merged_from": list(mem_ids),
|
||||
"merge_candidate_id": candidate_id,
|
||||
},
|
||||
note=f"merged {len(mem_ids)} sources via candidate {candidate_id[:8]}",
|
||||
)
|
||||
for srow in source_rows:
|
||||
_audit_memory(
|
||||
memory_id=srow["id"],
|
||||
action="superseded",
|
||||
actor=actor,
|
||||
before={"status": "active", "content": srow["content"]},
|
||||
after={"status": "superseded", "superseded_by": new_id},
|
||||
note=f"merged into {new_id}",
|
||||
)
|
||||
|
||||
log.info(
|
||||
"merge_executed",
|
||||
candidate_id=candidate_id,
|
||||
result_memory_id=new_id,
|
||||
source_count=len(source_rows),
|
||||
actor=actor,
|
||||
)
|
||||
return new_id
|
||||
|
||||
88
src/atocore/memory/similarity.py
Normal file
88
src/atocore/memory/similarity.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Phase 7A (Memory Consolidation): semantic similarity helpers.
|
||||
|
||||
Thin wrapper over ``atocore.retrieval.embeddings`` that exposes
|
||||
pairwise + batch cosine similarity on normalized embeddings. Used by
|
||||
the dedup detector to cluster near-duplicate active memories.
|
||||
|
||||
Embeddings from ``embed_texts()`` are already L2-normalized, so cosine
|
||||
similarity reduces to a dot product — no extra normalization needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from atocore.retrieval.embeddings import embed_texts
|
||||
|
||||
|
||||
def _dot(a: list[float], b: list[float]) -> float:
|
||||
return sum(x * y for x, y in zip(a, b))
|
||||
|
||||
|
||||
def cosine(a: list[float], b: list[float]) -> float:
|
||||
"""Cosine similarity on already-normalized vectors. Clamped to [0,1]
|
||||
(embeddings use paraphrase-multilingual-MiniLM which is unit-norm,
|
||||
and we never want negative values leaking into thresholds)."""
|
||||
return max(0.0, min(1.0, _dot(a, b)))
|
||||
|
||||
|
||||
def compute_memory_similarity(text_a: str, text_b: str) -> float:
|
||||
"""Return cosine similarity of two memory contents in [0,1].
|
||||
|
||||
Convenience helper for one-off checks + tests. For batch work (the
|
||||
dedup detector), use ``embed_texts()`` directly and compute the
|
||||
similarity matrix yourself to avoid re-embedding shared texts.
|
||||
"""
|
||||
if not text_a or not text_b:
|
||||
return 0.0
|
||||
vecs = embed_texts([text_a, text_b])
|
||||
return cosine(vecs[0], vecs[1])
|
||||
|
||||
|
||||
def similarity_matrix(texts: list[str]) -> list[list[float]]:
|
||||
"""N×N cosine similarity matrix. Diagonal is 1.0, symmetric."""
|
||||
if not texts:
|
||||
return []
|
||||
vecs = embed_texts(texts)
|
||||
n = len(vecs)
|
||||
matrix = [[0.0] * n for _ in range(n)]
|
||||
for i in range(n):
|
||||
matrix[i][i] = 1.0
|
||||
for j in range(i + 1, n):
|
||||
s = cosine(vecs[i], vecs[j])
|
||||
matrix[i][j] = s
|
||||
matrix[j][i] = s
|
||||
return matrix
|
||||
|
||||
|
||||
def cluster_by_threshold(texts: list[str], threshold: float) -> list[list[int]]:
|
||||
"""Greedy transitive clustering: if sim(i,j) >= threshold, merge.
|
||||
|
||||
Returns a list of clusters, each a list of indices into ``texts``.
|
||||
Singletons are included. Used by the dedup detector to collapse
|
||||
A~B~C into one merge proposal rather than three pair proposals.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
matrix = similarity_matrix(texts)
|
||||
n = len(texts)
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x: int) -> int:
|
||||
while parent[x] != x:
|
||||
parent[x] = parent[parent[x]]
|
||||
x = parent[x]
|
||||
return x
|
||||
|
||||
def union(x: int, y: int) -> None:
|
||||
rx, ry = find(x), find(y)
|
||||
if rx != ry:
|
||||
parent[rx] = ry
|
||||
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if matrix[i][j] >= threshold:
|
||||
union(i, j)
|
||||
|
||||
groups: dict[int, list[int]] = {}
|
||||
for i in range(n):
|
||||
groups.setdefault(find(i), []).append(i)
|
||||
return list(groups.values())
|
||||
Reference in New Issue
Block a user