Files
ATOCore/scripts/retrieval_eval.py

238 lines
8.4 KiB
Python

"""Retrieval quality eval harness.
Runs a fixed set of project-hinted questions against
``POST /context/build`` on a live AtoCore instance and scores the
resulting ``formatted_context`` against per-question expectations.
The goal is a diffable scorecard that tells you, run-to-run,
whether a retrieval / builder / ingestion change moved the needle.
Design notes
------------
- Fixtures live in ``scripts/retrieval_eval_fixtures.json`` so new
questions can be added without touching Python. Each fixture
names the project, the prompt, and a checklist of substrings that
MUST appear in ``formatted_context`` (``expect_present``) and
substrings that MUST NOT appear (``expect_absent``). The absent
list catches cross-project bleed and stale content.
- The checklist is deliberately substring-based (not regex, not
embedding-similarity) so a failure is always a trivially
reproducible "this string is not in that string". Richer scoring
can come later once we know the harness is useful.
- The harness is external to the app runtime and talks to AtoCore
over HTTP, so it works against dev, staging, or prod. It follows
the same environment-variable contract as ``atocore_client.py``
(``ATOCORE_BASE_URL``, ``ATOCORE_TIMEOUT_SECONDS``).
- Exit code 0 on all-pass, 1 on any fixture failure. Intended for
manual runs today; a future cron / CI hook can consume the
JSON output via ``--json``.
Usage
-----
python scripts/retrieval_eval.py # human-readable report
python scripts/retrieval_eval.py --json # machine-readable
python scripts/retrieval_eval.py --fixtures path/to/custom.json
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
DEFAULT_BASE_URL = os.environ.get("ATOCORE_BASE_URL", "http://dalidou:8100")
DEFAULT_TIMEOUT = int(os.environ.get("ATOCORE_TIMEOUT_SECONDS", "30"))
DEFAULT_BUDGET = 3000
DEFAULT_FIXTURES = Path(__file__).parent / "retrieval_eval_fixtures.json"
def request_json(base_url: str, path: str, timeout: int) -> dict:
req = urllib.request.Request(f"{base_url}{path}", method="GET")
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return json.loads(body) if body.strip() else {}
@dataclass
class Fixture:
name: str
project: str
prompt: str
budget: int = DEFAULT_BUDGET
expect_present: list[str] = field(default_factory=list)
expect_absent: list[str] = field(default_factory=list)
known_issue: bool = False
notes: str = ""
@dataclass
class FixtureResult:
fixture: Fixture
ok: bool
missing_present: list[str]
unexpected_absent: list[str]
total_chars: int
known_issue: bool = False
error: str = ""
@property
def blocking_failure(self) -> bool:
return not self.ok and not self.known_issue
def load_fixtures(path: Path) -> list[Fixture]:
data = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(data, list):
raise ValueError(f"{path} must contain a JSON array of fixtures")
fixtures: list[Fixture] = []
for i, raw in enumerate(data):
if not isinstance(raw, dict):
raise ValueError(f"fixture {i} is not an object")
fixtures.append(
Fixture(
name=raw["name"],
project=raw.get("project", ""),
prompt=raw["prompt"],
budget=int(raw.get("budget", DEFAULT_BUDGET)),
expect_present=list(raw.get("expect_present", [])),
expect_absent=list(raw.get("expect_absent", [])),
known_issue=bool(raw.get("known_issue", False)),
notes=raw.get("notes", ""),
)
)
return fixtures
def run_fixture(fixture: Fixture, base_url: str, timeout: int) -> FixtureResult:
payload = {
"prompt": fixture.prompt,
"project": fixture.project or None,
"budget": fixture.budget,
}
req = urllib.request.Request(
url=f"{base_url}/context/build",
method="POST",
headers={"Content-Type": "application/json"},
data=json.dumps(payload).encode("utf-8"),
)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = json.loads(resp.read().decode("utf-8"))
except urllib.error.URLError as exc:
return FixtureResult(
fixture=fixture,
ok=False,
missing_present=list(fixture.expect_present),
unexpected_absent=[],
total_chars=0,
known_issue=fixture.known_issue,
error=f"http_error: {exc}",
)
formatted = body.get("formatted_context") or ""
missing = [s for s in fixture.expect_present if s not in formatted]
unexpected = [s for s in fixture.expect_absent if s in formatted]
return FixtureResult(
fixture=fixture,
ok=not missing and not unexpected,
missing_present=missing,
unexpected_absent=unexpected,
total_chars=len(formatted),
known_issue=fixture.known_issue,
)
def print_human_report(results: list[FixtureResult], metadata: dict) -> None:
total = len(results)
passed = sum(1 for r in results if r.ok)
known = sum(1 for r in results if not r.ok and r.known_issue)
blocking = sum(1 for r in results if r.blocking_failure)
print(f"Retrieval eval: {passed}/{total} fixtures passed")
print(
"Target: "
f"{metadata.get('base_url', 'unknown')} "
f"build={metadata.get('health', {}).get('build_sha', 'unknown')}"
)
if known or blocking:
print(f"Blocking failures: {blocking} Known issues: {known}")
print()
for r in results:
marker = "PASS" if r.ok else ("KNOWN" if r.known_issue else "FAIL")
print(f"[{marker}] {r.fixture.name} project={r.fixture.project} chars={r.total_chars}")
if r.error:
print(f" error: {r.error}")
for miss in r.missing_present:
print(f" missing expected: {miss!r}")
for bleed in r.unexpected_absent:
print(f" unexpected present: {bleed!r}")
if r.fixture.notes and not r.ok:
print(f" notes: {r.fixture.notes}")
def print_json_report(results: list[FixtureResult], metadata: dict) -> None:
payload = {
"generated_at": metadata.get("generated_at"),
"base_url": metadata.get("base_url"),
"health": metadata.get("health", {}),
"total": len(results),
"passed": sum(1 for r in results if r.ok),
"known_issues": sum(1 for r in results if not r.ok and r.known_issue),
"blocking_failures": sum(1 for r in results if r.blocking_failure),
"fixtures": [
{
"name": r.fixture.name,
"project": r.fixture.project,
"ok": r.ok,
"known_issue": r.known_issue,
"total_chars": r.total_chars,
"missing_present": r.missing_present,
"unexpected_absent": r.unexpected_absent,
"error": r.error,
}
for r in results
],
}
json.dump(payload, sys.stdout, indent=2)
sys.stdout.write("\n")
def main() -> int:
parser = argparse.ArgumentParser(description="AtoCore retrieval quality eval harness")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT)
parser.add_argument("--fixtures", type=Path, default=DEFAULT_FIXTURES)
parser.add_argument("--json", action="store_true", help="emit machine-readable JSON")
args = parser.parse_args()
base_url = args.base_url.rstrip("/")
try:
health = request_json(base_url, "/health", args.timeout)
except (urllib.error.URLError, TimeoutError, OSError, json.JSONDecodeError) as exc:
health = {"error": str(exc)}
metadata = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"base_url": base_url,
"health": health,
}
fixtures = load_fixtures(args.fixtures)
results = [run_fixture(f, base_url, args.timeout) for f in fixtures]
if args.json:
print_json_report(results, metadata)
else:
print_human_report(results, metadata)
return 0 if not any(r.blocking_failure for r in results) else 1
if __name__ == "__main__":
raise SystemExit(main())