blakeblackshear.frigate/frigate/embeddings/onnx/base_embedding.py
Nicolas Mowen c736b1dae5
Refactor ONNX embedding class to use a base class and type-specific classes (#16703)
* Move onnx runner

* Build out base embedding

* Convert text embedding to separate class

* Move image embedding to separate

* Move LPR to separate class

* Remove mono embedding

* Simplify model downloading

* Reorganize jina v1 embeddings

* Cleanup

* Cleanup for review
2025-02-20 10:17:07 -06:00

96 lines
2.9 KiB
Python

"""Base class for onnx embedding implementations."""
import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from io import BytesIO
import numpy as np
import requests
from PIL import Image
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class BaseEmbedding(ABC):
"""Base embedding class."""
def __init__(self, model_name: str, model_file: str, download_urls: dict[str, str]):
self.model_name = model_name
self.model_file = model_file
self.download_urls = download_urls
self.downloader: ModelDownloader = None
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.error,
},
)
@abstractmethod
def _load_model_and_utils(self):
pass
@abstractmethod
def _preprocess_inputs(self, raw_inputs: any) -> any:
pass
def _process_image(self, image, output: str = "RGB") -> Image.Image:
if isinstance(image, str):
if image.startswith("http"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert(output)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image)).convert(output)
return image
def __call__(
self, inputs: list[str] | list[Image.Image] | list[str]
) -> list[np.ndarray]:
self._load_model_and_utils()
processed = self._preprocess_inputs(inputs)
input_names = self.runner.get_input_names()
onnx_inputs = {name: [] for name in input_names}
input: dict[str, any]
for input in processed:
for key, value in input.items():
if key in input_names:
onnx_inputs[key].append(value[0])
for key in input_names:
if onnx_inputs.get(key):
onnx_inputs[key] = np.stack(onnx_inputs[key])
else:
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings]