mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Add openvino support for the DFINE model (#17238)
* add openvino support for the dfine model * update docs to show DFINE support for openvino * remove warning about OpenVINO for DFINE
This commit is contained in:
parent
125c266585
commit
e340c9aaba
@ -129,8 +129,8 @@ detectors:
|
|||||||
type: edgetpu
|
type: edgetpu
|
||||||
device: pci
|
device: pci
|
||||||
```
|
```
|
||||||
---
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Hailo-8
|
## Hailo-8
|
||||||
|
|
||||||
@ -146,6 +146,7 @@ If both are provided, the detector will first check for the model at the given l
|
|||||||
#### YOLO
|
#### YOLO
|
||||||
|
|
||||||
Use this configuration for YOLO-based models. When no custom model path or URL is provided, the detector automatically downloads the default model based on the detected hardware:
|
Use this configuration for YOLO-based models. When no custom model path or URL is provided, the detector automatically downloads the default model based on the detected hardware:
|
||||||
|
|
||||||
- **Hailo-8 hardware:** Uses **YOLOv6n** (default: `yolov6n.hef`)
|
- **Hailo-8 hardware:** Uses **YOLOv6n** (default: `yolov6n.hef`)
|
||||||
- **Hailo-8L hardware:** Uses **YOLOv6n** (default: `yolov6n.hef`)
|
- **Hailo-8L hardware:** Uses **YOLOv6n** (default: `yolov6n.hef`)
|
||||||
|
|
||||||
@ -224,6 +225,7 @@ model:
|
|||||||
# Alternatively, or as a fallback, provide a custom URL:
|
# Alternatively, or as a fallback, provide a custom URL:
|
||||||
# path: https://custom-model-url.com/path/to/model.hef
|
# path: https://custom-model-url.com/path/to/model.hef
|
||||||
```
|
```
|
||||||
|
|
||||||
For additional ready-to-use models, please visit: https://github.com/hailo-ai/hailo_model_zoo
|
For additional ready-to-use models, please visit: https://github.com/hailo-ai/hailo_model_zoo
|
||||||
|
|
||||||
Hailo8 supports all models in the Hailo Model Zoo that include HailoRT post-processing. You're welcome to choose any of these pre-configured models for your implementation.
|
Hailo8 supports all models in the Hailo Model Zoo that include HailoRT post-processing. You're welcome to choose any of these pre-configured models for your implementation.
|
||||||
@ -233,8 +235,6 @@ Hailo8 supports all models in the Hailo Model Zoo that include HailoRT post-proc
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## OpenVINO Detector
|
## OpenVINO Detector
|
||||||
|
|
||||||
The OpenVINO detector type runs an OpenVINO IR model on AMD and Intel CPUs, Intel GPUs and Intel VPU hardware. To configure an OpenVINO detector, set the `"type"` attribute to `"openvino"`.
|
The OpenVINO detector type runs an OpenVINO IR model on AMD and Intel CPUs, Intel GPUs and Intel VPU hardware. To configure an OpenVINO detector, set the `"type"` attribute to `"openvino"`.
|
||||||
@ -340,6 +340,30 @@ model:
|
|||||||
|
|
||||||
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
|
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
|
||||||
|
|
||||||
|
#### D-FINE
|
||||||
|
|
||||||
|
[D-FINE](https://github.com/Peterande/D-FINE) is the [current state of the art](https://paperswithcode.com/sota/real-time-object-detection-on-coco?p=d-fine-redefine-regression-task-in-detrs-as) at the time of writing. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
|
||||||
|
|
||||||
|
After placing the downloaded onnx model in your config/model_cache folder, you can use the following configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
detectors:
|
||||||
|
ov:
|
||||||
|
type: openvino
|
||||||
|
device: GPU
|
||||||
|
|
||||||
|
model:
|
||||||
|
model_type: dfine
|
||||||
|
width: 640
|
||||||
|
height: 640
|
||||||
|
input_tensor: nchw
|
||||||
|
input_dtype: float
|
||||||
|
path: /config/model_cache/dfine_s_obj2coco.onnx
|
||||||
|
labelmap_path: /labelmap/coco-80.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
|
||||||
|
|
||||||
## NVidia TensorRT Detector
|
## NVidia TensorRT Detector
|
||||||
|
|
||||||
Nvidia GPUs may be used for object detection using the TensorRT libraries. Due to the size of the additional libraries, this detector is only provided in images with the `-tensorrt` tag suffix, e.g. `ghcr.io/blakeblackshear/frigate:stable-tensorrt`. This detector is designed to work with Yolo models for object detection.
|
Nvidia GPUs may be used for object detection using the TensorRT libraries. Due to the size of the additional libraries, this detector is only provided in images with the `-tensorrt` tag suffix, e.g. `ghcr.io/blakeblackshear/frigate:stable-tensorrt`. This detector is designed to work with Yolo models for object detection.
|
||||||
@ -529,6 +553,7 @@ $ docker exec -it frigate /bin/bash -c '(unset HSA_OVERRIDE_GFX_VERSION && /opt/
|
|||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
See [ONNX supported models](#supported-models) for supported models, there are some caveats:
|
See [ONNX supported models](#supported-models) for supported models, there are some caveats:
|
||||||
|
|
||||||
- D-FINE models are not supported
|
- D-FINE models are not supported
|
||||||
- YOLO-NAS models are known to not run well on integrated GPUs
|
- YOLO-NAS models are known to not run well on integrated GPUs
|
||||||
|
|
||||||
@ -626,12 +651,6 @@ Note that the labelmap uses a subset of the complete COCO label set that has onl
|
|||||||
|
|
||||||
[D-FINE](https://github.com/Peterande/D-FINE) is the [current state of the art](https://paperswithcode.com/sota/real-time-object-detection-on-coco?p=d-fine-redefine-regression-task-in-detrs-as) at the time of writing. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
|
[D-FINE](https://github.com/Peterande/D-FINE) is the [current state of the art](https://paperswithcode.com/sota/real-time-object-detection-on-coco?p=d-fine-redefine-regression-task-in-detrs-as) at the time of writing. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
|
||||||
|
|
||||||
:::warning
|
|
||||||
|
|
||||||
D-FINE is currently not supported on OpenVINO
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
After placing the downloaded onnx model in your config/model_cache folder, you can use the following configuration:
|
After placing the downloaded onnx model in your config/model_cache folder, you can use the following configuration:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -10,7 +10,7 @@ from typing_extensions import Literal
|
|||||||
from frigate.const import MODEL_CACHE_DIR
|
from frigate.const import MODEL_CACHE_DIR
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
||||||
from frigate.util.model import post_process_yolov9
|
from frigate.util.model import post_process_dfine, post_process_yolov9
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -29,6 +29,7 @@ class OvDetector(DetectionApi):
|
|||||||
ModelTypeEnum.yolonas,
|
ModelTypeEnum.yolonas,
|
||||||
ModelTypeEnum.yolov9,
|
ModelTypeEnum.yolov9,
|
||||||
ModelTypeEnum.yolox,
|
ModelTypeEnum.yolox,
|
||||||
|
ModelTypeEnum.dfine,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, detector_config: OvDetectorConfig):
|
def __init__(self, detector_config: OvDetectorConfig):
|
||||||
@ -163,6 +164,21 @@ class OvDetector(DetectionApi):
|
|||||||
infer_request = self.interpreter.create_infer_request()
|
infer_request = self.interpreter.create_infer_request()
|
||||||
# TODO: see if we can use shared_memory=True
|
# TODO: see if we can use shared_memory=True
|
||||||
input_tensor = ov.Tensor(array=tensor_input)
|
input_tensor = ov.Tensor(array=tensor_input)
|
||||||
|
|
||||||
|
if self.ov_model_type == ModelTypeEnum.dfine:
|
||||||
|
infer_request.set_tensor("images", input_tensor)
|
||||||
|
target_sizes_tensor = ov.Tensor(
|
||||||
|
np.array([[self.h, self.w]], dtype=np.int64)
|
||||||
|
)
|
||||||
|
infer_request.set_tensor("orig_target_sizes", target_sizes_tensor)
|
||||||
|
infer_request.infer()
|
||||||
|
tensor_output = (
|
||||||
|
infer_request.get_output_tensor(0).data,
|
||||||
|
infer_request.get_output_tensor(1).data,
|
||||||
|
infer_request.get_output_tensor(2).data,
|
||||||
|
)
|
||||||
|
return post_process_dfine(tensor_output, self.w, self.h)
|
||||||
|
|
||||||
infer_request.infer(input_tensor)
|
infer_request.infer(input_tensor)
|
||||||
|
|
||||||
detections = np.zeros((20, 6), np.float32)
|
detections = np.zeros((20, 6), np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user