mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-31 13:48:19 +02:00
Refactor ONNX embedding class to use a base class and type-specific classes (#16703)
* Move onnx runner * Build out base embedding * Convert text embedding to separate class * Move image embedding to separate * Move LPR to separate class * Remove mono embedding * Simplify model downloading * Reorganize jina v1 embeddings * Cleanup * Cleanup for review
This commit is contained in:
parent
649e5cfda5
commit
c736b1dae5
@ -16,7 +16,12 @@ from shapely.geometry import Polygon
|
|||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.const import FRIGATE_LOCALHOST
|
from frigate.const import FRIGATE_LOCALHOST
|
||||||
from frigate.embeddings.functions.onnx import GenericONNXEmbedding, ModelTypeEnum
|
from frigate.embeddings.onnx.lpr_embedding import (
|
||||||
|
LicensePlateDetector,
|
||||||
|
PaddleOCRClassification,
|
||||||
|
PaddleOCRDetection,
|
||||||
|
PaddleOCRRecognition,
|
||||||
|
)
|
||||||
from frigate.util.image import area
|
from frigate.util.image import area
|
||||||
|
|
||||||
from ..types import DataProcessorMetrics
|
from ..types import DataProcessorMetrics
|
||||||
@ -52,49 +57,26 @@ class LicensePlateProcessor(RealTimeProcessorApi):
|
|||||||
self.lpr_recognition_model = None
|
self.lpr_recognition_model = None
|
||||||
|
|
||||||
if self.config.lpr.enabled:
|
if self.config.lpr.enabled:
|
||||||
self.detection_model = GenericONNXEmbedding(
|
self.detection_model = PaddleOCRDetection(
|
||||||
model_name="paddleocr-onnx",
|
|
||||||
model_file="detection.onnx",
|
|
||||||
download_urls={
|
|
||||||
"detection.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/detection.onnx"
|
|
||||||
},
|
|
||||||
model_size="large",
|
model_size="large",
|
||||||
model_type=ModelTypeEnum.lpr_detect,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.classification_model = GenericONNXEmbedding(
|
self.classification_model = PaddleOCRClassification(
|
||||||
model_name="paddleocr-onnx",
|
|
||||||
model_file="classification.onnx",
|
|
||||||
download_urls={
|
|
||||||
"classification.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/classification.onnx"
|
|
||||||
},
|
|
||||||
model_size="large",
|
model_size="large",
|
||||||
model_type=ModelTypeEnum.lpr_classify,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognition_model = GenericONNXEmbedding(
|
self.recognition_model = PaddleOCRRecognition(
|
||||||
model_name="paddleocr-onnx",
|
|
||||||
model_file="recognition.onnx",
|
|
||||||
download_urls={
|
|
||||||
"recognition.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/recognition.onnx"
|
|
||||||
},
|
|
||||||
model_size="large",
|
model_size="large",
|
||||||
model_type=ModelTypeEnum.lpr_recognize,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
self.yolov9_detection_model = GenericONNXEmbedding(
|
|
||||||
model_name="yolov9_license_plate",
|
self.yolov9_detection_model = LicensePlateDetector(
|
||||||
model_file="yolov9-256-license-plates.onnx",
|
|
||||||
download_urls={
|
|
||||||
"yolov9-256-license-plates.onnx": "https://github.com/hawkeye217/yolov9-license-plates/raw/refs/heads/master/models/yolov9-256-license-plates.onnx"
|
|
||||||
},
|
|
||||||
model_size="large",
|
model_size="large",
|
||||||
model_type=ModelTypeEnum.yolov9_lpr_detect,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,7 @@ from frigate.types import ModelStatusTypesEnum
|
|||||||
from frigate.util.builtin import serialize
|
from frigate.util.builtin import serialize
|
||||||
from frigate.util.path import get_event_thumbnail_bytes
|
from frigate.util.path import get_event_thumbnail_bytes
|
||||||
|
|
||||||
from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum
|
from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -97,36 +97,14 @@ class Embeddings:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.text_embedding = GenericONNXEmbedding(
|
self.text_embedding = JinaV1TextEmbedding(
|
||||||
model_name="jinaai/jina-clip-v1",
|
|
||||||
model_file="text_model_fp16.onnx",
|
|
||||||
tokenizer_file="tokenizer",
|
|
||||||
download_urls={
|
|
||||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
|
||||||
},
|
|
||||||
model_size=config.semantic_search.model_size,
|
model_size=config.semantic_search.model_size,
|
||||||
model_type=ModelTypeEnum.text,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_file = (
|
self.vision_embedding = JinaV1ImageEmbedding(
|
||||||
"vision_model_fp16.onnx"
|
|
||||||
if self.config.semantic_search.model_size == "large"
|
|
||||||
else "vision_model_quantized.onnx"
|
|
||||||
)
|
|
||||||
|
|
||||||
download_urls = {
|
|
||||||
model_file: f"https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/{model_file}",
|
|
||||||
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
self.vision_embedding = GenericONNXEmbedding(
|
|
||||||
model_name="jinaai/jina-clip-v1",
|
|
||||||
model_file=model_file,
|
|
||||||
download_urls=download_urls,
|
|
||||||
model_size=config.semantic_search.model_size,
|
model_size=config.semantic_search.model_size,
|
||||||
model_type=ModelTypeEnum.vision,
|
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
|
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
@ -1,325 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from enum import Enum
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# importing this without pytorch or others causes a warning
|
|
||||||
# https://github.com/huggingface/transformers/issues/27214
|
|
||||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
|
||||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
|
||||||
from transformers.utils.logging import disable_progress_bar
|
|
||||||
|
|
||||||
from frigate.comms.inter_process import InterProcessRequestor
|
|
||||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
|
||||||
from frigate.types import ModelStatusTypesEnum
|
|
||||||
from frigate.util.downloader import ModelDownloader
|
|
||||||
from frigate.util.model import ONNXModelRunner
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
category=FutureWarning,
|
|
||||||
message="The class CLIPFeatureExtractor is deprecated",
|
|
||||||
)
|
|
||||||
|
|
||||||
# disables the progress bar for downloading tokenizers and feature extractors
|
|
||||||
disable_progress_bar()
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
FACE_EMBEDDING_SIZE = 160
|
|
||||||
LPR_EMBEDDING_SIZE = 256
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeEnum(str, Enum):
|
|
||||||
face = "face"
|
|
||||||
vision = "vision"
|
|
||||||
text = "text"
|
|
||||||
lpr_detect = "lpr_detect"
|
|
||||||
lpr_classify = "lpr_classify"
|
|
||||||
lpr_recognize = "lpr_recognize"
|
|
||||||
yolov9_lpr_detect = "yolov9_lpr_detect"
|
|
||||||
|
|
||||||
|
|
||||||
class GenericONNXEmbedding:
|
|
||||||
"""Generic embedding function for ONNX models (text and vision)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
model_file: str,
|
|
||||||
download_urls: Dict[str, str],
|
|
||||||
model_size: str,
|
|
||||||
model_type: ModelTypeEnum,
|
|
||||||
requestor: InterProcessRequestor,
|
|
||||||
tokenizer_file: Optional[str] = None,
|
|
||||||
device: str = "AUTO",
|
|
||||||
):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.model_file = model_file
|
|
||||||
self.tokenizer_file = tokenizer_file
|
|
||||||
self.requestor = requestor
|
|
||||||
self.download_urls = download_urls
|
|
||||||
self.model_type = model_type
|
|
||||||
self.model_size = model_size
|
|
||||||
self.device = device
|
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
|
||||||
self.tokenizer = None
|
|
||||||
self.feature_extractor = None
|
|
||||||
self.runner = None
|
|
||||||
files_names = list(self.download_urls.keys()) + (
|
|
||||||
[self.tokenizer_file] if self.tokenizer_file else []
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all(
|
|
||||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
|
||||||
):
|
|
||||||
logger.debug(f"starting model download for {self.model_name}")
|
|
||||||
self.downloader = ModelDownloader(
|
|
||||||
model_name=self.model_name,
|
|
||||||
download_path=self.download_path,
|
|
||||||
file_names=files_names,
|
|
||||||
download_func=self._download_model,
|
|
||||||
)
|
|
||||||
self.downloader.ensure_model_files()
|
|
||||||
else:
|
|
||||||
self.downloader = None
|
|
||||||
ModelDownloader.mark_files_state(
|
|
||||||
self.requestor,
|
|
||||||
self.model_name,
|
|
||||||
files_names,
|
|
||||||
ModelStatusTypesEnum.downloaded,
|
|
||||||
)
|
|
||||||
self._load_model_and_utils()
|
|
||||||
logger.debug(f"models are already downloaded for {self.model_name}")
|
|
||||||
|
|
||||||
def _download_model(self, path: str):
|
|
||||||
try:
|
|
||||||
file_name = os.path.basename(path)
|
|
||||||
|
|
||||||
if file_name in self.download_urls:
|
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
|
||||||
elif (
|
|
||||||
file_name == self.tokenizer_file
|
|
||||||
and self.model_type == ModelTypeEnum.text
|
|
||||||
):
|
|
||||||
if not os.path.exists(path + "/" + self.model_name):
|
|
||||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
|
|
||||||
clean_up_tokenization_spaces=True,
|
|
||||||
)
|
|
||||||
tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
self.downloader.requestor.send_data(
|
|
||||||
UPDATE_MODEL_STATE,
|
|
||||||
{
|
|
||||||
"model": f"{self.model_name}-{file_name}",
|
|
||||||
"state": ModelStatusTypesEnum.downloaded,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
self.downloader.requestor.send_data(
|
|
||||||
UPDATE_MODEL_STATE,
|
|
||||||
{
|
|
||||||
"model": f"{self.model_name}-{file_name}",
|
|
||||||
"state": ModelStatusTypesEnum.error,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load_model_and_utils(self):
|
|
||||||
if self.runner is None:
|
|
||||||
if self.downloader:
|
|
||||||
self.downloader.wait_for_download()
|
|
||||||
if self.model_type == ModelTypeEnum.text:
|
|
||||||
self.tokenizer = self._load_tokenizer()
|
|
||||||
elif self.model_type == ModelTypeEnum.vision:
|
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
|
||||||
elif self.model_type == ModelTypeEnum.face:
|
|
||||||
self.feature_extractor = []
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_detect:
|
|
||||||
self.feature_extractor = []
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_classify:
|
|
||||||
self.feature_extractor = []
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_recognize:
|
|
||||||
self.feature_extractor = []
|
|
||||||
elif self.model_type == ModelTypeEnum.yolov9_lpr_detect:
|
|
||||||
self.feature_extractor = []
|
|
||||||
|
|
||||||
self.runner = ONNXModelRunner(
|
|
||||||
os.path.join(self.download_path, self.model_file),
|
|
||||||
self.device,
|
|
||||||
self.model_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load_tokenizer(self):
|
|
||||||
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
|
||||||
return AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
cache_dir=tokenizer_path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
clean_up_tokenization_spaces=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load_feature_extractor(self):
|
|
||||||
return AutoFeatureExtractor.from_pretrained(
|
|
||||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
|
||||||
if self.model_type == ModelTypeEnum.text:
|
|
||||||
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
|
||||||
return [
|
|
||||||
self.tokenizer(
|
|
||||||
text,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
for text in raw_inputs
|
|
||||||
]
|
|
||||||
elif self.model_type == ModelTypeEnum.vision:
|
|
||||||
processed_images = [self._process_image(img) for img in raw_inputs]
|
|
||||||
return [
|
|
||||||
self.feature_extractor(images=image, return_tensors="np")
|
|
||||||
for image in processed_images
|
|
||||||
]
|
|
||||||
elif self.model_type == ModelTypeEnum.face:
|
|
||||||
if isinstance(raw_inputs, list):
|
|
||||||
raise ValueError("Face embedding does not support batch inputs.")
|
|
||||||
|
|
||||||
pil = self._process_image(raw_inputs)
|
|
||||||
|
|
||||||
# handle images larger than input size
|
|
||||||
width, height = pil.size
|
|
||||||
if width != FACE_EMBEDDING_SIZE or height != FACE_EMBEDDING_SIZE:
|
|
||||||
if width > height:
|
|
||||||
new_height = int(((height / width) * FACE_EMBEDDING_SIZE) // 4 * 4)
|
|
||||||
pil = pil.resize((FACE_EMBEDDING_SIZE, new_height))
|
|
||||||
else:
|
|
||||||
new_width = int(((width / height) * FACE_EMBEDDING_SIZE) // 4 * 4)
|
|
||||||
pil = pil.resize((new_width, FACE_EMBEDDING_SIZE))
|
|
||||||
|
|
||||||
og = np.array(pil).astype(np.float32)
|
|
||||||
|
|
||||||
# Image must be FACE_EMBEDDING_SIZExFACE_EMBEDDING_SIZE
|
|
||||||
og_h, og_w, channels = og.shape
|
|
||||||
frame = np.full(
|
|
||||||
(FACE_EMBEDDING_SIZE, FACE_EMBEDDING_SIZE, channels),
|
|
||||||
(0, 0, 0),
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute center offset
|
|
||||||
x_center = (FACE_EMBEDDING_SIZE - og_w) // 2
|
|
||||||
y_center = (FACE_EMBEDDING_SIZE - og_h) // 2
|
|
||||||
|
|
||||||
# copy img image into center of result image
|
|
||||||
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
|
|
||||||
frame = np.expand_dims(frame, axis=0)
|
|
||||||
return [{"input_2": frame}]
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_detect:
|
|
||||||
preprocessed = []
|
|
||||||
for x in raw_inputs:
|
|
||||||
preprocessed.append(x)
|
|
||||||
return [{"x": preprocessed[0]}]
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_classify:
|
|
||||||
processed = []
|
|
||||||
for img in raw_inputs:
|
|
||||||
processed.append({"x": img})
|
|
||||||
return processed
|
|
||||||
elif self.model_type == ModelTypeEnum.lpr_recognize:
|
|
||||||
processed = []
|
|
||||||
for img in raw_inputs:
|
|
||||||
processed.append({"x": img})
|
|
||||||
return processed
|
|
||||||
elif self.model_type == ModelTypeEnum.yolov9_lpr_detect:
|
|
||||||
if isinstance(raw_inputs, list):
|
|
||||||
raise ValueError(
|
|
||||||
"License plate embedding does not support batch inputs."
|
|
||||||
)
|
|
||||||
# Get image as numpy array
|
|
||||||
img = self._process_image(raw_inputs)
|
|
||||||
height, width, channels = img.shape
|
|
||||||
|
|
||||||
# Resize maintaining aspect ratio
|
|
||||||
if width > height:
|
|
||||||
new_height = int(((height / width) * LPR_EMBEDDING_SIZE) // 4 * 4)
|
|
||||||
img = cv2.resize(img, (LPR_EMBEDDING_SIZE, new_height))
|
|
||||||
else:
|
|
||||||
new_width = int(((width / height) * LPR_EMBEDDING_SIZE) // 4 * 4)
|
|
||||||
img = cv2.resize(img, (new_width, LPR_EMBEDDING_SIZE))
|
|
||||||
|
|
||||||
# Get new dimensions after resize
|
|
||||||
og_h, og_w, channels = img.shape
|
|
||||||
|
|
||||||
# Create black square frame
|
|
||||||
frame = np.full(
|
|
||||||
(LPR_EMBEDDING_SIZE, LPR_EMBEDDING_SIZE, channels),
|
|
||||||
(0, 0, 0),
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Center the resized image in the square frame
|
|
||||||
x_center = (LPR_EMBEDDING_SIZE - og_w) // 2
|
|
||||||
y_center = (LPR_EMBEDDING_SIZE - og_h) // 2
|
|
||||||
frame[y_center : y_center + og_h, x_center : x_center + og_w] = img
|
|
||||||
|
|
||||||
# Normalize to 0-1
|
|
||||||
frame = frame / 255.0
|
|
||||||
|
|
||||||
# Convert from HWC to CHW format and add batch dimension
|
|
||||||
frame = np.transpose(frame, (2, 0, 1))
|
|
||||||
frame = np.expand_dims(frame, axis=0)
|
|
||||||
return [{"images": frame}]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
|
||||||
|
|
||||||
def _process_image(self, image, output: str = "RGB") -> Image.Image:
|
|
||||||
if isinstance(image, str):
|
|
||||||
if image.startswith("http"):
|
|
||||||
response = requests.get(image)
|
|
||||||
image = Image.open(BytesIO(response.content)).convert(output)
|
|
||||||
elif isinstance(image, bytes):
|
|
||||||
image = Image.open(BytesIO(image)).convert(output)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
|
||||||
) -> List[np.ndarray]:
|
|
||||||
self._load_model_and_utils()
|
|
||||||
if self.runner is None or (
|
|
||||||
self.tokenizer is None and self.feature_extractor is None
|
|
||||||
):
|
|
||||||
logger.error(
|
|
||||||
f"{self.model_name} model or tokenizer/feature extractor is not loaded."
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
processed_inputs = self._preprocess_inputs(inputs)
|
|
||||||
input_names = self.runner.get_input_names()
|
|
||||||
onnx_inputs = {name: [] for name in input_names}
|
|
||||||
input: dict[str, any]
|
|
||||||
for input in processed_inputs:
|
|
||||||
for key, value in input.items():
|
|
||||||
if key in input_names:
|
|
||||||
onnx_inputs[key].append(value[0])
|
|
||||||
|
|
||||||
for key in input_names:
|
|
||||||
if onnx_inputs.get(key):
|
|
||||||
onnx_inputs[key] = np.stack(onnx_inputs[key])
|
|
||||||
else:
|
|
||||||
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
|
|
||||||
|
|
||||||
embeddings = self.runner.run(onnx_inputs)[0]
|
|
||||||
return [embedding for embedding in embeddings]
|
|
95
frigate/embeddings/onnx/base_embedding.py
Normal file
95
frigate/embeddings/onnx/base_embedding.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Base class for onnx embedding implementations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from frigate.const import UPDATE_MODEL_STATE
|
||||||
|
from frigate.types import ModelStatusTypesEnum
|
||||||
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTypeEnum(str, Enum):
|
||||||
|
thumbnail = "thumbnail"
|
||||||
|
description = "description"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbedding(ABC):
|
||||||
|
"""Base embedding class."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, model_file: str, download_urls: dict[str, str]):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_file = model_file
|
||||||
|
self.download_urls = download_urls
|
||||||
|
self.downloader: ModelDownloader = None
|
||||||
|
|
||||||
|
def _download_model(self, path: str):
|
||||||
|
try:
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
|
if file_name in self.download_urls:
|
||||||
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
|
|
||||||
|
self.downloader.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": f"{self.model_name}-{file_name}",
|
||||||
|
"state": ModelStatusTypesEnum.downloaded,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.downloader.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": f"{self.model_name}-{file_name}",
|
||||||
|
"state": ModelStatusTypesEnum.error,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _process_image(self, image, output: str = "RGB") -> Image.Image:
|
||||||
|
if isinstance(image, str):
|
||||||
|
if image.startswith("http"):
|
||||||
|
response = requests.get(image)
|
||||||
|
image = Image.open(BytesIO(response.content)).convert(output)
|
||||||
|
elif isinstance(image, bytes):
|
||||||
|
image = Image.open(BytesIO(image)).convert(output)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, inputs: list[str] | list[Image.Image] | list[str]
|
||||||
|
) -> list[np.ndarray]:
|
||||||
|
self._load_model_and_utils()
|
||||||
|
processed = self._preprocess_inputs(inputs)
|
||||||
|
input_names = self.runner.get_input_names()
|
||||||
|
onnx_inputs = {name: [] for name in input_names}
|
||||||
|
input: dict[str, any]
|
||||||
|
for input in processed:
|
||||||
|
for key, value in input.items():
|
||||||
|
if key in input_names:
|
||||||
|
onnx_inputs[key].append(value[0])
|
||||||
|
|
||||||
|
for key in input_names:
|
||||||
|
if onnx_inputs.get(key):
|
||||||
|
onnx_inputs[key] = np.stack(onnx_inputs[key])
|
||||||
|
else:
|
||||||
|
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
|
||||||
|
|
||||||
|
embeddings = self.runner.run(onnx_inputs)[0]
|
||||||
|
return [embedding for embedding in embeddings]
|
216
frigate/embeddings/onnx/jina_v1_embedding.py
Normal file
216
frigate/embeddings/onnx/jina_v1_embedding.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
"""JinaV1 Embeddings."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
# importing this without pytorch or others causes a warning
|
||||||
|
# https://github.com/huggingface/transformers/issues/27214
|
||||||
|
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||||
|
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||||
|
from transformers.utils.logging import disable_progress_bar
|
||||||
|
|
||||||
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||||
|
from frigate.types import ModelStatusTypesEnum
|
||||||
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
|
||||||
|
from .base_embedding import BaseEmbedding
|
||||||
|
from .runner import ONNXModelRunner
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
category=FutureWarning,
|
||||||
|
message="The class CLIPFeatureExtractor is deprecated",
|
||||||
|
)
|
||||||
|
|
||||||
|
# disables the progress bar for downloading tokenizers and feature extractors
|
||||||
|
disable_progress_bar()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JinaV1TextEmbedding(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_name="jinaai/jina-clip-v1",
|
||||||
|
model_file="text_model_fp16.onnx",
|
||||||
|
download_urls={
|
||||||
|
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.tokenizer_file = "tokenizer"
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.tokenizer = None
|
||||||
|
self.feature_extractor = None
|
||||||
|
self.runner = None
|
||||||
|
files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _download_model(self, path: str):
|
||||||
|
try:
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
|
if file_name in self.download_urls:
|
||||||
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
|
elif file_name == self.tokenizer_file:
|
||||||
|
if not os.path.exists(path + "/" + self.model_name):
|
||||||
|
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
self.downloader.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": f"{self.model_name}-{file_name}",
|
||||||
|
"state": ModelStatusTypesEnum.downloaded,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.downloader.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": f"{self.model_name}-{file_name}",
|
||||||
|
"state": ModelStatusTypesEnum.error,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
tokenizer_path = os.path.join(
|
||||||
|
f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer"
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
cache_dir=tokenizer_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
||||||
|
return [
|
||||||
|
self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
for text in raw_inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class JinaV1ImageEmbedding(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
model_file = (
|
||||||
|
"vision_model_fp16.onnx"
|
||||||
|
if model_size == "large"
|
||||||
|
else "vision_model_quantized.onnx"
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
model_name="jinaai/jina-clip-v1",
|
||||||
|
model_file=model_file,
|
||||||
|
download_urls={
|
||||||
|
model_file: f"https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/{model_file}",
|
||||||
|
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.feature_extractor = None
|
||||||
|
self.runner: ONNXModelRunner | None = None
|
||||||
|
files_names = list(self.download_urls.keys())
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
processed_images = [self._process_image(img) for img in raw_inputs]
|
||||||
|
return [
|
||||||
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
|
for image in processed_images
|
||||||
|
]
|
297
frigate/embeddings/onnx/lpr_embedding.py
Normal file
297
frigate/embeddings/onnx/lpr_embedding.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
|
from frigate.const import MODEL_CACHE_DIR
|
||||||
|
from frigate.types import ModelStatusTypesEnum
|
||||||
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
|
||||||
|
from .base_embedding import BaseEmbedding
|
||||||
|
from .runner import ONNXModelRunner
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
category=FutureWarning,
|
||||||
|
message="The class CLIPFeatureExtractor is deprecated",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LPR_EMBEDDING_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleOCRDetection(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_name="paddleocr-onnx",
|
||||||
|
model_file="detection.onnx",
|
||||||
|
download_urls={
|
||||||
|
"detection.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/detection.onnx"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.runner: ONNXModelRunner | None = None
|
||||||
|
files_names = list(self.download_urls.keys())
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
preprocessed = []
|
||||||
|
for x in raw_inputs:
|
||||||
|
preprocessed.append(x)
|
||||||
|
return [{"x": preprocessed[0]}]
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleOCRClassification(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_name="paddleocr-onnx",
|
||||||
|
model_file="classification.onnx",
|
||||||
|
download_urls={
|
||||||
|
"classification.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/classification.onnx"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.runner: ONNXModelRunner | None = None
|
||||||
|
files_names = list(self.download_urls.keys())
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
processed = []
|
||||||
|
for img in raw_inputs:
|
||||||
|
processed.append({"x": img})
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleOCRRecognition(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_name="paddleocr-onnx",
|
||||||
|
model_file="recognition.onnx",
|
||||||
|
download_urls={
|
||||||
|
"recognition.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/recognition.onnx"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.runner: ONNXModelRunner | None = None
|
||||||
|
files_names = list(self.download_urls.keys())
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
processed = []
|
||||||
|
for img in raw_inputs:
|
||||||
|
processed.append({"x": img})
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
class LicensePlateDetector(BaseEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
|
device: str = "AUTO",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_name="yolov9_license_plate",
|
||||||
|
model_file="yolov9-256-license-plates.onnx",
|
||||||
|
download_urls={
|
||||||
|
"yolov9-256-license-plates.onnx": "https://github.com/hawkeye217/yolov9-license-plates/raw/refs/heads/master/models/yolov9-256-license-plates.onnx"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.requestor = requestor
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
|
self.runner: ONNXModelRunner | None = None
|
||||||
|
files_names = list(self.download_urls.keys())
|
||||||
|
if not all(
|
||||||
|
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||||
|
):
|
||||||
|
logger.debug(f"starting model download for {self.model_name}")
|
||||||
|
self.downloader = ModelDownloader(
|
||||||
|
model_name=self.model_name,
|
||||||
|
download_path=self.download_path,
|
||||||
|
file_names=files_names,
|
||||||
|
download_func=self._download_model,
|
||||||
|
)
|
||||||
|
self.downloader.ensure_model_files()
|
||||||
|
else:
|
||||||
|
self.downloader = None
|
||||||
|
ModelDownloader.mark_files_state(
|
||||||
|
self.requestor,
|
||||||
|
self.model_name,
|
||||||
|
files_names,
|
||||||
|
ModelStatusTypesEnum.downloaded,
|
||||||
|
)
|
||||||
|
self._load_model_and_utils()
|
||||||
|
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||||
|
|
||||||
|
def _load_model_and_utils(self):
|
||||||
|
if self.runner is None:
|
||||||
|
if self.downloader:
|
||||||
|
self.downloader.wait_for_download()
|
||||||
|
|
||||||
|
self.runner = ONNXModelRunner(
|
||||||
|
os.path.join(self.download_path, self.model_file),
|
||||||
|
self.device,
|
||||||
|
self.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
if isinstance(raw_inputs, list):
|
||||||
|
raise ValueError("License plate embedding does not support batch inputs.")
|
||||||
|
# Get image as numpy array
|
||||||
|
img = self._process_image(raw_inputs)
|
||||||
|
height, width, channels = img.shape
|
||||||
|
|
||||||
|
# Resize maintaining aspect ratio
|
||||||
|
if width > height:
|
||||||
|
new_height = int(((height / width) * LPR_EMBEDDING_SIZE) // 4 * 4)
|
||||||
|
img = cv2.resize(img, (LPR_EMBEDDING_SIZE, new_height))
|
||||||
|
else:
|
||||||
|
new_width = int(((width / height) * LPR_EMBEDDING_SIZE) // 4 * 4)
|
||||||
|
img = cv2.resize(img, (new_width, LPR_EMBEDDING_SIZE))
|
||||||
|
|
||||||
|
# Get new dimensions after resize
|
||||||
|
og_h, og_w, channels = img.shape
|
||||||
|
|
||||||
|
# Create black square frame
|
||||||
|
frame = np.full(
|
||||||
|
(LPR_EMBEDDING_SIZE, LPR_EMBEDDING_SIZE, channels),
|
||||||
|
(0, 0, 0),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Center the resized image in the square frame
|
||||||
|
x_center = (LPR_EMBEDDING_SIZE - og_w) // 2
|
||||||
|
y_center = (LPR_EMBEDDING_SIZE - og_h) // 2
|
||||||
|
frame[y_center : y_center + og_h, x_center : x_center + og_w] = img
|
||||||
|
|
||||||
|
# Normalize to 0-1
|
||||||
|
frame = frame / 255.0
|
||||||
|
|
||||||
|
# Convert from HWC to CHW format and add batch dimension
|
||||||
|
frame = np.transpose(frame, (2, 0, 1))
|
||||||
|
frame = np.expand_dims(frame, axis=0)
|
||||||
|
return [{"images": frame}]
|
79
frigate/embeddings/onnx/runner.py
Normal file
79
frigate/embeddings/onnx/runner.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Convenience runner for onnx models."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
from frigate.util.model import get_ort_providers
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openvino as ov
|
||||||
|
except ImportError:
|
||||||
|
# openvino is not included
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXModelRunner:
|
||||||
|
"""Run onnx models optimally based on available hardware."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, device: str, requires_fp16: bool = False):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.ort: ort.InferenceSession = None
|
||||||
|
self.ov: ov.Core = None
|
||||||
|
providers, options = get_ort_providers(device == "CPU", device, requires_fp16)
|
||||||
|
self.interpreter = None
|
||||||
|
|
||||||
|
if "OpenVINOExecutionProvider" in providers:
|
||||||
|
try:
|
||||||
|
# use OpenVINO directly
|
||||||
|
self.type = "ov"
|
||||||
|
self.ov = ov.Core()
|
||||||
|
self.ov.set_property(
|
||||||
|
{ov.properties.cache_dir: "/config/model_cache/openvino"}
|
||||||
|
)
|
||||||
|
self.interpreter = self.ov.compile_model(
|
||||||
|
model=model_path, device_name=device
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"OpenVINO failed to build model, using CPU instead: {e}"
|
||||||
|
)
|
||||||
|
self.interpreter = None
|
||||||
|
|
||||||
|
# Use ONNXRuntime
|
||||||
|
if self.interpreter is None:
|
||||||
|
self.type = "ort"
|
||||||
|
self.ort = ort.InferenceSession(
|
||||||
|
model_path,
|
||||||
|
providers=providers,
|
||||||
|
provider_options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_names(self) -> list[str]:
|
||||||
|
if self.type == "ov":
|
||||||
|
input_names = []
|
||||||
|
|
||||||
|
for input in self.interpreter.inputs:
|
||||||
|
input_names.extend(input.names)
|
||||||
|
|
||||||
|
return input_names
|
||||||
|
elif self.type == "ort":
|
||||||
|
return [input.name for input in self.ort.get_inputs()]
|
||||||
|
|
||||||
|
def run(self, input: dict[str, Any]) -> Any:
|
||||||
|
if self.type == "ov":
|
||||||
|
infer_request = self.interpreter.create_infer_request()
|
||||||
|
input_tensor = list(input.values())
|
||||||
|
|
||||||
|
if len(input_tensor) == 1:
|
||||||
|
input_tensor = ov.Tensor(array=input_tensor[0])
|
||||||
|
else:
|
||||||
|
input_tensor = ov.Tensor(array=input_tensor)
|
||||||
|
|
||||||
|
infer_request.infer(input_tensor)
|
||||||
|
return [infer_request.get_output_tensor().data]
|
||||||
|
elif self.type == "ort":
|
||||||
|
return self.ort.run(None, input)
|
@ -2,18 +2,11 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
try:
|
|
||||||
import openvino as ov
|
|
||||||
except ImportError:
|
|
||||||
# openvino is not included
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
### Post Processing
|
### Post Processing
|
||||||
@ -124,66 +117,3 @@ def get_ort_providers(
|
|||||||
options.append({})
|
options.append({})
|
||||||
|
|
||||||
return (providers, options)
|
return (providers, options)
|
||||||
|
|
||||||
|
|
||||||
class ONNXModelRunner:
|
|
||||||
"""Run onnx models optimally based on available hardware."""
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, device: str, requires_fp16: bool = False):
|
|
||||||
self.model_path = model_path
|
|
||||||
self.ort: ort.InferenceSession = None
|
|
||||||
self.ov: ov.Core = None
|
|
||||||
providers, options = get_ort_providers(device == "CPU", device, requires_fp16)
|
|
||||||
self.interpreter = None
|
|
||||||
|
|
||||||
if "OpenVINOExecutionProvider" in providers:
|
|
||||||
try:
|
|
||||||
# use OpenVINO directly
|
|
||||||
self.type = "ov"
|
|
||||||
self.ov = ov.Core()
|
|
||||||
self.ov.set_property(
|
|
||||||
{ov.properties.cache_dir: "/config/model_cache/openvino"}
|
|
||||||
)
|
|
||||||
self.interpreter = self.ov.compile_model(
|
|
||||||
model=model_path, device_name=device
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"OpenVINO failed to build model, using CPU instead: {e}"
|
|
||||||
)
|
|
||||||
self.interpreter = None
|
|
||||||
|
|
||||||
# Use ONNXRuntime
|
|
||||||
if self.interpreter is None:
|
|
||||||
self.type = "ort"
|
|
||||||
self.ort = ort.InferenceSession(
|
|
||||||
model_path,
|
|
||||||
providers=providers,
|
|
||||||
provider_options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_names(self) -> list[str]:
|
|
||||||
if self.type == "ov":
|
|
||||||
input_names = []
|
|
||||||
|
|
||||||
for input in self.interpreter.inputs:
|
|
||||||
input_names.extend(input.names)
|
|
||||||
|
|
||||||
return input_names
|
|
||||||
elif self.type == "ort":
|
|
||||||
return [input.name for input in self.ort.get_inputs()]
|
|
||||||
|
|
||||||
def run(self, input: dict[str, Any]) -> Any:
|
|
||||||
if self.type == "ov":
|
|
||||||
infer_request = self.interpreter.create_infer_request()
|
|
||||||
input_tensor = list(input.values())
|
|
||||||
|
|
||||||
if len(input_tensor) == 1:
|
|
||||||
input_tensor = ov.Tensor(array=input_tensor[0])
|
|
||||||
else:
|
|
||||||
input_tensor = ov.Tensor(array=input_tensor)
|
|
||||||
|
|
||||||
infer_request.infer(input_tensor)
|
|
||||||
return [infer_request.get_output_tensor().data]
|
|
||||||
elif self.type == "ort":
|
|
||||||
return self.ort.run(None, input)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user