Improve rocm handling of different models (#14072)

* Improve rocm handling of different models

* Formatting

* Fix type check
This commit is contained in:
Nicolas Mowen 2024-09-30 15:40:46 -06:00 committed by GitHub
parent c73f694c63
commit 18bf7f93fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 6 deletions

View File

@ -83,6 +83,7 @@ ARG AMDGPU
COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/
COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
COPY --from=rocm /opt/rocm-dist/ /
COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/

View File

@ -12,7 +12,8 @@ from setproctitle import setproctitle
import frigate.util as util
from frigate.detectors import create_detector
from frigate.detectors.detector_config import InputTensorEnum
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager
from frigate.util.services import listen
@ -22,11 +23,11 @@ logger = logging.getLogger(__name__)
class ObjectDetector(ABC):
@abstractmethod
def detect(self, tensor_input, threshold=0.4):
def detect(self, tensor_input, threshold: float = 0.4):
pass
def tensor_transform(desired_shape):
def tensor_transform(desired_shape: InputTensorEnum):
# Currently this function only supports BHWC permutations
if desired_shape == InputTensorEnum.nhwc:
return None
@ -37,8 +38,8 @@ def tensor_transform(desired_shape):
class LocalObjectDetector(ObjectDetector):
def __init__(
self,
detector_config=None,
labels=None,
detector_config: BaseDetectorConfig = None,
labels: str = None,
):
self.fps = EventsPerSecond()
if labels is None:
@ -47,7 +48,13 @@ class LocalObjectDetector(ObjectDetector):
self.labels = load_labels(labels)
if detector_config:
self.input_transform = tensor_transform(detector_config.model.input_tensor)
if detector_config.type == ROCM_DETECTOR_KEY:
# ROCm requires NHWC as input
self.input_transform = None
else:
self.input_transform = tensor_transform(
detector_config.model.input_tensor
)
else:
self.input_transform = None