blakeblackshear.frigate/frigate/embeddings/onnx/runner.py
Josh Hawkins d0e9bcbfdc
Add ability to use Jina CLIP V2 for semantic search (#16826)
* add wheels

* move extra index url to bottom

* config model option

* add postprocess

* fix config

* jina v2 embedding class

* use jina v2 in embeddings

* fix ov inference

* frontend

* update reference config

* revert device

* fix truncation

* return np tensors

* use correct embeddings from inference

* manual preprocess

* clean up

* docs

* lower batch size for v2 only

* docs clarity

* wording
2025-02-26 07:58:25 -07:00

75 lines
2.2 KiB
Python

"""Convenience runner for onnx models."""
import logging
from typing import Any
import onnxruntime as ort
from frigate.util.model import get_ort_providers
try:
import openvino as ov
except ImportError:
# openvino is not included
pass
logger = logging.getLogger(__name__)
class ONNXModelRunner:
"""Run onnx models optimally based on available hardware."""
def __init__(self, model_path: str, device: str, requires_fp16: bool = False):
self.model_path = model_path
self.ort: ort.InferenceSession = None
self.ov: ov.Core = None
providers, options = get_ort_providers(device == "CPU", device, requires_fp16)
self.interpreter = None
if "OpenVINOExecutionProvider" in providers:
try:
# use OpenVINO directly
self.type = "ov"
self.ov = ov.Core()
self.ov.set_property(
{ov.properties.cache_dir: "/config/model_cache/openvino"}
)
self.interpreter = self.ov.compile_model(
model=model_path, device_name=device
)
except Exception as e:
logger.warning(
f"OpenVINO failed to build model, using CPU instead: {e}"
)
self.interpreter = None
# Use ONNXRuntime
if self.interpreter is None:
self.type = "ort"
self.ort = ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
)
def get_input_names(self) -> list[str]:
if self.type == "ov":
input_names = []
for input in self.interpreter.inputs:
input_names.extend(input.names)
return input_names
elif self.type == "ort":
return [input.name for input in self.ort.get_inputs()]
def run(self, input: dict[str, Any]) -> Any:
if self.type == "ov":
infer_request = self.interpreter.create_infer_request()
outputs = infer_request.infer(input)
return outputs
elif self.type == "ort":
return self.ort.run(None, input)