import logging import os import warnings from enum import Enum from io import BytesIO from typing import Dict, List, Optional, Union import numpy as np import requests from PIL import Image # importing this without pytorch or others causes a warning # https://github.com/huggingface/transformers/issues/27214 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 from transformers import AutoFeatureExtractor, AutoTokenizer from transformers.utils.logging import disable_progress_bar from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader from frigate.util.model import ONNXModelRunner warnings.filterwarnings( "ignore", category=FutureWarning, message="The class CLIPFeatureExtractor is deprecated", ) # disables the progress bar for downloading tokenizers and feature extractors disable_progress_bar() logger = logging.getLogger(__name__) class ModelTypeEnum(str, Enum): face = "face" vision = "vision" text = "text" class GenericONNXEmbedding: """Generic embedding function for ONNX models (text and vision).""" def __init__( self, model_name: str, model_file: str, download_urls: Dict[str, str], model_size: str, model_type: str, requestor: InterProcessRequestor, tokenizer_file: Optional[str] = None, device: str = "AUTO", ): self.model_name = model_name self.model_file = model_file self.tokenizer_file = tokenizer_file self.requestor = requestor self.download_urls = download_urls self.model_type = model_type # 'text' or 'vision' self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.tokenizer = None self.feature_extractor = None self.runner = None files_names = list(self.download_urls.keys()) + ( [self.tokenizer_file] if self.tokenizer_file else [] ) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names ): logger.debug(f"starting model download for {self.model_name}") self.downloader = ModelDownloader( model_name=self.model_name, download_path=self.download_path, file_names=files_names, download_func=self._download_model, ) self.downloader.ensure_model_files() else: self.downloader = None ModelDownloader.mark_files_state( self.requestor, self.model_name, files_names, ModelStatusTypesEnum.downloaded, ) self._load_model_and_tokenizer() logger.debug(f"models are already downloaded for {self.model_name}") def _download_model(self, path: str): try: file_name = os.path.basename(path) if file_name in self.download_urls: ModelDownloader.download_from_url(self.download_urls[file_name], path) elif ( file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text ): if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer") tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer", clean_up_tokenization_spaces=True, ) tokenizer.save_pretrained(path) self.downloader.requestor.send_data( UPDATE_MODEL_STATE, { "model": f"{self.model_name}-{file_name}", "state": ModelStatusTypesEnum.downloaded, }, ) except Exception: self.downloader.requestor.send_data( UPDATE_MODEL_STATE, { "model": f"{self.model_name}-{file_name}", "state": ModelStatusTypesEnum.error, }, ) def _load_model_and_tokenizer(self): if self.runner is None: if self.downloader: self.downloader.wait_for_download() if self.model_type == ModelTypeEnum.text: self.tokenizer = self._load_tokenizer() else: self.feature_extractor = self._load_feature_extractor() self.runner = ONNXModelRunner( os.path.join(self.download_path, self.model_file), self.device, self.model_size, ) def _load_tokenizer(self): tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") return AutoTokenizer.from_pretrained( self.model_name, cache_dir=tokenizer_path, trust_remote_code=True, clean_up_tokenization_spaces=True, ) def _load_feature_extractor(self): return AutoFeatureExtractor.from_pretrained( f"{MODEL_CACHE_DIR}/{self.model_name}", ) def _preprocess_inputs(self, raw_inputs: any) -> any: if self.model_type == ModelTypeEnum.text: max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) return [ self.tokenizer( text, padding="max_length", truncation=True, max_length=max_length, return_tensors="np", ) for text in raw_inputs ] elif self.model_type == ModelTypeEnum.vision: processed_images = [self._process_image(img) for img in raw_inputs] return [ self.feature_extractor(images=image, return_tensors="np") for image in processed_images ] else: raise ValueError(f"Unable to preprocess inputs for {self.model_type}") def _process_image(self, image): if isinstance(image, str): if image.startswith("http"): response = requests.get(image) image = Image.open(BytesIO(response.content)).convert("RGB") elif isinstance(image, bytes): image = Image.open(BytesIO(image)).convert("RGB") return image def __call__( self, inputs: Union[List[str], List[Image.Image], List[str]] ) -> List[np.ndarray]: self._load_model_and_tokenizer() if self.runner is None or ( self.tokenizer is None and self.feature_extractor is None ): logger.error( f"{self.model_name} model or tokenizer/feature extractor is not loaded." ) return [] processed_inputs = self._preprocess_inputs(inputs) input_names = self.runner.get_input_names() onnx_inputs = {name: [] for name in input_names} input: dict[str, any] for input in processed_inputs: for key, value in input.items(): if key in input_names: onnx_inputs[key].append(value[0]) for key in input_names: if onnx_inputs.get(key): onnx_inputs[key] = np.stack(onnx_inputs[key]) else: logger.warning(f"Expected input '{key}' not found in onnx_inputs") embeddings = self.runner.run(onnx_inputs)[0] return [embedding for embedding in embeddings]