fix(context): rank trusted state by query relevance

This commit is contained in:
2026-04-24 20:54:56 -04:00
parent 0fc6705173
commit d3de9f67ea
3 changed files with 102 additions and 3 deletions

View File

@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
from pathlib import Path
import atocore.config as _config
from atocore.context.project_state import format_project_state, get_state
from atocore.context.project_state import ProjectStateEntry, format_project_state, get_state
from atocore.memory.service import get_memories_for_context
from atocore.observability.logger import get_logger
from atocore.engineering.service import get_entities, get_entity_with_context
@@ -116,6 +116,11 @@ def build_context(
if canonical_project:
state_entries = get_state(canonical_project)
if state_entries:
state_entries = _rank_project_state_entries(
state_entries,
query=user_prompt,
project=canonical_project,
)
project_state_text = format_project_state(state_entries)
project_state_text, project_state_chars = _truncate_text_block(
project_state_text,
@@ -284,6 +289,55 @@ def get_last_context_pack() -> ContextPack | None:
return _last_context_pack
def _rank_project_state_entries(
entries: list[ProjectStateEntry],
query: str,
project: str,
) -> list[ProjectStateEntry]:
"""Promote query-relevant trusted state before the state band is truncated."""
if not query or len(entries) <= 1:
return entries
from atocore.memory.reinforcement import _normalize, _tokenize
query_text = _normalize(query.replace("_", " "))
query_tokens = set(_tokenize(query_text))
query_tokens -= {
"how",
"what",
"when",
"where",
"which",
"who",
"why",
"current",
"status",
"project",
}
for part in (project or "").lower().replace("_", "-").split("-"):
query_tokens.discard(part)
if not query_tokens:
return entries
scored: list[tuple[int, float, float, int, ProjectStateEntry]] = []
for index, entry in enumerate(entries):
entry_text = " ".join(
[
entry.category,
entry.key.replace("_", " "),
entry.value,
entry.source,
]
)
entry_tokens = _tokenize(_normalize(entry_text))
overlap = len(entry_tokens & query_tokens) if entry_tokens else 0
density = overlap / len(entry_tokens) if entry_tokens else 0.0
scored.append((overlap, density, entry.confidence, -index, entry))
scored.sort(key=lambda item: (item[0], item[1], item[2], item[3]), reverse=True)
return [entry for _, _, _, _, entry in scored]
def _rank_chunks(
candidates: list[ChunkResult],
project_hint: str | None,