199 lines
6.5 KiB
Python
199 lines
6.5 KiB
Python
|
|
"""Tests for deploy/hooks/inject_context.py — Claude Code UserPromptSubmit hook.
|
||
|
|
|
||
|
|
These are process-level tests: we run the actual script with subprocess,
|
||
|
|
feed it stdin, and check the exit code + stdout shape. The hook must:
|
||
|
|
- always exit 0 (never block a user prompt)
|
||
|
|
- emit valid hookSpecificOutput JSON on success
|
||
|
|
- fail open (empty output) on network errors, bad stdin, kill-switch
|
||
|
|
- respect the short-prompt filter
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
HOOK = Path(__file__).resolve().parent.parent / "deploy" / "hooks" / "inject_context.py"
|
||
|
|
|
||
|
|
|
||
|
|
def _run_hook(stdin_json: dict | str, env_overrides: dict | None = None, timeout: float = 10) -> tuple[int, str, str]:
|
||
|
|
env = os.environ.copy()
|
||
|
|
# Force kill switch off unless the test overrides
|
||
|
|
env.pop("ATOCORE_CONTEXT_DISABLED", None)
|
||
|
|
if env_overrides:
|
||
|
|
env.update(env_overrides)
|
||
|
|
stdin = stdin_json if isinstance(stdin_json, str) else json.dumps(stdin_json)
|
||
|
|
proc = subprocess.run(
|
||
|
|
[sys.executable, str(HOOK)],
|
||
|
|
input=stdin, text=True,
|
||
|
|
capture_output=True, timeout=timeout,
|
||
|
|
env=env,
|
||
|
|
)
|
||
|
|
return proc.returncode, proc.stdout, proc.stderr
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_exit_0_on_success_or_failure():
|
||
|
|
"""Canonical contract: the hook never blocks a prompt. Even with a
|
||
|
|
bogus URL we must exit 0 with empty stdout (fail-open)."""
|
||
|
|
code, stdout, stderr = _run_hook(
|
||
|
|
{
|
||
|
|
"prompt": "What's the p04-gigabit current status?",
|
||
|
|
"cwd": "/tmp",
|
||
|
|
"session_id": "t",
|
||
|
|
"hook_event_name": "UserPromptSubmit",
|
||
|
|
},
|
||
|
|
env_overrides={"ATOCORE_URL": "http://127.0.0.1:1", # unreachable
|
||
|
|
"ATOCORE_CONTEXT_TIMEOUT": "1"},
|
||
|
|
)
|
||
|
|
assert code == 0
|
||
|
|
# stdout is empty (fail-open) — no hookSpecificOutput emitted
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
assert "atocore unreachable" in stderr or "request failed" in stderr
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_kill_switch():
|
||
|
|
code, stdout, stderr = _run_hook(
|
||
|
|
{"prompt": "hello world is this a thing", "cwd": "", "session_id": "t"},
|
||
|
|
env_overrides={"ATOCORE_CONTEXT_DISABLED": "1"},
|
||
|
|
)
|
||
|
|
assert code == 0
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_ignores_short_prompt():
|
||
|
|
code, stdout, _ = _run_hook(
|
||
|
|
{"prompt": "ok", "cwd": "", "session_id": "t"},
|
||
|
|
env_overrides={"ATOCORE_URL": "http://127.0.0.1:1"},
|
||
|
|
)
|
||
|
|
assert code == 0
|
||
|
|
# No network call attempted; empty output
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_ignores_xml_prompt():
|
||
|
|
"""System/meta prompts starting with '<' should be skipped."""
|
||
|
|
code, stdout, _ = _run_hook(
|
||
|
|
{"prompt": "<system>do something</system>", "cwd": "", "session_id": "t"},
|
||
|
|
env_overrides={"ATOCORE_URL": "http://127.0.0.1:1"},
|
||
|
|
)
|
||
|
|
assert code == 0
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_handles_bad_stdin():
|
||
|
|
code, stdout, stderr = _run_hook("not-json-at-all")
|
||
|
|
assert code == 0
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
assert "bad stdin" in stderr
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_handles_empty_stdin():
|
||
|
|
code, stdout, _ = _run_hook("")
|
||
|
|
assert code == 0
|
||
|
|
assert stdout.strip() == ""
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_success_shape_with_mock_server(monkeypatch, tmp_path):
|
||
|
|
"""When the API returns a pack, the hook emits valid
|
||
|
|
hookSpecificOutput JSON wrapping it."""
|
||
|
|
# Start a tiny HTTP server on localhost that returns a fake pack
|
||
|
|
import http.server
|
||
|
|
import json as _json
|
||
|
|
import threading
|
||
|
|
|
||
|
|
pack = "Trusted State: foo=bar"
|
||
|
|
|
||
|
|
class Handler(http.server.BaseHTTPRequestHandler):
|
||
|
|
def do_POST(self): # noqa: N802
|
||
|
|
self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
|
|
body = _json.dumps({"formatted_context": pack}).encode()
|
||
|
|
self.send_response(200)
|
||
|
|
self.send_header("Content-Type", "application/json")
|
||
|
|
self.send_header("Content-Length", str(len(body)))
|
||
|
|
self.end_headers()
|
||
|
|
self.wfile.write(body)
|
||
|
|
|
||
|
|
def log_message(self, *a, **kw):
|
||
|
|
pass
|
||
|
|
|
||
|
|
server = http.server.HTTPServer(("127.0.0.1", 0), Handler)
|
||
|
|
port = server.server_address[1]
|
||
|
|
t = threading.Thread(target=server.serve_forever, daemon=True)
|
||
|
|
t.start()
|
||
|
|
try:
|
||
|
|
code, stdout, stderr = _run_hook(
|
||
|
|
{
|
||
|
|
"prompt": "What do we know about p04?",
|
||
|
|
"cwd": "",
|
||
|
|
"session_id": "t",
|
||
|
|
"hook_event_name": "UserPromptSubmit",
|
||
|
|
},
|
||
|
|
env_overrides={
|
||
|
|
"ATOCORE_URL": f"http://127.0.0.1:{port}",
|
||
|
|
"ATOCORE_CONTEXT_TIMEOUT": "5",
|
||
|
|
},
|
||
|
|
timeout=15,
|
||
|
|
)
|
||
|
|
finally:
|
||
|
|
server.shutdown()
|
||
|
|
|
||
|
|
assert code == 0, stderr
|
||
|
|
assert stdout.strip(), "expected JSON output with context"
|
||
|
|
out = json.loads(stdout)
|
||
|
|
hso = out.get("hookSpecificOutput", {})
|
||
|
|
assert hso.get("hookEventName") == "UserPromptSubmit"
|
||
|
|
assert pack in hso.get("additionalContext", "")
|
||
|
|
assert "AtoCore-injected context" in hso.get("additionalContext", "")
|
||
|
|
|
||
|
|
|
||
|
|
def test_hook_project_inference_from_cwd(monkeypatch):
|
||
|
|
"""The hook should map a known cwd to a project slug and send it in
|
||
|
|
the /context/build payload."""
|
||
|
|
import http.server
|
||
|
|
import json as _json
|
||
|
|
import threading
|
||
|
|
|
||
|
|
captured_body: dict = {}
|
||
|
|
|
||
|
|
class Handler(http.server.BaseHTTPRequestHandler):
|
||
|
|
def do_POST(self): # noqa: N802
|
||
|
|
n = int(self.headers.get("Content-Length", 0))
|
||
|
|
body = self.rfile.read(n)
|
||
|
|
captured_body.update(_json.loads(body.decode()))
|
||
|
|
out = _json.dumps({"formatted_context": "ok"}).encode()
|
||
|
|
self.send_response(200)
|
||
|
|
self.send_header("Content-Length", str(len(out)))
|
||
|
|
self.end_headers()
|
||
|
|
self.wfile.write(out)
|
||
|
|
|
||
|
|
def log_message(self, *a, **kw):
|
||
|
|
pass
|
||
|
|
|
||
|
|
server = http.server.HTTPServer(("127.0.0.1", 0), Handler)
|
||
|
|
port = server.server_address[1]
|
||
|
|
t = threading.Thread(target=server.serve_forever, daemon=True)
|
||
|
|
t.start()
|
||
|
|
try:
|
||
|
|
_run_hook(
|
||
|
|
{
|
||
|
|
"prompt": "Is this being tested properly",
|
||
|
|
"cwd": "C:\\Users\\antoi\\ATOCore",
|
||
|
|
"session_id": "t",
|
||
|
|
},
|
||
|
|
env_overrides={
|
||
|
|
"ATOCORE_URL": f"http://127.0.0.1:{port}",
|
||
|
|
"ATOCORE_CONTEXT_TIMEOUT": "5",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
finally:
|
||
|
|
server.shutdown()
|
||
|
|
|
||
|
|
# Hook should have inferred project="atocore" from the ATOCore cwd
|
||
|
|
assert captured_body.get("project") == "atocore"
|
||
|
|
assert captured_body.get("prompt", "").startswith("Is this being tested")
|