From 18f1cc16388b2aa296e28533eb9b1e56363044ee Mon Sep 17 00:00:00 2001 From: Abinila Siva Date: Mon, 5 May 2025 16:18:20 -0400 Subject: [PATCH] Add AsyncLocalObjectDetector class --- frigate/detectors/plugins/memryx.py | 18 ++--- frigate/object_detection/base.py | 105 ++++++++++++++++++---------- 2 files changed, 78 insertions(+), 45 deletions(-) diff --git a/frigate/detectors/plugins/memryx.py b/frigate/detectors/plugins/memryx.py index 280d94817..3c710ec86 100644 --- a/frigate/detectors/plugins/memryx.py +++ b/frigate/detectors/plugins/memryx.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field from typing_extensions import Literal from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum -from frigate.util.model import __post_process_multipart_yolo +from frigate.util.model import post_process_yolo logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ class MemryXDetector(DetectionApi): if self.memx_model_type == ModelTypeEnum.yologeneric: self.model_url = ( - "https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip" + "https://developer.memryx.com/example_files/1p2_frigate/yolo-generic.zip" ) self.expected_dfp_model = ( "YOLO_v9_small_640_640_3_onnx.dfp" @@ -214,6 +214,7 @@ class MemryXDetector(DetectionApi): if self.memx_model_type == ModelTypeEnum.yolox: tensor_input = tensor_input.squeeze(0) + tensor_input = tensor_input * 255.0 padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114 scale = min( @@ -238,10 +239,11 @@ class MemryXDetector(DetectionApi): # Step 5: Concatenate along the channel dimension (axis 2) concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2) processed_input = concatenated_img.astype(np.float32) - else: - processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize - # Assuming original input is always NHWC and MemryX wants HWNC: - processed_input = processed_input.transpose(1, 2, 0, 3) # NHWC -> HWNC + + else: + tensor_input = tensor_input.squeeze(0) # (H, W, C) + # Add axis=2 to create Z=1: (H, W, Z=1, C) + processed_input = np.expand_dims(tensor_input, axis=2) # Now (H, W, 1, 3) # Send frame to MemryX for processing self.capture_queue.put(processed_input) @@ -596,7 +598,7 @@ class MemryXDetector(DetectionApi): sigmoid_output = self.sigmoid(split_1) outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1) - final_detections = __post_process_multipart_yolo( + final_detections = post_process_yolo( outputs, self.memx_model_width, self.memx_model_height ) self.output_queue.put(final_detections) @@ -617,4 +619,4 @@ class MemryXDetector(DetectionApi): def detect_raw(self, tensor_input: np.ndarray): """Removed synchronous detect_raw() function so that we only use async""" - return 0 + return 0 \ No newline at end of file diff --git a/frigate/object_detection/base.py b/frigate/object_detection/base.py index 8ad8fa322..46f28cf69 100644 --- a/frigate/object_detection/base.py +++ b/frigate/object_detection/base.py @@ -7,6 +7,8 @@ import queue import signal import threading from abc import ABC, abstractmethod +from multiprocessing import Queue, Value +from multiprocessing.synchronize import Event as MpEvent import numpy as np from setproctitle import setproctitle @@ -16,6 +18,7 @@ from frigate.detectors import create_detector from frigate.detectors.detector_config import ( BaseDetectorConfig, InputDTypeEnum, + ModelConfig, ) from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory @@ -32,7 +35,7 @@ class ObjectDetector(ABC): pass -class LocalObjectDetector(ObjectDetector): +class BaseLocalDetector(ObjectDetector): def __init__( self, detector_config: BaseDetectorConfig = None, @@ -54,6 +57,18 @@ class LocalObjectDetector(ObjectDetector): self.detect_api = create_detector(detector_config) + def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray: + if self.input_transform: + tensor_input = np.transpose(tensor_input, self.input_transform) + + if self.dtype == InputDTypeEnum.float: + tensor_input = tensor_input.astype(np.float32) + tensor_input /= 255 + elif self.dtype == InputDTypeEnum.float_denorm: + tensor_input = tensor_input.astype(np.float32) + + return tensor_input + def detect(self, tensor_input: np.ndarray, threshold=0.4): detections = [] @@ -71,25 +86,30 @@ class LocalObjectDetector(ObjectDetector): self.fps.update() return detections + +class LocalObjectDetector(BaseLocalDetector): def detect_raw(self, tensor_input: np.ndarray): - if self.input_transform: - tensor_input = np.transpose(tensor_input, self.input_transform) - - if self.dtype == InputDTypeEnum.float: - tensor_input = tensor_input.astype(np.float32) - tensor_input /= 255 - + tensor_input = self._transform_input(tensor_input) return self.detect_api.detect_raw(tensor_input=tensor_input) -def prepare_detector(name, detector_config, out_events): +class AsyncLocalObjectDetector(BaseLocalDetector): + def async_send_input(self, tensor_input: np.ndarray, connection_id): + tensor_input = self._transform_input(tensor_input) + return self.detect_api.send_input(connection_id, tensor_input) + + def async_receive_output(self): + return self.detect_api.receive_output() + + +def prepare_detector(name, out_events): threading.current_thread().name = f"detector:{name}" logger = logging.getLogger(f"detector.{name}") logger.info(f"Starting detection process: {os.getpid()}") setproctitle(f"frigate.detector.{name}") listen() - stop_event = mp.Event() + stop_event: MpEvent = mp.Event() def receiveSignal(signalNumber, frame): stop_event.set() @@ -98,7 +118,6 @@ def prepare_detector(name, detector_config, out_events): signal.signal(signal.SIGINT, receiveSignal) frame_manager = SharedMemoryFrameManager() - object_detector = LocalObjectDetector(detector_config=detector_config) outputs = {} for name in out_events.keys(): @@ -106,22 +125,24 @@ def prepare_detector(name, detector_config, out_events): out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) outputs[name] = {"shm": out_shm, "np": out_np} - return stop_event, frame_manager, object_detector, outputs, logger + return stop_event, frame_manager, outputs, logger def run_detector( name: str, - detection_queue: mp.Queue, - out_events: dict[str, mp.Event], - avg_speed, - start, - detector_config, + detection_queue: Queue, + out_events: dict[str, MpEvent], + avg_speed: Value, + start: Value, + detector_config: BaseDetectorConfig, ): - stop_event, frame_manager, object_detector, outputs, logger = prepare_detector( - name, detector_config, out_events + stop_event, frame_manager, outputs, logger = prepare_detector( + name, out_events ) + object_detector = LocalObjectDetector(detector_config=detector_config) + while not stop_event.is_set(): try: connection_id = detection_queue.get(timeout=1) @@ -152,17 +173,19 @@ def run_detector( def async_run_detector( name: str, - detection_queue: mp.Queue, - out_events: dict[str, mp.Event], - avg_speed, - start, - detector_config, + detection_queue: Queue, + out_events: dict[str, MpEvent], + avg_speed: Value, + start: Value, + detector_config: BaseDetectorConfig, ): - stop_event, frame_manager, object_detector, outputs, logger = prepare_detector( - name, detector_config, out_events + stop_event, frame_manager, outputs, logger = prepare_detector( + name, out_events ) + object_detector = AsyncLocalObjectDetector(detector_config=detector_config) + def detect_worker(): # Continuously fetch frames and send them to the async detector logger.info("Starting Detect Worker Thread") @@ -184,13 +207,13 @@ def async_run_detector( # send input to Accelator start.value = datetime.datetime.now().timestamp() - object_detector.detect_api.send_input(connection_id, input_frame) + object_detector.async_send_input(input_frame, connection_id) def result_worker(): # Continuously receive detection results from the async detector logger.info("Starting Result Worker Thread") while not stop_event.is_set(): - connection_id, detections = object_detector.detect_api.receive_output() + connection_id, detections = object_detector.async_receive_output() duration = datetime.datetime.now().timestamp() - start.value frame_manager.close(connection_id) @@ -222,17 +245,17 @@ def async_run_detector( class ObjectDetectProcess: def __init__( self, - name, - detection_queue, - out_events, - detector_config, + name: str, + detection_queue: Queue, + out_events: dict[str, MpEvent], + detector_config: BaseDetectorConfig, ): self.name = name self.out_events = out_events self.detection_queue = detection_queue - self.avg_inference_speed = mp.Value("d", 0.01) - self.detection_start = mp.Value("d", 0.0) - self.detect_process = None + self.avg_inference_speed = Value("d", 0.01) + self.detection_start = Value("d", 0.0) + self.detect_process: util.Process | None = None self.detector_config = detector_config self.start_or_restart() @@ -285,7 +308,15 @@ class ObjectDetectProcess: class RemoteObjectDetector: - def __init__(self, name, labels, detection_queue, event, model_config, stop_event): + def __init__( + self, + name: str, + labels: dict[int, str], + detection_queue: Queue, + event: MpEvent, + model_config: ModelConfig, + stop_event: MpEvent, + ): self.labels = labels self.name = name self.fps = EventsPerSecond() @@ -328,4 +359,4 @@ class RemoteObjectDetector: def cleanup(self): self.shm.unlink() - self.out_shm.unlink() + self.out_shm.unlink() \ No newline at end of file