2024-09-17 21:24:35 +02:00
|
|
|
import logging
|
2024-06-21 23:30:19 +02:00
|
|
|
import os
|
2024-10-07 22:30:45 +02:00
|
|
|
from typing import List, Optional, Union
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
import numpy as np
|
2024-06-21 23:30:19 +02:00
|
|
|
import onnxruntime as ort
|
2024-10-07 22:30:45 +02:00
|
|
|
from onnx_clip import OnnxClip, Preprocessor, Tokenizer
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
|
|
|
from frigate.types import ModelStatusTypesEnum
|
|
|
|
from frigate.util.downloader import ModelDownloader
|
2024-06-21 23:30:19 +02:00
|
|
|
|
2024-10-07 22:30:45 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
2024-06-21 23:30:19 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Clip(OnnxClip):
|
2024-10-07 22:30:45 +02:00
|
|
|
"""Override load models to use pre-downloaded models from cache directory."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: str = "ViT-B/32",
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
providers: List[str] = ["CPUExecutionProvider"],
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Instantiates the model and required encoding classes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: The model to utilize. Currently ViT-B/32 and RN50 are
|
|
|
|
allowed.
|
|
|
|
batch_size: If set, splits the lists in `get_image_embeddings`
|
|
|
|
and `get_text_embeddings` into batches of this size before
|
|
|
|
passing them to the model. The embeddings are then concatenated
|
|
|
|
back together before being returned. This is necessary when
|
|
|
|
passing large amounts of data (perhaps ~100 or more).
|
|
|
|
"""
|
|
|
|
allowed_models = ["ViT-B/32", "RN50"]
|
|
|
|
if model not in allowed_models:
|
|
|
|
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
|
|
|
|
if model == "ViT-B/32":
|
|
|
|
self.embedding_size = 512
|
|
|
|
elif model == "RN50":
|
|
|
|
self.embedding_size = 1024
|
|
|
|
self.image_model, self.text_model = self._load_models(model, providers)
|
|
|
|
self._tokenizer = Tokenizer()
|
|
|
|
self._preprocessor = Preprocessor()
|
|
|
|
self._batch_size = batch_size
|
2024-06-21 23:30:19 +02:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _load_models(
|
|
|
|
model: str,
|
2024-10-07 22:30:45 +02:00
|
|
|
providers: List[str],
|
|
|
|
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
2024-06-21 23:30:19 +02:00
|
|
|
"""
|
2024-10-07 22:30:45 +02:00
|
|
|
Load models from cache directory.
|
2024-06-21 23:30:19 +02:00
|
|
|
"""
|
|
|
|
if model == "ViT-B/32":
|
|
|
|
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
|
|
|
|
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
|
|
|
|
elif model == "RN50":
|
|
|
|
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
|
|
|
|
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
|
|
|
|
|
|
|
|
models = []
|
|
|
|
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
|
|
|
|
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
|
2024-10-07 22:30:45 +02:00
|
|
|
models.append(Clip._load_model(path, providers))
|
2024-06-21 23:30:19 +02:00
|
|
|
|
|
|
|
return models[0], models[1]
|
|
|
|
|
2024-09-17 21:24:35 +02:00
|
|
|
@staticmethod
|
2024-10-07 22:30:45 +02:00
|
|
|
def _load_model(path: str, providers: List[str]):
|
|
|
|
if os.path.exists(path):
|
|
|
|
return ort.InferenceSession(path, providers=providers)
|
|
|
|
else:
|
|
|
|
logger.warning(f"CLIP model file {path} not found.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class ClipEmbedding:
|
|
|
|
"""Embedding function for CLIP model."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: str = "ViT-B/32",
|
|
|
|
silent: bool = False,
|
|
|
|
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
|
|
|
):
|
|
|
|
self.model_name = model
|
|
|
|
self.silent = silent
|
|
|
|
self.preferred_providers = preferred_providers
|
|
|
|
self.model_files = self._get_model_files()
|
|
|
|
self.model = None
|
|
|
|
|
|
|
|
self.downloader = ModelDownloader(
|
|
|
|
model_name="clip",
|
|
|
|
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
|
|
|
|
file_names=self.model_files,
|
|
|
|
download_func=self._download_model,
|
|
|
|
silent=self.silent,
|
|
|
|
)
|
|
|
|
self.downloader.ensure_model_files()
|
|
|
|
|
|
|
|
def _get_model_files(self):
|
|
|
|
if self.model_name == "ViT-B/32":
|
|
|
|
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
|
|
|
|
elif self.model_name == "RN50":
|
|
|
|
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unexpected model {self.model_name}. No `.onnx` file found."
|
|
|
|
)
|
|
|
|
|
|
|
|
def _download_model(self, path: str):
|
|
|
|
s3_url = (
|
|
|
|
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
|
|
|
|
)
|
2024-09-17 21:24:35 +02:00
|
|
|
try:
|
2024-10-07 22:30:45 +02:00
|
|
|
ModelDownloader.download_from_url(s3_url, path, self.silent)
|
|
|
|
self.downloader.requestor.send_data(
|
|
|
|
UPDATE_MODEL_STATE,
|
|
|
|
{
|
|
|
|
"model": f"{self.model_name}-{os.path.basename(path)}",
|
|
|
|
"state": ModelStatusTypesEnum.downloaded,
|
|
|
|
},
|
|
|
|
)
|
2024-09-17 21:24:35 +02:00
|
|
|
except Exception:
|
2024-10-07 22:30:45 +02:00
|
|
|
self.downloader.requestor.send_data(
|
|
|
|
UPDATE_MODEL_STATE,
|
|
|
|
{
|
|
|
|
"model": f"{self.model_name}-{os.path.basename(path)}",
|
|
|
|
"state": ModelStatusTypesEnum.error,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
def _load_model(self):
|
|
|
|
if self.model is None:
|
|
|
|
self.downloader.wait_for_download()
|
|
|
|
self.model = Clip(self.model_name, providers=self.preferred_providers)
|
|
|
|
|
|
|
|
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
|
|
|
|
self._load_model()
|
|
|
|
if (
|
|
|
|
self.model is None
|
|
|
|
or self.model.image_model is None
|
|
|
|
or self.model.text_model is None
|
|
|
|
):
|
|
|
|
logger.info(
|
|
|
|
"CLIP model is not fully loaded. Please wait for the download to complete."
|
|
|
|
)
|
|
|
|
return []
|
|
|
|
|
|
|
|
embeddings = []
|
2024-06-21 23:30:19 +02:00
|
|
|
for item in input:
|
2024-10-07 22:30:45 +02:00
|
|
|
if isinstance(item, Image.Image):
|
2024-06-21 23:30:19 +02:00
|
|
|
result = self.model.get_image_embeddings([item])
|
2024-10-07 22:30:45 +02:00
|
|
|
embeddings.append(result[0])
|
|
|
|
elif isinstance(item, str):
|
2024-06-21 23:30:19 +02:00
|
|
|
result = self.model.get_text_embeddings([item])
|
2024-10-07 22:30:45 +02:00
|
|
|
embeddings.append(result[0])
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported input type: {type(item)}")
|
2024-06-21 23:30:19 +02:00
|
|
|
return embeddings
|