Files
ATOCore/tests/test_capture_stop.py

250 lines
9.5 KiB
Python
Raw Normal View History

"""Tests for deploy/hooks/capture_stop.py — Claude Code Stop hook."""
from __future__ import annotations
import json
import os
import sys
import tempfile
import textwrap
from io import StringIO
from pathlib import Path
from unittest import mock
import pytest
# The hook script lives outside of the normal package tree, so import
# it by manipulating sys.path.
_HOOK_DIR = str(Path(__file__).resolve().parent.parent / "deploy" / "hooks")
if _HOOK_DIR not in sys.path:
sys.path.insert(0, _HOOK_DIR)
import capture_stop # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _write_transcript(tmp: Path, entries: list[dict]) -> str:
"""Write a JSONL transcript and return the path."""
path = tmp / "transcript.jsonl"
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
return str(path)
def _user_entry(content: str, *, is_meta: bool = False) -> dict:
return {
"type": "user",
"isMeta": is_meta,
"message": {"role": "user", "content": content},
}
def _assistant_entry() -> dict:
return {
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Sure, here's the answer."}],
},
}
def _system_entry() -> dict:
return {"type": "system", "message": {"role": "system", "content": "system init"}}
# ---------------------------------------------------------------------------
# _extract_last_user_prompt
# ---------------------------------------------------------------------------
class TestExtractLastUserPrompt:
def test_returns_last_real_prompt(self, tmp_path):
path = _write_transcript(tmp_path, [
_user_entry("First prompt that is long enough to capture"),
_assistant_entry(),
_user_entry("Second prompt that should be the one we capture"),
_assistant_entry(),
])
result = capture_stop._extract_last_user_prompt(path)
assert result == "Second prompt that should be the one we capture"
def test_skips_meta_messages(self, tmp_path):
path = _write_transcript(tmp_path, [
_user_entry("Real prompt that is definitely long enough"),
_user_entry("<local-command>some system stuff</local-command>"),
_user_entry("Meta message that looks real enough", is_meta=True),
])
result = capture_stop._extract_last_user_prompt(path)
assert result == "Real prompt that is definitely long enough"
def test_skips_xml_content(self, tmp_path):
path = _write_transcript(tmp_path, [
_user_entry("Actual prompt from a real human user"),
_user_entry("<command-name>/help</command-name>"),
])
result = capture_stop._extract_last_user_prompt(path)
assert result == "Actual prompt from a real human user"
def test_skips_short_messages(self, tmp_path):
path = _write_transcript(tmp_path, [
_user_entry("This prompt is long enough to be captured"),
_user_entry("yes"), # too short
])
result = capture_stop._extract_last_user_prompt(path)
assert result == "This prompt is long enough to be captured"
def test_handles_content_blocks(self, tmp_path):
entry = {
"type": "user",
"message": {
"role": "user",
"content": [
{"type": "text", "text": "First paragraph of the prompt."},
{"type": "text", "text": "Second paragraph continues here."},
],
},
}
path = _write_transcript(tmp_path, [entry])
result = capture_stop._extract_last_user_prompt(path)
assert "First paragraph" in result
assert "Second paragraph" in result
def test_empty_transcript(self, tmp_path):
path = _write_transcript(tmp_path, [])
result = capture_stop._extract_last_user_prompt(path)
assert result == ""
def test_missing_file(self):
result = capture_stop._extract_last_user_prompt("/nonexistent/path.jsonl")
assert result == ""
def test_empty_path(self):
result = capture_stop._extract_last_user_prompt("")
assert result == ""
# ---------------------------------------------------------------------------
# _infer_project
# ---------------------------------------------------------------------------
class TestInferProject:
def test_empty_cwd(self):
assert capture_stop._infer_project("") == ""
def test_unknown_path(self):
assert capture_stop._infer_project("C:\\Users\\antoi\\random") == ""
def test_mapped_path(self):
with mock.patch.dict(capture_stop._PROJECT_PATH_MAP, {
"C:\\Users\\antoi\\gigabit": "p04-gigabit",
}):
result = capture_stop._infer_project("C:\\Users\\antoi\\gigabit\\src")
assert result == "p04-gigabit"
# ---------------------------------------------------------------------------
# _capture (integration-style, mocking HTTP)
# ---------------------------------------------------------------------------
class TestCapture:
def _hook_input(self, *, transcript_path: str = "", **overrides) -> str:
data = {
"session_id": "test-session-123",
"transcript_path": transcript_path,
"cwd": "C:\\Users\\antoi\\ATOCore",
"permission_mode": "default",
"hook_event_name": "Stop",
"last_assistant_message": "Here is the answer to your question about the code.",
"turn_number": 3,
}
data.update(overrides)
return json.dumps(data)
@mock.patch("capture_stop.urllib.request.urlopen")
def test_posts_to_atocore(self, mock_urlopen, tmp_path):
transcript = _write_transcript(tmp_path, [
_user_entry("Please explain how the backup system works in detail"),
_assistant_entry(),
])
mock_resp = mock.MagicMock()
mock_resp.read.return_value = json.dumps({"id": "int-001", "status": "recorded"}).encode()
mock_urlopen.return_value = mock_resp
with mock.patch("sys.stdin", StringIO(self._hook_input(transcript_path=transcript))):
capture_stop._capture()
mock_urlopen.assert_called_once()
req = mock_urlopen.call_args[0][0]
body = json.loads(req.data.decode())
assert body["prompt"] == "Please explain how the backup system works in detail"
assert body["client"] == "claude-code"
assert body["session_id"] == "test-session-123"
assert body["reinforce"] is True
@mock.patch("capture_stop.urllib.request.urlopen")
def test_skips_when_disabled(self, mock_urlopen, tmp_path):
transcript = _write_transcript(tmp_path, [
_user_entry("A prompt that would normally be captured"),
])
with mock.patch.dict(os.environ, {"ATOCORE_CAPTURE_DISABLED": "1"}):
with mock.patch("sys.stdin", StringIO(self._hook_input(transcript_path=transcript))):
capture_stop._capture()
mock_urlopen.assert_not_called()
@mock.patch("capture_stop.urllib.request.urlopen")
def test_skips_short_prompt(self, mock_urlopen, tmp_path):
transcript = _write_transcript(tmp_path, [
_user_entry("yes"),
])
with mock.patch("sys.stdin", StringIO(self._hook_input(transcript_path=transcript))):
capture_stop._capture()
mock_urlopen.assert_not_called()
@mock.patch("capture_stop.urllib.request.urlopen")
def test_truncates_long_response(self, mock_urlopen, tmp_path):
transcript = _write_transcript(tmp_path, [
_user_entry("Tell me everything about the entire codebase architecture"),
])
long_response = "x" * 60_000
mock_resp = mock.MagicMock()
mock_resp.read.return_value = json.dumps({"id": "int-002"}).encode()
mock_urlopen.return_value = mock_resp
with mock.patch("sys.stdin", StringIO(
self._hook_input(transcript_path=transcript, last_assistant_message=long_response)
)):
capture_stop._capture()
req = mock_urlopen.call_args[0][0]
body = json.loads(req.data.decode())
assert len(body["response"]) <= capture_stop.MAX_RESPONSE_LENGTH + 20
assert body["response"].endswith("[truncated]")
def test_main_never_raises(self):
"""main() must always exit 0, even on garbage input."""
with mock.patch("sys.stdin", StringIO("not json at all")):
# Should not raise
capture_stop.main()
@mock.patch("capture_stop.urllib.request.urlopen")
def test_uses_atocore_url_env(self, mock_urlopen, tmp_path):
transcript = _write_transcript(tmp_path, [
_user_entry("Please help me with this particular problem in the code"),
])
mock_resp = mock.MagicMock()
mock_resp.read.return_value = json.dumps({"id": "int-003"}).encode()
mock_urlopen.return_value = mock_resp
with mock.patch.dict(os.environ, {"ATOCORE_URL": "http://localhost:9999"}):
# Re-read the env var
with mock.patch.object(capture_stop, "ATOCORE_URL", "http://localhost:9999"):
with mock.patch("sys.stdin", StringIO(self._hook_input(transcript_path=transcript))):
capture_stop._capture()
req = mock_urlopen.call_args[0][0]
assert req.full_url == "http://localhost:9999/interactions"