mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Improve rocm handling of different models (#14072)
* Improve rocm handling of different models * Formatting * Fix type check
This commit is contained in:
parent
c73f694c63
commit
18bf7f93fa
@ -83,6 +83,7 @@ ARG AMDGPU
|
|||||||
|
|
||||||
COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
|
COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
|
||||||
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
|
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
|
||||||
|
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/
|
||||||
COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
|
COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
|
||||||
COPY --from=rocm /opt/rocm-dist/ /
|
COPY --from=rocm /opt/rocm-dist/ /
|
||||||
COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/
|
COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/
|
||||||
|
@ -12,7 +12,8 @@ from setproctitle import setproctitle
|
|||||||
|
|
||||||
import frigate.util as util
|
import frigate.util as util
|
||||||
from frigate.detectors import create_detector
|
from frigate.detectors import create_detector
|
||||||
from frigate.detectors.detector_config import InputTensorEnum
|
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
||||||
|
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||||
from frigate.util.image import SharedMemoryFrameManager
|
from frigate.util.image import SharedMemoryFrameManager
|
||||||
from frigate.util.services import listen
|
from frigate.util.services import listen
|
||||||
@ -22,11 +23,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ObjectDetector(ABC):
|
class ObjectDetector(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def detect(self, tensor_input, threshold=0.4):
|
def detect(self, tensor_input, threshold: float = 0.4):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def tensor_transform(desired_shape):
|
def tensor_transform(desired_shape: InputTensorEnum):
|
||||||
# Currently this function only supports BHWC permutations
|
# Currently this function only supports BHWC permutations
|
||||||
if desired_shape == InputTensorEnum.nhwc:
|
if desired_shape == InputTensorEnum.nhwc:
|
||||||
return None
|
return None
|
||||||
@ -37,8 +38,8 @@ def tensor_transform(desired_shape):
|
|||||||
class LocalObjectDetector(ObjectDetector):
|
class LocalObjectDetector(ObjectDetector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector_config=None,
|
detector_config: BaseDetectorConfig = None,
|
||||||
labels=None,
|
labels: str = None,
|
||||||
):
|
):
|
||||||
self.fps = EventsPerSecond()
|
self.fps = EventsPerSecond()
|
||||||
if labels is None:
|
if labels is None:
|
||||||
@ -47,7 +48,13 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.labels = load_labels(labels)
|
self.labels = load_labels(labels)
|
||||||
|
|
||||||
if detector_config:
|
if detector_config:
|
||||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
if detector_config.type == ROCM_DETECTOR_KEY:
|
||||||
|
# ROCm requires NHWC as input
|
||||||
|
self.input_transform = None
|
||||||
|
else:
|
||||||
|
self.input_transform = tensor_transform(
|
||||||
|
detector_config.model.input_tensor
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.input_transform = None
|
self.input_transform = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user