33 lines
997 B
Python
33 lines
997 B
Python
|
|
"""Embedding model management."""
|
||
|
|
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
|
||
|
|
from atocore.config import settings
|
||
|
|
from atocore.observability.logger import get_logger
|
||
|
|
|
||
|
|
log = get_logger("embeddings")
|
||
|
|
|
||
|
|
_model: SentenceTransformer | None = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_model() -> SentenceTransformer:
|
||
|
|
"""Load and cache the embedding model."""
|
||
|
|
global _model
|
||
|
|
if _model is None:
|
||
|
|
log.info("loading_embedding_model", model=settings.embedding_model)
|
||
|
|
_model = SentenceTransformer(settings.embedding_model)
|
||
|
|
log.info("embedding_model_loaded", model=settings.embedding_model)
|
||
|
|
return _model
|
||
|
|
|
||
|
|
|
||
|
|
def embed_texts(texts: list[str]) -> list[list[float]]:
|
||
|
|
"""Generate embeddings for a list of texts."""
|
||
|
|
model = get_model()
|
||
|
|
embeddings = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
|
||
|
|
return embeddings.tolist()
|
||
|
|
|
||
|
|
|
||
|
|
def embed_query(query: str) -> list[float]:
|
||
|
|
"""Generate embedding for a single query."""
|
||
|
|
return embed_texts([query])[0]
|