blakeblackshear.frigate/frigate/embeddings/functions/minilm_l6_v2.py
Josh Hawkins 24ac9f3e5a
Use sqlite-vec extension instead of chromadb for embeddings (#14163)
* swap sqlite_vec for chroma in requirements

* load sqlite_vec in embeddings manager

* remove chroma and revamp Embeddings class for sqlite_vec

* manual minilm onnx inference

* remove chroma in clip model

* migrate api from chroma to sqlite_vec

* migrate event cleanup from chroma to sqlite_vec

* migrate embedding maintainer from chroma to sqlite_vec

* genai description for sqlite_vec

* load sqlite_vec in main thread db

* extend the SqliteQueueDatabase class and use peewee db.execute_sql

* search with Event type for similarity

* fix similarity search

* install and add comment about transformers

* fix normalization

* add id filter

* clean up

* clean up

* fully remove chroma and add transformers env var

* readd uvicorn for fastapi

* readd tokenizer parallelism env var

* remove chroma from docs

* remove chroma from UI

* try removing custom pysqlite3 build

* hard code limit

* optimize queries

* revert explore query

* fix query

* keep building pysqlite3

* single pass fetch and process

* remove unnecessary re-embed

* update deps

* move SqliteVecQueueDatabase to db directory

* make search thumbnail take up full size of results box

* improve typing

* improve model downloading and add status screen

* daemon downloading thread

* catch case when semantic search is disabled

* fix typing

* build sqlite_vec from source

* resolve conflict

* file permissions

* try build deps

* remove sources

* sources

* fix thread start

* include git in build

* reorder embeddings after detectors are started

* build with sqlite amalgamation

* non-platform specific

* use wget instead of curl

* remove unzip -d

* remove sqlite_vec from requirements and load the compiled version

* fix build

* avoid race in db connection

* add scale_factor and bias to description zscore normalization
2024-10-07 14:30:45 -06:00

108 lines
3.9 KiB
Python

import logging
import os
from typing import List
import numpy as np
import onnxruntime as ort
# importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from transformers import AutoTokenizer
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class MiniLMEmbedding:
"""Embedding function for ONNX MiniLM-L6 model."""
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
IMAGE_MODEL_FILE = "model.onnx"
TOKENIZER_FILE = "tokenizer"
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
self.preferred_providers = preferred_providers
self.tokenizer = None
self.session = None
self.downloader = ModelDownloader(
model_name=self.MODEL_NAME,
download_path=self.DOWNLOAD_PATH,
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
download_func=self._download_model,
)
self.downloader.ensure_model_files()
def _download_model(self, path: str):
try:
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
ModelDownloader.download_from_url(s3_url, path)
elif os.path.basename(path) == self.TOKENIZER_FILE:
logger.info("Downloading MiniLM tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.MODEL_NAME, clean_up_tokenization_spaces=True
)
tokenizer.save_pretrained(path)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model_and_tokenizer(self):
if self.tokenizer is None or self.session is None:
self.downloader.wait_for_download()
self.tokenizer = self._load_tokenizer()
self.session = self._load_model(
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
self.preferred_providers,
)
def _load_tokenizer(self):
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
return AutoTokenizer.from_pretrained(
tokenizer_path, clean_up_tokenization_spaces=True
)
def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
logger.warning(f"MiniLM model file {path} not found.")
return None
def __call__(self, texts: List[str]) -> List[np.ndarray]:
self._load_model_and_tokenizer()
if self.session is None or self.tokenizer is None:
logger.error("MiniLM model or tokenizer is not loaded.")
return []
inputs = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="np"
)
input_names = [input.name for input in self.session.get_inputs()]
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
outputs = self.session.run(None, onnx_inputs)
embeddings = outputs[0].mean(axis=1)
return [embedding for embedding in embeddings]