blakeblackshear.frigate/frigate/embeddings/onnx/base_embedding.py
Martin Weinelt 4d4d54d030
Fix various typing issues (#18187)
* Fix the `Any` typing hint treewide

There has been confusion between the Any type[1] and the any function[2]
in typing hints.

[1] https://docs.python.org/3/library/typing.html#typing.Any
[2] https://docs.python.org/3/library/functions.html#any

* Fix typing for various frame_shape members

Frame shapes are most likely defined by height and width, so a single int
cannot express that.

* Wrap gpu stats functions in Optional[]

These can return `None`, so they need to be `Type | None`, which is what
`Optional` expresses very nicely.

* Fix return type in get_latest_segment_datetime

Returns a datetime object, not an integer.

* Make the return type of FrameManager.write optional

This is necessary since the SharedMemoryFrameManager.write function can
return None.

* Fix total_seconds() return type in get_tz_modifiers

The function returns a float, not an int.

https://docs.python.org/3/library/datetime.html#datetime.timedelta.total_seconds

* Account for floating point results in to_relative_box

Because the function uses division the return types may either be int or
float.

* Resolve ruff deprecation warning

The config has been split into formatter and linter, and the global
options are deprecated.
2025-05-13 08:27:20 -06:00

104 lines
3.2 KiB
Python

"""Base class for onnx embedding implementations."""
import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from io import BytesIO
from typing import Any
import numpy as np
import requests
from PIL import Image
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class BaseEmbedding(ABC):
"""Base embedding class."""
def __init__(self, model_name: str, model_file: str, download_urls: dict[str, str]):
self.model_name = model_name
self.model_file = model_file
self.download_urls = download_urls
self.downloader: ModelDownloader = None
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.error,
},
)
@abstractmethod
def _load_model_and_utils(self):
pass
@abstractmethod
def _preprocess_inputs(self, raw_inputs: Any) -> Any:
pass
def _process_image(self, image, output: str = "RGB") -> Image.Image:
if isinstance(image, str):
if image.startswith("http"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert(output)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image)).convert(output)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
return image
def _postprocess_outputs(self, outputs: Any) -> Any:
return outputs
def __call__(
self, inputs: list[str] | list[Image.Image] | list[str]
) -> list[np.ndarray]:
self._load_model_and_utils()
processed = self._preprocess_inputs(inputs)
input_names = self.runner.get_input_names()
onnx_inputs = {name: [] for name in input_names}
input: dict[str, Any]
for input in processed:
for key, value in input.items():
if key in input_names:
onnx_inputs[key].append(value[0])
for key in input_names:
if onnx_inputs.get(key):
onnx_inputs[key] = np.stack(onnx_inputs[key])
else:
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
outputs = self.runner.run(onnx_inputs)[0]
embeddings = self._postprocess_outputs(outputs)
return [embedding for embedding in embeddings]