blakeblackshear.frigate/frigate/embeddings/functions/clip.py
Josh Hawkins 24ac9f3e5a
Use sqlite-vec extension instead of chromadb for embeddings (#14163)
* swap sqlite_vec for chroma in requirements

* load sqlite_vec in embeddings manager

* remove chroma and revamp Embeddings class for sqlite_vec

* manual minilm onnx inference

* remove chroma in clip model

* migrate api from chroma to sqlite_vec

* migrate event cleanup from chroma to sqlite_vec

* migrate embedding maintainer from chroma to sqlite_vec

* genai description for sqlite_vec

* load sqlite_vec in main thread db

* extend the SqliteQueueDatabase class and use peewee db.execute_sql

* search with Event type for similarity

* fix similarity search

* install and add comment about transformers

* fix normalization

* add id filter

* clean up

* clean up

* fully remove chroma and add transformers env var

* readd uvicorn for fastapi

* readd tokenizer parallelism env var

* remove chroma from docs

* remove chroma from UI

* try removing custom pysqlite3 build

* hard code limit

* optimize queries

* revert explore query

* fix query

* keep building pysqlite3

* single pass fetch and process

* remove unnecessary re-embed

* update deps

* move SqliteVecQueueDatabase to db directory

* make search thumbnail take up full size of results box

* improve typing

* improve model downloading and add status screen

* daemon downloading thread

* catch case when semantic search is disabled

* fix typing

* build sqlite_vec from source

* resolve conflict

* file permissions

* try build deps

* remove sources

* sources

* fix thread start

* include git in build

* reorder embeddings after detectors are started

* build with sqlite amalgamation

* non-platform specific

* use wget instead of curl

* remove unzip -d

* remove sqlite_vec from requirements and load the compiled version

* fix build

* avoid race in db connection

* add scale_factor and bias to description zscore normalization
2024-10-07 14:30:45 -06:00

167 lines
5.8 KiB
Python

import logging
import os
from typing import List, Optional, Union
import numpy as np
import onnxruntime as ort
from onnx_clip import OnnxClip, Preprocessor, Tokenizer
from PIL import Image
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class Clip(OnnxClip):
"""Override load models to use pre-downloaded models from cache directory."""
def __init__(
self,
model: str = "ViT-B/32",
batch_size: Optional[int] = None,
providers: List[str] = ["CPUExecutionProvider"],
):
"""
Instantiates the model and required encoding classes.
Args:
model: The model to utilize. Currently ViT-B/32 and RN50 are
allowed.
batch_size: If set, splits the lists in `get_image_embeddings`
and `get_text_embeddings` into batches of this size before
passing them to the model. The embeddings are then concatenated
back together before being returned. This is necessary when
passing large amounts of data (perhaps ~100 or more).
"""
allowed_models = ["ViT-B/32", "RN50"]
if model not in allowed_models:
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
if model == "ViT-B/32":
self.embedding_size = 512
elif model == "RN50":
self.embedding_size = 1024
self.image_model, self.text_model = self._load_models(model, providers)
self._tokenizer = Tokenizer()
self._preprocessor = Preprocessor()
self._batch_size = batch_size
@staticmethod
def _load_models(
model: str,
providers: List[str],
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""
Load models from cache directory.
"""
if model == "ViT-B/32":
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
elif model == "RN50":
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
else:
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
models = []
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
models.append(Clip._load_model(path, providers))
return models[0], models[1]
@staticmethod
def _load_model(path: str, providers: List[str]):
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
logger.warning(f"CLIP model file {path} not found.")
return None
class ClipEmbedding:
"""Embedding function for CLIP model."""
def __init__(
self,
model: str = "ViT-B/32",
silent: bool = False,
preferred_providers: List[str] = ["CPUExecutionProvider"],
):
self.model_name = model
self.silent = silent
self.preferred_providers = preferred_providers
self.model_files = self._get_model_files()
self.model = None
self.downloader = ModelDownloader(
model_name="clip",
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
file_names=self.model_files,
download_func=self._download_model,
silent=self.silent,
)
self.downloader.ensure_model_files()
def _get_model_files(self):
if self.model_name == "ViT-B/32":
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
elif self.model_name == "RN50":
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
else:
raise ValueError(
f"Unexpected model {self.model_name}. No `.onnx` file found."
)
def _download_model(self, path: str):
s3_url = (
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
)
try:
ModelDownloader.download_from_url(s3_url, path, self.silent)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model(self):
if self.model is None:
self.downloader.wait_for_download()
self.model = Clip(self.model_name, providers=self.preferred_providers)
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
self._load_model()
if (
self.model is None
or self.model.image_model is None
or self.model.text_model is None
):
logger.info(
"CLIP model is not fully loaded. Please wait for the download to complete."
)
return []
embeddings = []
for item in input:
if isinstance(item, Image.Image):
result = self.model.get_image_embeddings([item])
embeddings.append(result[0])
elif isinstance(item, str):
result = self.model.get_text_embeddings([item])
embeddings.append(result[0])
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return embeddings