YOLOX support for OpenVINO Detector (#5285)

* Initial commit to enable Yolox models with OpenVINO in Frigate

* Fix ModelEnumtType import error in openvino.py

* Initial edit of the docs to include verbage about yolox

* Initial edit of the docs to include verbage about yolox

* Elaborate configuration and limitations in docs.

* Add capability to dynamically determine number of classes in yolox model

* Further refinements

* Removed unnecesarry comments, improved documentation, addressed PR items

* Fixed lint formatting issues
This commit is contained in:
Anil Ozyalcin 2023-02-03 17:36:37 -08:00 committed by GitHub
parent 7083a5c9b6
commit b33094207c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 120 additions and 21 deletions

View File

@ -101,7 +101,7 @@ The OpenVINO device to be used is specified using the `"device"` attribute accor
OpenVINO is supported on 6th Gen Intel platforms (Skylake) and newer. A supported Intel platform is required to use the `GPU` device with OpenVINO. The `MYRIAD` device may be run on any platform, including Arm devices. For detailed system requirements, see [OpenVINO System Requirements](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/system-requirements.html)
An OpenVINO model is provided in the container at `/openvino-model/ssdlite_mobilenet_v2.xml` and is used by this detector type by default. The model comes from Intel's Open Model Zoo [SSDLite MobileNet V2](https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/ssdlite_mobilenet_v2) and is converted to an FP16 precision IR model. Use the model configuration shown below when using the OpenVINO detector.
An OpenVINO model is provided in the container at `/openvino-model/ssdlite_mobilenet_v2.xml` and is used by this detector type by default. The model comes from Intel's Open Model Zoo [SSDLite MobileNet V2](https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/ssdlite_mobilenet_v2) and is converted to an FP16 precision IR model. Use the model configuration shown below when using the OpenVINO detector with the default model.
```yaml
detectors:
@ -119,6 +119,25 @@ model:
labelmap_path: /openvino-model/coco_91cl_bkgr.txt
```
This detector also supports YOLOx models, and has been verified to work with the [yolox_tiny](https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/yolox-tiny) model from Intel's Open Model Zoo. Frigate does not come with `yolox_tiny` model, you will need to follow [OpenVINO documentation](https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/yolox-tiny) to provide your own model to Frigate. There is currently no support for other types of YOLO models (YOLOv3, YOLOv4, etc...). Below is an example of how `yolox_tiny` and other yolox variants can be used in Frigate:
```yaml
detectors:
ov:
type: openvino
device: AUTO
model:
path: /path/to/yolox_tiny.xml
model:
width: 416
height: 416
input_tensor: nchw
input_pixel_format: bgr
model_type: yolox
labelmap_path: /path/to/coco_80cl.txt
```
### Intel NCS2 VPU and Myriad X Setup
Intel produces a neural net inference accelleration chip called Myriad X. This chip was sold in their Neural Compute Stick 2 (NCS2) which has been discontinued. If intending to use the MYRIAD device for accelleration, additional setup is required to pass through the USB device. The host needs a udev rule installed to handle the NCS2 device.

View File

@ -105,6 +105,9 @@ model:
# Optional: Object detection model input tensor format
# Valid values are nhwc or nchw (default: shown below)
input_tensor: nhwc
# Optional: Object detection model type, currently only used with the OpenVINO detector
# Valid values are ssd or yolox (default: shown below)
model_type: ssd
# Optional: Label name modifications. These are merged into the standard labelmap.
labelmap:
2: vehicle

View File

@ -23,6 +23,11 @@ class InputTensorEnum(str, Enum):
nhwc = "nhwc"
class ModelTypeEnum(str, Enum):
ssd = "ssd"
yolox = "yolox"
class ModelConfig(BaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
@ -37,6 +42,9 @@ class ModelConfig(BaseModel):
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
model_type: ModelTypeEnum = Field(
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()

View File

@ -3,7 +3,7 @@ import numpy as np
import openvino.runtime as ov
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
from typing import Literal
from pydantic import Extra, Field
@ -24,12 +24,18 @@ class OvDetector(DetectionApi):
def __init__(self, detector_config: OvDetectorConfig):
self.ov_core = ov.Core()
self.ov_model = self.ov_core.read_model(detector_config.model.path)
self.ov_model_type = detector_config.model.model_type
self.h = detector_config.model.height
self.w = detector_config.model.width
self.interpreter = self.ov_core.compile_model(
model=self.ov_model, device_name=detector_config.device
)
logger.info(f"Model Input Shape: {self.interpreter.input(0).shape}")
self.output_indexes = 0
while True:
try:
tensor_shape = self.interpreter.output(self.output_indexes).shape
@ -38,29 +44,92 @@ class OvDetector(DetectionApi):
except:
logger.info(f"Model has {self.output_indexes} Output Tensors")
break
if self.ov_model_type == ModelTypeEnum.yolox:
self.num_classes = tensor_shape[2] - 5
logger.info(f"YOLOX model has {self.num_classes} classes")
self.set_strides_grids()
def set_strides_grids(self):
grids = []
expanded_strides = []
strides = [8, 16, 32]
hsizes = [self.h // stride for stride in strides]
wsizes = [self.w // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
self.grids = np.concatenate(grids, 1)
self.expanded_strides = np.concatenate(expanded_strides, 1)
def detect_raw(self, tensor_input):
infer_request = self.interpreter.create_infer_request()
infer_request.infer([tensor_input])
results = infer_request.get_output_tensor()
if self.ov_model_type == ModelTypeEnum.ssd:
results = infer_request.get_output_tensor()
detections = np.zeros((20, 6), np.float32)
i = 0
for object_detected in results.data[0, 0, :]:
if object_detected[0] != -1:
logger.debug(object_detected)
if object_detected[2] < 0.1 or i == 20:
break
detections[i] = [
object_detected[1], # Label ID
float(object_detected[2]), # Confidence
object_detected[4], # y_min
object_detected[3], # x_min
object_detected[6], # y_max
object_detected[5], # x_max
]
i += 1
detections = np.zeros((20, 6), np.float32)
i = 0
for object_detected in results.data[0, 0, :]:
if object_detected[0] != -1:
logger.debug(object_detected)
if object_detected[2] < 0.1 or i == 20:
break
detections[i] = [
object_detected[1], # Label ID
float(object_detected[2]), # Confidence
object_detected[4], # y_min
object_detected[3], # x_min
object_detected[6], # y_max
object_detected[5], # x_max
]
i += 1
return detections
elif self.ov_model_type == ModelTypeEnum.yolox:
out_tensor = infer_request.get_output_tensor()
# [x, y, h, w, box_score, class_no_1, ..., class_no_80],
results = out_tensor.data
results[..., :2] = (results[..., :2] + self.grids) * self.expanded_strides
results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides
image_pred = results[0, ...]
return detections
class_conf = np.max(
image_pred[:, 5 : 5 + self.num_classes], axis=1, keepdims=True
)
class_pred = np.argmax(image_pred[:, 5 : 5 + self.num_classes], axis=1)
class_pred = np.expand_dims(class_pred, axis=1)
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= 0.3).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
dets = np.concatenate((image_pred[:, :5], class_conf, class_pred), axis=1)
dets = dets[conf_mask]
ordered = dets[dets[:, 5].argsort()[::-1]][:20]
detections = np.zeros((20, 6), np.float32)
i = 0
for object_detected in ordered:
if i < 20:
detections[i] = [
object_detected[6], # Label ID
object_detected[5], # Confidence
(object_detected[1] - (object_detected[3] / 2))
/ self.h, # y_min
(object_detected[0] - (object_detected[2] / 2))
/ self.w, # x_min
(object_detected[1] + (object_detected[3] / 2))
/ self.h, # y_max
(object_detected[0] + (object_detected[2] / 2))
/ self.w, # x_max
]
i += 1
else:
break
return detections