mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-02-23 00:18:31 +01:00
* Move onnx runner * Build out base embedding * Convert text embedding to separate class * Move image embedding to separate * Move LPR to separate class * Remove mono embedding * Simplify model downloading * Reorganize jina v1 embeddings * Cleanup * Cleanup for review
96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
"""Base class for onnx embedding implementations."""
|
|
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
import requests
|
|
from PIL import Image
|
|
|
|
from frigate.const import UPDATE_MODEL_STATE
|
|
from frigate.types import ModelStatusTypesEnum
|
|
from frigate.util.downloader import ModelDownloader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingTypeEnum(str, Enum):
|
|
thumbnail = "thumbnail"
|
|
description = "description"
|
|
|
|
|
|
class BaseEmbedding(ABC):
|
|
"""Base embedding class."""
|
|
|
|
def __init__(self, model_name: str, model_file: str, download_urls: dict[str, str]):
|
|
self.model_name = model_name
|
|
self.model_file = model_file
|
|
self.download_urls = download_urls
|
|
self.downloader: ModelDownloader = None
|
|
|
|
def _download_model(self, path: str):
|
|
try:
|
|
file_name = os.path.basename(path)
|
|
|
|
if file_name in self.download_urls:
|
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
|
|
|
self.downloader.requestor.send_data(
|
|
UPDATE_MODEL_STATE,
|
|
{
|
|
"model": f"{self.model_name}-{file_name}",
|
|
"state": ModelStatusTypesEnum.downloaded,
|
|
},
|
|
)
|
|
except Exception:
|
|
self.downloader.requestor.send_data(
|
|
UPDATE_MODEL_STATE,
|
|
{
|
|
"model": f"{self.model_name}-{file_name}",
|
|
"state": ModelStatusTypesEnum.error,
|
|
},
|
|
)
|
|
|
|
@abstractmethod
|
|
def _load_model_and_utils(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
|
pass
|
|
|
|
def _process_image(self, image, output: str = "RGB") -> Image.Image:
|
|
if isinstance(image, str):
|
|
if image.startswith("http"):
|
|
response = requests.get(image)
|
|
image = Image.open(BytesIO(response.content)).convert(output)
|
|
elif isinstance(image, bytes):
|
|
image = Image.open(BytesIO(image)).convert(output)
|
|
|
|
return image
|
|
|
|
def __call__(
|
|
self, inputs: list[str] | list[Image.Image] | list[str]
|
|
) -> list[np.ndarray]:
|
|
self._load_model_and_utils()
|
|
processed = self._preprocess_inputs(inputs)
|
|
input_names = self.runner.get_input_names()
|
|
onnx_inputs = {name: [] for name in input_names}
|
|
input: dict[str, any]
|
|
for input in processed:
|
|
for key, value in input.items():
|
|
if key in input_names:
|
|
onnx_inputs[key].append(value[0])
|
|
|
|
for key in input_names:
|
|
if onnx_inputs.get(key):
|
|
onnx_inputs[key] = np.stack(onnx_inputs[key])
|
|
else:
|
|
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
|
|
|
|
embeddings = self.runner.run(onnx_inputs)[0]
|
|
return [embedding for embedding in embeddings]
|