2024-10-07 22:30:45 +02:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from typing import List
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
import numpy as np
|
|
|
|
import onnxruntime as ort
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
# importing this without pytorch or others causes a warning
|
|
|
|
# https://github.com/huggingface/transformers/issues/27214
|
|
|
|
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
|
|
|
from transformers import AutoTokenizer
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
|
|
|
from frigate.types import ModelStatusTypesEnum
|
|
|
|
from frigate.util.downloader import ModelDownloader
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class MiniLMEmbedding:
|
|
|
|
"""Embedding function for ONNX MiniLM-L6 model."""
|
2024-06-21 23:30:19 +02:00
|
|
|
|
|
|
|
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
2024-10-07 22:30:45 +02:00
|
|
|
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
IMAGE_MODEL_FILE = "model.onnx"
|
|
|
|
TOKENIZER_FILE = "tokenizer"
|
|
|
|
|
|
|
|
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
|
|
|
|
self.preferred_providers = preferred_providers
|
|
|
|
self.tokenizer = None
|
|
|
|
self.session = None
|
|
|
|
|
|
|
|
self.downloader = ModelDownloader(
|
|
|
|
model_name=self.MODEL_NAME,
|
|
|
|
download_path=self.DOWNLOAD_PATH,
|
|
|
|
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
|
|
|
|
download_func=self._download_model,
|
|
|
|
)
|
|
|
|
self.downloader.ensure_model_files()
|
|
|
|
|
|
|
|
def _download_model(self, path: str):
|
|
|
|
try:
|
|
|
|
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
|
|
|
|
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
|
|
|
|
ModelDownloader.download_from_url(s3_url, path)
|
|
|
|
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
|
|
|
logger.info("Downloading MiniLM tokenizer")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
self.MODEL_NAME, clean_up_tokenization_spaces=True
|
|
|
|
)
|
|
|
|
tokenizer.save_pretrained(path)
|
|
|
|
|
|
|
|
self.downloader.requestor.send_data(
|
|
|
|
UPDATE_MODEL_STATE,
|
|
|
|
{
|
|
|
|
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
|
|
|
"state": ModelStatusTypesEnum.downloaded,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
self.downloader.requestor.send_data(
|
|
|
|
UPDATE_MODEL_STATE,
|
|
|
|
{
|
|
|
|
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
|
|
|
"state": ModelStatusTypesEnum.error,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
def _load_model_and_tokenizer(self):
|
|
|
|
if self.tokenizer is None or self.session is None:
|
|
|
|
self.downloader.wait_for_download()
|
|
|
|
self.tokenizer = self._load_tokenizer()
|
|
|
|
self.session = self._load_model(
|
|
|
|
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
|
|
|
|
self.preferred_providers,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _load_tokenizer(self):
|
|
|
|
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
|
|
|
return AutoTokenizer.from_pretrained(
|
|
|
|
tokenizer_path, clean_up_tokenization_spaces=True
|
|
|
|
)
|
|
|
|
|
|
|
|
def _load_model(self, path: str, providers: List[str]):
|
|
|
|
if os.path.exists(path):
|
|
|
|
return ort.InferenceSession(path, providers=providers)
|
|
|
|
else:
|
|
|
|
logger.warning(f"MiniLM model file {path} not found.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
def __call__(self, texts: List[str]) -> List[np.ndarray]:
|
|
|
|
self._load_model_and_tokenizer()
|
|
|
|
|
|
|
|
if self.session is None or self.tokenizer is None:
|
|
|
|
logger.error("MiniLM model or tokenizer is not loaded.")
|
|
|
|
return []
|
|
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
|
texts, padding=True, truncation=True, return_tensors="np"
|
|
|
|
)
|
|
|
|
input_names = [input.name for input in self.session.get_inputs()]
|
|
|
|
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
|
|
|
|
|
|
|
|
outputs = self.session.run(None, onnx_inputs)
|
|
|
|
embeddings = outputs[0].mean(axis=1)
|
|
|
|
|
|
|
|
return [embedding for embedding in embeddings]
|