From bd0ee86db91b8a488736b273e2f9aacb4ab94d99 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 14 Apr 2025 15:05:41 -0600 Subject: [PATCH] Refactor yolov9 detector to support v3, v4, v7 as well (#17697) * Implement blobbed yolov7 post processing and consolidate yolo implementation * Update documentation * Add repo * fix name --- docs/docs/configuration/object_detectors.md | 54 ++++++++---- frigate/detectors/detector_config.py | 1 - frigate/detectors/plugins/onnx.py | 10 +-- frigate/detectors/plugins/openvino.py | 16 ++-- frigate/util/model.py | 96 ++++++++++++++++++++- 5 files changed, 145 insertions(+), 32 deletions(-) diff --git a/docs/docs/configuration/object_detectors.md b/docs/docs/configuration/object_detectors.md index 31f0df1da..2906a7829 100644 --- a/docs/docs/configuration/object_detectors.md +++ b/docs/docs/configuration/object_detectors.md @@ -312,13 +312,13 @@ model: Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects. -#### YOLOv9 +#### YOLO (v3, v4, v7, v9) -[YOLOv9](https://github.com/WongKinYiu/yolov9) models are supported, but not included by default. +YOLOv3, YOLOv4, YOLOv7, and [YOLOv9](https://github.com/WongKinYiu/yolov9) models are supported, but not included by default. :::tip -The YOLOv9 detector has been designed to support YOLOv9 models, but may support other YOLO model architectures as well. +The YOLO detector has been designed to support YOLOv3, YOLOv4, YOLOv7, and YOLOv9 models, but may support other YOLO model architectures as well. ::: @@ -331,12 +331,12 @@ detectors: device: GPU model: - model_type: yolov9 - width: 640 # <--- should match the imgsize set during model export - height: 640 # <--- should match the imgsize set during model export + model_type: yolo-generic + width: 320 # <--- should match the imgsize set during model export + height: 320 # <--- should match the imgsize set during model export input_tensor: nchw input_dtype: float - path: /config/model_cache/yolov9-t.onnx + path: /config/model_cache/yolo.onnx labelmap_path: /labelmap/coco-80.txt ``` @@ -653,13 +653,13 @@ model: labelmap_path: /labelmap/coco-80.txt ``` -#### YOLOv9 +#### YOLO (v3, v4, v7, v9) -[YOLOv9](https://github.com/WongKinYiu/yolov9) models are supported, but not included by default. +YOLOv3, YOLOv4, YOLOv7, and [YOLOv9](https://github.com/WongKinYiu/yolov9) models are supported, but not included by default. :::tip -The YOLOv9 detector has been designed to support YOLOv9 models, but may support other YOLO model architectures as well. +The YOLO detector has been designed to support YOLOv3, YOLOv4, YOLOv7, and YOLOv9 models, but may support other YOLO model architectures as well. ::: @@ -671,12 +671,12 @@ detectors: type: onnx model: - model_type: yolov9 - width: 640 # <--- should match the imgsize set during model export - height: 640 # <--- should match the imgsize set during model export + model_type: yolo-generic + width: 320 # <--- should match the imgsize set during model export + height: 320 # <--- should match the imgsize set during model export input_tensor: nchw input_dtype: float - path: /config/model_cache/yolov9-t.onnx + path: /config/model_cache/yolo.onnx labelmap_path: /labelmap/coco-80.txt ``` @@ -684,7 +684,7 @@ Note that the labelmap uses a subset of the complete COCO label set that has onl #### RF-DETR -[RF-DETR](https://github.com/roboflow/rf-detr) is a DETR based model. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-rf-detr-model) for more informatoin on downloading the RF-DETR model for use in Frigate. +[RF-DETR](https://github.com/roboflow/rf-detr) is a DETR based model. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-rf-detr-model) for more information on downloading the RF-DETR model for use in Frigate. After placing the downloaded onnx model in your `config/model_cache` folder, you can use the following configuration: @@ -959,3 +959,27 @@ The pre-trained YOLO-NAS weights from DeciAI are subject to their license and ca ::: The input image size in this notebook is set to 320x320. This results in lower CPU usage and faster inference times without impacting performance in most cases due to the way Frigate crops video frames to areas of interest before running detection. The notebook and config can be updated to 640x640 if desired. + +### Downloading YOLO Models + +#### YOLOv3, YOLOv4, and YOLOv7 + +To export as ONNX: + +```sh +git clone https://github.com/NateMeyer/tensorrt_demos +cd tensorrt_demos/yolo +./download_yolo.sh +python3 yolo_to_onnx.py -m yolov7-320 +``` + +#### YOLOv9 + +YOLOv9 models can be exported using the below code or they [can be downloaded from hugging face](https://huggingface.co/Xenova/yolov9-onnx/tree/main) + +```sh +git clone https://github.com/WongKinYiu/yolov9 +cd yolov9 +wget -O yolov9-t.pt "https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-t-converted.pt" +python3 export.py --weights ./yolov9-t-converted.pt --imgsz 320 --simplify +``` diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index ce7738493..2c54d11e5 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -37,7 +37,6 @@ class ModelTypeEnum(str, Enum): rfdetr = "rfdetr" ssd = "ssd" yolox = "yolox" - yolov9 = "yolov9" yolonas = "yolonas" yologeneric = "yolo-generic" diff --git a/frigate/detectors/plugins/onnx.py b/frigate/detectors/plugins/onnx.py index a10447b48..aef6e909b 100644 --- a/frigate/detectors/plugins/onnx.py +++ b/frigate/detectors/plugins/onnx.py @@ -13,7 +13,7 @@ from frigate.util.model import ( get_ort_providers, post_process_dfine, post_process_rfdetr, - post_process_yolov9, + post_process_yolo, ) logger = logging.getLogger(__name__) @@ -97,12 +97,8 @@ class ONNXDetector(DetectionApi): x_max / self.w, ] return detections - elif ( - self.onnx_model_type == ModelTypeEnum.yolov9 - or self.onnx_model_type == ModelTypeEnum.yologeneric - ): - predictions: np.ndarray = tensor_output[0] - return post_process_yolov9(predictions, self.w, self.h) + elif self.onnx_model_type == ModelTypeEnum.yologeneric: + return post_process_yolo(tensor_output, self.w, self.h) else: raise Exception( f"{self.onnx_model_type} is currently not supported for onnx. See the docs for more info on supported models." diff --git a/frigate/detectors/plugins/openvino.py b/frigate/detectors/plugins/openvino.py index d90352772..9c7ed5248 100644 --- a/frigate/detectors/plugins/openvino.py +++ b/frigate/detectors/plugins/openvino.py @@ -13,7 +13,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum from frigate.util.model import ( post_process_dfine, post_process_rfdetr, - post_process_yolov9, + post_process_yolo, ) logger = logging.getLogger(__name__) @@ -33,7 +33,6 @@ class OvDetector(DetectionApi): ModelTypeEnum.rfdetr, ModelTypeEnum.ssd, ModelTypeEnum.yolonas, - ModelTypeEnum.yolov9, ModelTypeEnum.yologeneric, ModelTypeEnum.yolox, ] @@ -232,12 +231,13 @@ class OvDetector(DetectionApi): x_max / self.w, ] return detections - elif ( - self.ov_model_type == ModelTypeEnum.yolov9 - or self.ov_model_type == ModelTypeEnum.yologeneric - ): - out_tensor = infer_request.get_output_tensor(0).data - return post_process_yolov9(out_tensor, self.w, self.h) + elif self.ov_model_type == ModelTypeEnum.yologeneric: + out_tensor = [] + + for item in infer_request.output_tensors: + out_tensor.append(item.data) + + return post_process_yolo(out_tensor, self.w, self.h) elif self.ov_model_type == ModelTypeEnum.yolox: out_tensor = infer_request.get_output_tensor() # [x, y, h, w, box_score, class_no_1, ..., class_no_80], diff --git a/frigate/util/model.py b/frigate/util/model.py index 19b3b1bf5..a4ff9bd75 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -99,7 +99,94 @@ def post_process_rfdetr(tensor_output: list[np.ndarray, np.ndarray]) -> np.ndarr return detections -def post_process_yolov9(predictions: np.ndarray, width, height) -> np.ndarray: +def __post_process_multipart_yolo( + output_list, + width, + height, +): + anchors = [ + [(12, 16), (19, 36), (40, 28)], + [(36, 75), (76, 55), (72, 146)], + [(142, 110), (192, 243), (459, 401)], + ] + + stride_map = {0: 8, 1: 16, 2: 32} + + all_boxes = [] + all_scores = [] + all_class_ids = [] + + for i, output in enumerate(output_list): + bs, _, ny, nx = output.shape + stride = stride_map[i] + anchor_set = anchors[i] + + num_anchors = len(anchor_set) + output = output.reshape(bs, num_anchors, 85, ny, nx) + output = output.transpose(0, 1, 3, 4, 2) + output = output[0] + + for a_idx, (anchor_w, anchor_h) in enumerate(anchor_set): + for y in range(ny): + for x in range(nx): + pred = output[a_idx, y, x] + class_probs = pred[5:] + class_id = np.argmax(class_probs) + class_conf = class_probs[class_id] + conf = class_conf * pred[4] + + if conf < 0.4: + continue + + dx = pred[0] + dy = pred[1] + dw = pred[2] + dh = pred[3] + + bx = ((dx * 2.0 - 0.5) + x) * stride + by = ((dy * 2.0 - 0.5) + y) * stride + bw = ((dw * 2.0) ** 2) * anchor_w + bh = ((dh * 2.0) ** 2) * anchor_h + + x1 = max(0, bx - bw / 2) / width + y1 = max(0, by - bh / 2) / height + x2 = min(width, bx + bw / 2) / width + y2 = min(height, by + bh / 2) / height + + all_boxes.append([x1, y1, x2, y2]) + all_scores.append(conf) + all_class_ids.append(class_id) + + formatted_boxes = [ + [ + int(x1 * width), + int(y1 * height), + int((x2 - x1) * width), + int((y2 - y1) * height), + ] + for x1, y1, x2, y2 in all_boxes + ] + + indices = cv2.dnn.NMSBoxes( + bboxes=formatted_boxes, + scores=all_scores, + score_threshold=0.4, + nms_threshold=0.4, + ) + + results = np.zeros((20, 6), np.float32) + + if len(indices) > 0: + for i, idx in enumerate(indices.flatten()[:20]): + class_id = all_class_ids[idx] + conf = all_scores[idx] + x1, y1, x2, y2 = all_boxes[idx] + results[i] = [class_id, conf, y1, x1, y2, x2] + + return np.array(results, dtype=np.float32) + + +def __post_process_nms_yolo(predictions: np.ndarray, width, height) -> np.ndarray: predictions = np.squeeze(predictions).T scores = np.max(predictions[:, 4:], axis=1) predictions = predictions[scores > 0.4, :] @@ -131,6 +218,13 @@ def post_process_yolov9(predictions: np.ndarray, width, height) -> np.ndarray: return detections +def post_process_yolo(output: list[np.ndarray], width: int, height: int) -> np.ndarray: + if len(output) > 1: + return __post_process_multipart_yolo(output, width, height) + else: + return __post_process_nms_yolo(output[0], width, height) + + ### ONNX Utilities