Add support for RF-DETR models (#17298)

* Add support for rf-detr models

* Add docs for rf-detr model

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-03-21 18:55:46 -06:00 committed by GitHub
parent 4e83237d47
commit 48e4c44b32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 5 deletions

View File

@ -342,7 +342,7 @@ Note that the labelmap uses a subset of the complete COCO label set that has onl
#### D-FINE
[D-FINE](https://github.com/Peterande/D-FINE) is the [current state of the art](https://paperswithcode.com/sota/real-time-object-detection-on-coco?p=d-fine-redefine-regression-task-in-detrs-as) at the time of writing. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
[D-FINE](https://github.com/Peterande/D-FINE) is a DETR based model. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
After placing the downloaded onnx model in your config/model_cache folder, you can use the following configuration:
@ -647,9 +647,29 @@ model:
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
#### RF-DETR
[RF-DETR](https://github.com/roboflow/rf-detr) is a DETR based model. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-rf-detr-model) for more informatoin on downloading the RF-DETR model for use in Frigate.
After placing the downloaded onnx model in your `config/model_cache` folder, you can use the following configuration:
```
detectors:
onnx:
type: onnx
model:
model_type: rfdetr
width: 560
height: 560
input_tensor: nchw
input_dtype: float
path: /config/model_cache/rfdetr.onnx
```
#### D-FINE
[D-FINE](https://github.com/Peterande/D-FINE) is the [current state of the art](https://paperswithcode.com/sota/real-time-object-detection-on-coco?p=d-fine-redefine-regression-task-in-detrs-as) at the time of writing. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
[D-FINE](https://github.com/Peterande/D-FINE) is a DETR based model. The ONNX exported models are supported, but not included by default. See [the models section](#downloading-d-fine-model) for more information on downloading the D-FINE model for use in Frigate.
After placing the downloaded onnx model in your config/model_cache folder, you can use the following configuration:
@ -873,6 +893,16 @@ Make sure you change the batch size to 1 before exporting.
:::
### Download RF-DETR Model
To export as ONNX:
1. `pip3 install rfdetr`
2. `python`
3. `from rfdetr import RFDETRBase`
4. `x = RFDETRBase()`
5. `x.export()`
### Downloading YOLO-NAS Model
You can build and download a compatible model with pre-trained weights using [this notebook](https://github.com/blakeblackshear/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/blakeblackshear/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb).

View File

@ -33,11 +33,12 @@ class InputDTypeEnum(str, Enum):
class ModelTypeEnum(str, Enum):
dfine = "dfine"
rfdetr = "rfdetr"
ssd = "ssd"
yolox = "yolox"
yolov9 = "yolov9"
yolonas = "yolonas"
dfine = "dfine"
yologeneric = "yolo-generic"

View File

@ -12,6 +12,7 @@ from frigate.detectors.detector_config import (
from frigate.util.model import (
get_ort_providers,
post_process_dfine,
post_process_rfdetr,
post_process_yolov9,
)
@ -73,7 +74,9 @@ class ONNXDetector(DetectionApi):
model_input_name = self.model.get_inputs()[0].name
tensor_output = self.model.run(None, {model_input_name: tensor_input})
if self.onnx_model_type == ModelTypeEnum.yolonas:
if self.onnx_model_type == ModelTypeEnum.rfdetr:
return post_process_rfdetr(tensor_output)
elif self.onnx_model_type == ModelTypeEnum.yolonas:
predictions = tensor_output[0]
detections = np.zeros((20, 6), np.float32)

View File

@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
### Post Processing
def post_process_dfine(tensor_output: np.ndarray, width, height) -> np.ndarray:
def post_process_dfine(
tensor_output: np.ndarray, width: int, height: int
) -> np.ndarray:
class_ids = tensor_output[0][tensor_output[2] > 0.4]
boxes = tensor_output[1][tensor_output[2] > 0.4]
scores = tensor_output[2][tensor_output[2] > 0.4]
@ -41,6 +45,60 @@ def post_process_dfine(tensor_output: np.ndarray, width, height) -> np.ndarray:
return detections
def post_process_rfdetr(tensor_output: list[np.ndarray, np.ndarray]) -> np.ndarray:
boxes = tensor_output[0]
raw_scores = tensor_output[1]
# apply soft max to scores
exp = np.exp(raw_scores - np.max(raw_scores, axis=-1, keepdims=True))
all_scores = exp / np.sum(exp, axis=-1, keepdims=True)
# get highest scoring class from every detection
scores = np.max(all_scores[0, :, 1:], axis=-1)
labels = np.argmax(all_scores[0, :, 1:], axis=-1)
idxs = scores > 0.4
filtered_boxes = boxes[0, idxs]
filtered_scores = scores[idxs]
filtered_labels = labels[idxs]
# convert boxes from [x_center, y_center, width, height]
x_center, y_center, w, h = (
filtered_boxes[:, 0],
filtered_boxes[:, 1],
filtered_boxes[:, 2],
filtered_boxes[:, 3],
)
x_min = x_center - w / 2
y_min = y_center - h / 2
x_max = x_center + w / 2
y_max = y_center + h / 2
filtered_boxes = np.stack([x_min, y_min, x_max, y_max], axis=-1)
# apply nms
indices = cv2.dnn.NMSBoxes(
filtered_boxes, filtered_scores, score_threshold=0.4, nms_threshold=0.4
)
detections = np.zeros((20, 6), np.float32)
for i, (bbox, confidence, class_id) in enumerate(
zip(filtered_boxes[indices], filtered_scores[indices], filtered_labels[indices])
):
if i == 20:
break
detections[i] = [
class_id,
confidence,
bbox[1],
bbox[0],
bbox[3],
bbox[2],
]
return detections
def post_process_yolov9(predictions: np.ndarray, width, height) -> np.ndarray:
predictions = np.squeeze(predictions).T
scores = np.max(predictions[:, 4:], axis=1)