"""Extractor eval runner — scores the rule-based extractor against a labeled interaction corpus. Pulls full interaction content from a frozen snapshot, runs each through ``extract_candidates_from_interaction``, and compares the output to the expected counts from a labels file. Produces a per-label scorecard plus aggregate precision / recall / yield numbers. This harness deliberately stays file-based: snapshot + labels + this runner. No Dalidou HTTP dependency once the snapshot is frozen, so the eval is reproducible run-to-run even as live captures drift. Usage: python scripts/extractor_eval.py # human report python scripts/extractor_eval.py --json # machine-readable python scripts/extractor_eval.py \\ --snapshot scripts/eval_data/interactions_snapshot_2026-04-11.json \\ --labels scripts/eval_data/extractor_labels_2026-04-11.json """ from __future__ import annotations import argparse import json import sys from dataclasses import dataclass, field from pathlib import Path # Make src/ importable without requiring an install. _REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_REPO_ROOT / "src")) from atocore.interactions.service import Interaction # noqa: E402 from atocore.memory.extractor import extract_candidates_from_interaction # noqa: E402 from atocore.memory.extractor_llm import extract_candidates_llm # noqa: E402 DEFAULT_SNAPSHOT = _REPO_ROOT / "scripts" / "eval_data" / "interactions_snapshot_2026-04-11.json" DEFAULT_LABELS = _REPO_ROOT / "scripts" / "eval_data" / "extractor_labels_2026-04-11.json" @dataclass class LabelResult: id: str expected_count: int actual_count: int ok: bool miss_class: str notes: str actual_candidates: list[dict] = field(default_factory=list) def load_snapshot(path: Path) -> dict[str, dict]: data = json.loads(path.read_text(encoding="utf-8")) return {item["id"]: item for item in data.get("interactions", [])} def load_labels(path: Path) -> dict: return json.loads(path.read_text(encoding="utf-8")) def interaction_from_snapshot(snap: dict) -> Interaction: return Interaction( id=snap["id"], prompt=snap.get("prompt", "") or "", response=snap.get("response", "") or "", response_summary="", project=snap.get("project", "") or "", client=snap.get("client", "") or "", session_id=snap.get("session_id", "") or "", created_at=snap.get("created_at", "") or "", ) def score(snapshot: dict[str, dict], labels_doc: dict, mode: str = "rule") -> list[LabelResult]: results: list[LabelResult] = [] for label in labels_doc["labels"]: iid = label["id"] snap = snapshot.get(iid) if snap is None: results.append( LabelResult( id=iid, expected_count=int(label.get("expected_count", 0)), actual_count=-1, ok=False, miss_class="not_in_snapshot", notes=label.get("notes", ""), ) ) continue interaction = interaction_from_snapshot(snap) if mode == "llm": candidates = extract_candidates_llm(interaction) else: candidates = extract_candidates_from_interaction(interaction) actual_count = len(candidates) expected_count = int(label.get("expected_count", 0)) results.append( LabelResult( id=iid, expected_count=expected_count, actual_count=actual_count, ok=(actual_count == expected_count), miss_class=label.get("miss_class", "n/a"), notes=label.get("notes", ""), actual_candidates=[ { "memory_type": c.memory_type, "content": c.content, "project": c.project, "rule": c.rule, } for c in candidates ], ) ) return results def aggregate(results: list[LabelResult]) -> dict: total = len(results) exact_match = sum(1 for r in results if r.ok) true_positive = sum(1 for r in results if r.expected_count > 0 and r.actual_count > 0) false_positive_interactions = sum( 1 for r in results if r.expected_count == 0 and r.actual_count > 0 ) false_negative_interactions = sum( 1 for r in results if r.expected_count > 0 and r.actual_count == 0 ) positive_expected = sum(1 for r in results if r.expected_count > 0) total_expected_candidates = sum(r.expected_count for r in results) total_actual_candidates = sum(max(r.actual_count, 0) for r in results) yield_rate = total_actual_candidates / total if total else 0.0 # Recall over interaction count that had at least one expected candidate: recall = true_positive / positive_expected if positive_expected else 0.0 # Precision over interaction count that produced any candidate: precision_denom = true_positive + false_positive_interactions precision = true_positive / precision_denom if precision_denom else 0.0 # Miss class breakdown miss_classes: dict[str, int] = {} for r in results: if r.expected_count > 0 and r.actual_count == 0: key = r.miss_class or "unlabeled" miss_classes[key] = miss_classes.get(key, 0) + 1 return { "total": total, "exact_match": exact_match, "positive_expected": positive_expected, "total_expected_candidates": total_expected_candidates, "total_actual_candidates": total_actual_candidates, "yield_rate": round(yield_rate, 3), "recall": round(recall, 3), "precision": round(precision, 3), "false_positive_interactions": false_positive_interactions, "false_negative_interactions": false_negative_interactions, "miss_classes": miss_classes, } def print_human(results: list[LabelResult], summary: dict) -> None: print("=== Extractor eval ===") print( f"labeled={summary['total']} " f"exact_match={summary['exact_match']} " f"positive_expected={summary['positive_expected']}" ) print( f"yield={summary['yield_rate']} " f"recall={summary['recall']} " f"precision={summary['precision']}" ) print( f"false_positives={summary['false_positive_interactions']} " f"false_negatives={summary['false_negative_interactions']}" ) print() print("miss class breakdown (FN):") if summary["miss_classes"]: for k, v in sorted(summary["miss_classes"].items(), key=lambda kv: -kv[1]): print(f" {v:3d} {k}") else: print(" (none)") print() print("per-interaction:") for r in results: marker = "OK " if r.ok else "MISS" iid_short = r.id[:8] print(f" {marker} {iid_short} expected={r.expected_count} actual={r.actual_count} class={r.miss_class}") if r.actual_candidates: for c in r.actual_candidates: preview = (c["content"] or "")[:80] print(f" [{c['memory_type']}] {preview}") def print_json(results: list[LabelResult], summary: dict) -> None: payload = { "summary": summary, "results": [ { "id": r.id, "expected_count": r.expected_count, "actual_count": r.actual_count, "ok": r.ok, "miss_class": r.miss_class, "notes": r.notes, "actual_candidates": r.actual_candidates, } for r in results ], } json.dump(payload, sys.stdout, indent=2) sys.stdout.write("\n") def main() -> int: parser = argparse.ArgumentParser(description="AtoCore extractor eval") parser.add_argument("--snapshot", type=Path, default=DEFAULT_SNAPSHOT) parser.add_argument("--labels", type=Path, default=DEFAULT_LABELS) parser.add_argument("--json", action="store_true", help="emit machine-readable JSON") parser.add_argument( "--mode", choices=["rule", "llm"], default="rule", help="which extractor to score (default: rule)", ) args = parser.parse_args() snapshot = load_snapshot(args.snapshot) labels = load_labels(args.labels) results = score(snapshot, labels, mode=args.mode) summary = aggregate(results) summary["mode"] = args.mode if args.json: print_json(results, summary) else: print_human(results, summary) return 0 if summary["false_negative_interactions"] == 0 and summary["false_positive_interactions"] == 0 else 1 if __name__ == "__main__": raise SystemExit(main())