mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Improve rocm handling of different models (#14072)
* Improve rocm handling of different models * Formatting * Fix type check
This commit is contained in:
		
							parent
							
								
									c73f694c63
								
							
						
					
					
						commit
						18bf7f93fa
					
				@ -83,6 +83,7 @@ ARG AMDGPU
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
 | 
					COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/
 | 
				
			||||||
COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
 | 
					COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/
 | 
				
			||||||
 | 
					COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/
 | 
				
			||||||
COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
 | 
					COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/
 | 
				
			||||||
COPY --from=rocm /opt/rocm-dist/ /
 | 
					COPY --from=rocm /opt/rocm-dist/ /
 | 
				
			||||||
COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/
 | 
					COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,8 @@ from setproctitle import setproctitle
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import frigate.util as util
 | 
					import frigate.util as util
 | 
				
			||||||
from frigate.detectors import create_detector
 | 
					from frigate.detectors import create_detector
 | 
				
			||||||
from frigate.detectors.detector_config import InputTensorEnum
 | 
					from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
 | 
				
			||||||
 | 
					from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
 | 
				
			||||||
from frigate.util.builtin import EventsPerSecond, load_labels
 | 
					from frigate.util.builtin import EventsPerSecond, load_labels
 | 
				
			||||||
from frigate.util.image import SharedMemoryFrameManager
 | 
					from frigate.util.image import SharedMemoryFrameManager
 | 
				
			||||||
from frigate.util.services import listen
 | 
					from frigate.util.services import listen
 | 
				
			||||||
@ -22,11 +23,11 @@ logger = logging.getLogger(__name__)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class ObjectDetector(ABC):
 | 
					class ObjectDetector(ABC):
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def detect(self, tensor_input, threshold=0.4):
 | 
					    def detect(self, tensor_input, threshold: float = 0.4):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def tensor_transform(desired_shape):
 | 
					def tensor_transform(desired_shape: InputTensorEnum):
 | 
				
			||||||
    # Currently this function only supports BHWC permutations
 | 
					    # Currently this function only supports BHWC permutations
 | 
				
			||||||
    if desired_shape == InputTensorEnum.nhwc:
 | 
					    if desired_shape == InputTensorEnum.nhwc:
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
@ -37,8 +38,8 @@ def tensor_transform(desired_shape):
 | 
				
			|||||||
class LocalObjectDetector(ObjectDetector):
 | 
					class LocalObjectDetector(ObjectDetector):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        detector_config=None,
 | 
					        detector_config: BaseDetectorConfig = None,
 | 
				
			||||||
        labels=None,
 | 
					        labels: str = None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        self.fps = EventsPerSecond()
 | 
					        self.fps = EventsPerSecond()
 | 
				
			||||||
        if labels is None:
 | 
					        if labels is None:
 | 
				
			||||||
@ -47,7 +48,13 @@ class LocalObjectDetector(ObjectDetector):
 | 
				
			|||||||
            self.labels = load_labels(labels)
 | 
					            self.labels = load_labels(labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if detector_config:
 | 
					        if detector_config:
 | 
				
			||||||
            self.input_transform = tensor_transform(detector_config.model.input_tensor)
 | 
					            if detector_config.type == ROCM_DETECTOR_KEY:
 | 
				
			||||||
 | 
					                # ROCm requires NHWC as input
 | 
				
			||||||
 | 
					                self.input_transform = None
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.input_transform = tensor_transform(
 | 
				
			||||||
 | 
					                    detector_config.model.input_tensor
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.input_transform = None
 | 
					            self.input_transform = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user