mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Improve rocm handling of different models (#14072)
* Improve rocm handling of different models * Formatting * Fix type check
This commit is contained in:
parent
c73f694c63
commit
18bf7f93fa
@ -83,6 +83,7 @@ ARG AMDGPU
|
||||
|
||||
COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
|
||||
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
|
||||
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/
|
||||
COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
|
||||
COPY --from=rocm /opt/rocm-dist/ /
|
||||
COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/
|
||||
|
@ -12,7 +12,8 @@ from setproctitle import setproctitle
|
||||
|
||||
import frigate.util as util
|
||||
from frigate.detectors import create_detector
|
||||
from frigate.detectors.detector_config import InputTensorEnum
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
||||
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||
from frigate.util.image import SharedMemoryFrameManager
|
||||
from frigate.util.services import listen
|
||||
@ -22,11 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectDetector(ABC):
|
||||
@abstractmethod
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
def detect(self, tensor_input, threshold: float = 0.4):
|
||||
pass
|
||||
|
||||
|
||||
def tensor_transform(desired_shape):
|
||||
def tensor_transform(desired_shape: InputTensorEnum):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
@ -37,8 +38,8 @@ def tensor_transform(desired_shape):
|
||||
class LocalObjectDetector(ObjectDetector):
|
||||
def __init__(
|
||||
self,
|
||||
detector_config=None,
|
||||
labels=None,
|
||||
detector_config: BaseDetectorConfig = None,
|
||||
labels: str = None,
|
||||
):
|
||||
self.fps = EventsPerSecond()
|
||||
if labels is None:
|
||||
@ -47,7 +48,13 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
if detector_config:
|
||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
||||
if detector_config.type == ROCM_DETECTOR_KEY:
|
||||
# ROCm requires NHWC as input
|
||||
self.input_transform = None
|
||||
else:
|
||||
self.input_transform = tensor_transform(
|
||||
detector_config.model.input_tensor
|
||||
)
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user