From 99b2b372389978f1ce8cf4aab663f6c1e06327d2 Mon Sep 17 00:00:00 2001 From: Abinila Siva <163017635+abinila4@users.noreply.github.com> Date: Wed, 21 May 2025 12:13:41 -0400 Subject: [PATCH] Added variable model size support --- frigate/detectors/plugins/memryx.py | 37 +++++++++++++++++++---------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/frigate/detectors/plugins/memryx.py b/frigate/detectors/plugins/memryx.py index 89ca9caf3..77fdd829e 100644 --- a/frigate/detectors/plugins/memryx.py +++ b/frigate/detectors/plugins/memryx.py @@ -73,16 +73,26 @@ class MemryXDetector(DetectionApi): self.cache_dir = "/memryx_models" if self.memx_model_type == ModelTypeEnum.yologeneric: - self.model_url = ( - "https://developer.memryx.com/example_files/1p2_frigate/yolo-generic.zip" + model_mapping = { + (640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_640.zip", "yolov9_640"), + (320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320") + } + self.model_url, self.model_folder = model_mapping.get( + (self.memx_model_height, self.memx_model_width), + ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320") ) self.expected_dfp_model = ( - "YOLO_v9_small_640_640_3_onnx.dfp" + "YOLO_v9_small_onnx.dfp" ) elif self.memx_model_type == ModelTypeEnum.yolonas: - self.model_url = ( - "https://developer.memryx.com/example_files/1p2_frigate/yolonas.zip" + model_mapping = { + (640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_640.zip", "yolonas_640"), + (320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320") + } + self.model_url, self.model_folder = model_mapping.get( + (self.memx_model_height, self.memx_model_width), + ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320") ) self.expected_dfp_model = ( "yolo_nas_s.dfp" @@ -92,6 +102,7 @@ class MemryXDetector(DetectionApi): ) elif self.memx_model_type == ModelTypeEnum.yolox: + self.model_folder = "yolox" self.model_url = ( "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip" ) @@ -101,6 +112,7 @@ class MemryXDetector(DetectionApi): self.set_strides_grids() elif self.memx_model_type == ModelTypeEnum.ssd: + self.model_folder = "ssd" self.model_url = ( "https://developer.memryx.com/example_files/1p2_frigate/ssd.zip" ) @@ -124,7 +136,7 @@ class MemryXDetector(DetectionApi): # Load MemryX Model with a unique device target self.accl = AsyncAccl( self.memx_model_path, - mxserver_addr=mxserver_addr, + mxserver_addr = mxserver_addr, group_id=device_id, # AsyncAccl device id ) @@ -145,7 +157,7 @@ class MemryXDetector(DetectionApi): raise def load_yolo_constants(self): - base = f"{self.cache_dir}/{self.memx_model_type.value}" + base = f"{self.cache_dir}/{self.model_folder}" # constants for yolov9 post-processing self.const_A = np.load( f"{base}/_model_22_Constant_9_output_0.npy" @@ -170,7 +182,7 @@ class MemryXDetector(DetectionApi): return logger.info(f"Model files not found. Downloading from {self.model_url}...") - zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip") + zip_path = os.path.join(self.cache_dir, f"{self.model_folder}.zip") try: # Before downloading, check if already downloaded @@ -180,7 +192,7 @@ class MemryXDetector(DetectionApi): logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...") # Before extracting, check if model folder exists already - model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value) + model_subdir = os.path.join(self.cache_dir, self.model_folder) if not os.path.exists(model_subdir): with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(self.cache_dir) @@ -541,7 +553,8 @@ class MemryXDetector(DetectionApi): # Perform split along axis 1 split_0, split_1 = np.split(concat_4, indices, axis=axis) - shape1 = np.array([1, 4, 16, 8400]) + num_boxes = 2100 if self.memx_model_height == 320 else 8400 + shape1 = np.array([1, 4, 16, num_boxes]) reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0) transpose_1 = reshape_4.transpose(0, 2, 1, 3) @@ -562,7 +575,7 @@ class MemryXDetector(DetectionApi): softmax_output * W, axis=1, keepdims=True ) # shape: (1, 1, 4, 8400) - shape2 = np.array([1, 4, 8400]) + shape2 = np.array([1, 4, num_boxes]) reshape_5 = self.onnx_reshape_with_allowzero( conv_output, shape2, allowzero=0 ) @@ -617,4 +630,4 @@ class MemryXDetector(DetectionApi): def detect_raw(self, tensor_input: np.ndarray): """Removed synchronous detect_raw() function so that we only use async""" - return 0 \ No newline at end of file + return 0