33 lines
1018 B
Python
33 lines
1018 B
Python
"""Embedding model management."""
|
|
|
|
import atocore.config as _config
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
from atocore.observability.logger import get_logger
|
|
|
|
log = get_logger("embeddings")
|
|
|
|
_model: SentenceTransformer | None = None
|
|
|
|
|
|
def get_model() -> SentenceTransformer:
|
|
"""Load and cache the embedding model."""
|
|
global _model
|
|
if _model is None:
|
|
log.info("loading_embedding_model", model=_config.settings.embedding_model)
|
|
_model = SentenceTransformer(_config.settings.embedding_model)
|
|
log.info("embedding_model_loaded", model=_config.settings.embedding_model)
|
|
return _model
|
|
|
|
|
|
def embed_texts(texts: list[str]) -> list[list[float]]:
|
|
"""Generate embeddings for a list of texts."""
|
|
model = get_model()
|
|
embeddings = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
|
|
return embeddings.tolist()
|
|
|
|
|
|
def embed_query(query: str) -> list[float]:
|
|
"""Generate embedding for a single query."""
|
|
return embed_texts([query])[0]
|