213 lines
5.9 KiB
Python
213 lines
5.9 KiB
Python
|
|
"""Context pack assembly: retrieve, rank, budget, format."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from atocore.config import settings
|
||
|
|
from atocore.observability.logger import get_logger
|
||
|
|
from atocore.retrieval.retriever import ChunkResult, retrieve
|
||
|
|
|
||
|
|
log = get_logger("context_builder")
|
||
|
|
|
||
|
|
SYSTEM_PREFIX = (
|
||
|
|
"You have access to the following personal context from the user's knowledge base.\n"
|
||
|
|
"Use it to inform your answer. If the context is not relevant, ignore it.\n"
|
||
|
|
"Do not mention the context system unless asked."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Last built context pack for debug inspection
|
||
|
|
_last_context_pack: "ContextPack | None" = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ContextChunk:
|
||
|
|
content: str
|
||
|
|
source_file: str
|
||
|
|
heading_path: str
|
||
|
|
score: float
|
||
|
|
char_count: int
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ContextPack:
|
||
|
|
chunks_used: list[ContextChunk] = field(default_factory=list)
|
||
|
|
total_chars: int = 0
|
||
|
|
budget: int = 0
|
||
|
|
budget_remaining: int = 0
|
||
|
|
formatted_context: str = ""
|
||
|
|
full_prompt: str = ""
|
||
|
|
query: str = ""
|
||
|
|
project_hint: str = ""
|
||
|
|
duration_ms: int = 0
|
||
|
|
|
||
|
|
|
||
|
|
def build_context(
|
||
|
|
user_prompt: str,
|
||
|
|
project_hint: str | None = None,
|
||
|
|
budget: int | None = None,
|
||
|
|
) -> ContextPack:
|
||
|
|
"""Build a context pack for a user prompt."""
|
||
|
|
global _last_context_pack
|
||
|
|
start = time.time()
|
||
|
|
budget = budget or settings.context_budget
|
||
|
|
|
||
|
|
# 1. Retrieve candidates
|
||
|
|
candidates = retrieve(user_prompt, top_k=settings.context_top_k)
|
||
|
|
|
||
|
|
# 2. Score and rank
|
||
|
|
scored = _rank_chunks(candidates, project_hint)
|
||
|
|
|
||
|
|
# 3. Select within budget
|
||
|
|
selected = _select_within_budget(scored, budget)
|
||
|
|
|
||
|
|
# 4. Format
|
||
|
|
formatted = _format_context_block(selected)
|
||
|
|
|
||
|
|
# 5. Build full prompt
|
||
|
|
full_prompt = f"{SYSTEM_PREFIX}\n\n{formatted}\n\n{user_prompt}"
|
||
|
|
|
||
|
|
total_chars = sum(c.char_count for c in selected)
|
||
|
|
duration_ms = int((time.time() - start) * 1000)
|
||
|
|
|
||
|
|
pack = ContextPack(
|
||
|
|
chunks_used=selected,
|
||
|
|
total_chars=total_chars,
|
||
|
|
budget=budget,
|
||
|
|
budget_remaining=budget - total_chars,
|
||
|
|
formatted_context=formatted,
|
||
|
|
full_prompt=full_prompt,
|
||
|
|
query=user_prompt,
|
||
|
|
project_hint=project_hint or "",
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
_last_context_pack = pack
|
||
|
|
|
||
|
|
log.info(
|
||
|
|
"context_built",
|
||
|
|
chunks_used=len(selected),
|
||
|
|
total_chars=total_chars,
|
||
|
|
budget_remaining=budget - total_chars,
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
log.debug("context_pack_detail", pack=_pack_to_dict(pack))
|
||
|
|
|
||
|
|
return pack
|
||
|
|
|
||
|
|
|
||
|
|
def get_last_context_pack() -> ContextPack | None:
|
||
|
|
"""Return the last built context pack for debug inspection."""
|
||
|
|
return _last_context_pack
|
||
|
|
|
||
|
|
|
||
|
|
def _rank_chunks(
|
||
|
|
candidates: list[ChunkResult],
|
||
|
|
project_hint: str | None,
|
||
|
|
) -> list[tuple[float, ChunkResult]]:
|
||
|
|
"""Rank candidates with boosting for project match."""
|
||
|
|
scored = []
|
||
|
|
seen_content: set[str] = set()
|
||
|
|
|
||
|
|
for chunk in candidates:
|
||
|
|
# Deduplicate by content prefix (first 200 chars)
|
||
|
|
content_key = chunk.content[:200]
|
||
|
|
if content_key in seen_content:
|
||
|
|
continue
|
||
|
|
seen_content.add(content_key)
|
||
|
|
|
||
|
|
# Base score from similarity
|
||
|
|
final_score = chunk.score
|
||
|
|
|
||
|
|
# Project boost
|
||
|
|
if project_hint:
|
||
|
|
tags_str = chunk.tags.lower() if chunk.tags else ""
|
||
|
|
source_str = chunk.source_file.lower()
|
||
|
|
title_str = chunk.title.lower() if chunk.title else ""
|
||
|
|
hint_lower = project_hint.lower()
|
||
|
|
|
||
|
|
if hint_lower in tags_str or hint_lower in source_str or hint_lower in title_str:
|
||
|
|
final_score += 0.3
|
||
|
|
|
||
|
|
scored.append((final_score, chunk))
|
||
|
|
|
||
|
|
# Sort by score descending
|
||
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
||
|
|
return scored
|
||
|
|
|
||
|
|
|
||
|
|
def _select_within_budget(
|
||
|
|
scored: list[tuple[float, ChunkResult]],
|
||
|
|
budget: int,
|
||
|
|
) -> list[ContextChunk]:
|
||
|
|
"""Select top chunks that fit within the character budget."""
|
||
|
|
selected = []
|
||
|
|
used = 0
|
||
|
|
|
||
|
|
for score, chunk in scored:
|
||
|
|
chunk_len = len(chunk.content)
|
||
|
|
if used + chunk_len > budget:
|
||
|
|
continue
|
||
|
|
selected.append(
|
||
|
|
ContextChunk(
|
||
|
|
content=chunk.content,
|
||
|
|
source_file=_shorten_path(chunk.source_file),
|
||
|
|
heading_path=chunk.heading_path,
|
||
|
|
score=score,
|
||
|
|
char_count=chunk_len,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
used += chunk_len
|
||
|
|
|
||
|
|
return selected
|
||
|
|
|
||
|
|
|
||
|
|
def _format_context_block(chunks: list[ContextChunk]) -> str:
|
||
|
|
"""Format chunks into the context block string."""
|
||
|
|
if not chunks:
|
||
|
|
return "--- AtoCore Context ---\nNo relevant context found.\n--- End Context ---"
|
||
|
|
|
||
|
|
lines = ["--- AtoCore Context ---"]
|
||
|
|
for chunk in chunks:
|
||
|
|
lines.append(
|
||
|
|
f"[Source: {chunk.source_file} | Section: {chunk.heading_path} | Score: {chunk.score:.2f}]"
|
||
|
|
)
|
||
|
|
lines.append(chunk.content)
|
||
|
|
lines.append("")
|
||
|
|
lines.append("--- End Context ---")
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
def _shorten_path(path: str) -> str:
|
||
|
|
"""Shorten an absolute path to a relative-like display."""
|
||
|
|
p = Path(path)
|
||
|
|
parts = p.parts
|
||
|
|
# Show last 3 parts at most
|
||
|
|
if len(parts) > 3:
|
||
|
|
return str(Path(*parts[-3:]))
|
||
|
|
return str(p)
|
||
|
|
|
||
|
|
|
||
|
|
def _pack_to_dict(pack: ContextPack) -> dict:
|
||
|
|
"""Convert a context pack to a JSON-serializable dict."""
|
||
|
|
return {
|
||
|
|
"query": pack.query,
|
||
|
|
"project_hint": pack.project_hint,
|
||
|
|
"chunks_used": len(pack.chunks_used),
|
||
|
|
"total_chars": pack.total_chars,
|
||
|
|
"budget": pack.budget,
|
||
|
|
"budget_remaining": pack.budget_remaining,
|
||
|
|
"duration_ms": pack.duration_ms,
|
||
|
|
"chunks": [
|
||
|
|
{
|
||
|
|
"source_file": c.source_file,
|
||
|
|
"heading_path": c.heading_path,
|
||
|
|
"score": c.score,
|
||
|
|
"char_count": c.char_count,
|
||
|
|
"content_preview": c.content[:100],
|
||
|
|
}
|
||
|
|
for c in pack.chunks_used
|
||
|
|
],
|
||
|
|
}
|