From 18bf7f93fa5bbbd945eb8301e810490f4278844a Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 30 Sep 2024 15:40:46 -0600 Subject: [PATCH] Improve rocm handling of different models (#14072) * Improve rocm handling of different models * Formatting * Fix type check --- docker/rocm/Dockerfile | 1 + frigate/object_detection.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/docker/rocm/Dockerfile b/docker/rocm/Dockerfile index a1d6ce832..eebe04878 100644 --- a/docker/rocm/Dockerfile +++ b/docker/rocm/Dockerfile @@ -83,6 +83,7 @@ ARG AMDGPU COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/ COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/ +COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/ COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/ COPY --from=rocm /opt/rocm-dist/ / COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/ diff --git a/frigate/object_detection.py b/frigate/object_detection.py index eac019a7a..eaa3b4e04 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -12,7 +12,8 @@ from setproctitle import setproctitle import frigate.util as util from frigate.detectors import create_detector -from frigate.detectors.detector_config import InputTensorEnum +from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum +from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager from frigate.util.services import listen @@ -22,11 +23,11 @@ logger = logging.getLogger(__name__) class ObjectDetector(ABC): @abstractmethod - def detect(self, tensor_input, threshold=0.4): + def detect(self, tensor_input, threshold: float = 0.4): pass -def tensor_transform(desired_shape): +def tensor_transform(desired_shape: InputTensorEnum): # Currently this function only supports BHWC permutations if desired_shape == InputTensorEnum.nhwc: return None @@ -37,8 +38,8 @@ def tensor_transform(desired_shape): class LocalObjectDetector(ObjectDetector): def __init__( self, - detector_config=None, - labels=None, + detector_config: BaseDetectorConfig = None, + labels: str = None, ): self.fps = EventsPerSecond() if labels is None: @@ -47,7 +48,13 @@ class LocalObjectDetector(ObjectDetector): self.labels = load_labels(labels) if detector_config: - self.input_transform = tensor_transform(detector_config.model.input_tensor) + if detector_config.type == ROCM_DETECTOR_KEY: + # ROCm requires NHWC as input + self.input_transform = None + else: + self.input_transform = tensor_transform( + detector_config.model.input_tensor + ) else: self.input_transform = None