Files
ATOCore/src/atocore/retrieval/embeddings.py

33 lines
1018 B
Python

"""Embedding model management."""
import atocore.config as _config
from sentence_transformers import SentenceTransformer
from atocore.observability.logger import get_logger
log = get_logger("embeddings")
_model: SentenceTransformer | None = None
def get_model() -> SentenceTransformer:
"""Load and cache the embedding model."""
global _model
if _model is None:
log.info("loading_embedding_model", model=_config.settings.embedding_model)
_model = SentenceTransformer(_config.settings.embedding_model)
log.info("embedding_model_loaded", model=_config.settings.embedding_model)
return _model
def embed_texts(texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a list of texts."""
model = get_model()
embeddings = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
return embeddings.tolist()
def embed_query(query: str) -> list[float]:
"""Generate embedding for a single query."""
return embed_texts([query])[0]