From cf429549b9eb9415c633936f847a81dc92914c36 Mon Sep 17 00:00:00 2001 From: Tim Date: Tue, 15 Apr 2025 11:33:59 -0400 Subject: [PATCH] ruff formatting --- frigate/detectors/plugins/memryx.py | 214 ++++++++++++++++++---------- frigate/object_detection/base.py | 9 +- 2 files changed, 139 insertions(+), 84 deletions(-) diff --git a/frigate/detectors/plugins/memryx.py b/frigate/detectors/plugins/memryx.py index 460a645e1..a55a8ba32 100644 --- a/frigate/detectors/plugins/memryx.py +++ b/frigate/detectors/plugins/memryx.py @@ -11,7 +11,9 @@ try: # from memryx import AsyncAccl # Import MemryX SDK from memryx import AsyncAccl except ModuleNotFoundError: - raise ImportError("MemryX SDK is not installed. Install it and set up MIX environment.") + raise ImportError( + "MemryX SDK is not installed. Install it and set up MIX environment." + ) from pydantic import BaseModel, Field from typing_extensions import Literal @@ -23,17 +25,20 @@ logger = logging.getLogger(__name__) DETECTOR_KEY = "memryx" + # Configuration class for model settings class ModelConfig(BaseModel): path: str = Field(default=None, title="Model Path") # Path to the DFP file - labelmap_path: str = Field(default=None, title="Path to Label Map") + labelmap_path: str = Field(default=None, title="Path to Label Map") + class MemryXDetectorConfig(BaseDetectorConfig): type: Literal[DETECTOR_KEY] device: str = Field(default="PCIe", title="Device Path") + class MemryXDetector(DetectionApi): - type_key = DETECTOR_KEY # Set the type key + type_key = DETECTOR_KEY # Set the type key supported_models = [ ModelTypeEnum.ssd, ModelTypeEnum.yolonas, @@ -51,7 +56,7 @@ class MemryXDetector(DetectionApi): self.logger = logger self.memx_model_path = detector_config.model.path # Path to .dfp file - self.memx_post_model = None # Path to .post file + self.memx_post_model = None # Path to .post file self.expected_post_model = None self.memx_device_path = detector_config.device # Device path self.memx_model_height = detector_config.model.height @@ -61,38 +66,60 @@ class MemryXDetector(DetectionApi): self.cache_dir = "/memryx_models" if self.memx_model_type == ModelTypeEnum.yolov9: - self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip" + self.model_url = ( + "https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip" + ) elif self.memx_model_type == ModelTypeEnum.yolov8: - self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip" + self.model_url = ( + "https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip" + ) if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]: # Shared constants for both yolov8 and yolov9 post-processing - self.const_A = np.load("/memryx_models/yolov9/_model_22_Constant_9_output_0.npy") - self.const_B = np.load("/memryx_models/yolov9/_model_22_Constant_10_output_0.npy") - self.const_C = np.load("/memryx_models/yolov9/_model_22_Constant_12_output_0.npy") + self.const_A = np.load( + "/memryx_models/yolov9/_model_22_Constant_9_output_0.npy" + ) + self.const_B = np.load( + "/memryx_models/yolov9/_model_22_Constant_10_output_0.npy" + ) + self.const_C = np.load( + "/memryx_models/yolov9/_model_22_Constant_12_output_0.npy" + ) elif self.memx_model_type == ModelTypeEnum.yolonas: - self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip" + self.model_url = ( + "https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip" + ) self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx" elif self.memx_model_type == ModelTypeEnum.yolox: - self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip" + self.model_url = ( + "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip" + ) self.set_strides_grids() elif self.memx_model_type == ModelTypeEnum.ssd: - self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip" - self.expected_post_model = "ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx" + self.model_url = ( + "https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip" + ) + self.expected_post_model = ( + "ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx" + ) self.check_and_prepare_model() - logger.info(f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}") + logger.info( + f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}" + ) try: # Load MemryX Model logger.info(f"dfp path: {self.memx_model_path}") # Initialization code - self.accl = AsyncAccl(self.memx_model_path, mxserver_addr="host.docker.internal") + self.accl = AsyncAccl( + self.memx_model_path, mxserver_addr="host.docker.internal" + ) # Models that use cropped post-processing sections (YOLO-NAS and SSD) # --> These will be moved to pure numpy in the future to improve performance on low-end CPUs @@ -102,7 +129,9 @@ class MemryXDetector(DetectionApi): self.accl.connect_input(self.process_input) self.accl.connect_output(self.process_output) - logger.info(f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}") + logger.info( + f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}" + ) except Exception as e: logger.error(f"Failed to initialize MemryX model: {e}") @@ -117,14 +146,20 @@ class MemryXDetector(DetectionApi): logger.info(f"Assigned Model Path: {self.memx_model_path}") else: - post_model_file_path = os.path.join(self.cache_dir, self.expected_post_model) + post_model_file_path = os.path.join( + self.cache_dir, self.expected_post_model + ) # Check if both post model file exist if os.path.isfile(post_model_file_path): self.memx_post_model = post_model_file_path - logger.info(f"Post-processing model found at {post_model_file_path}, skipping download.") + logger.info( + f"Post-processing model found at {post_model_file_path}, skipping download." + ) else: - logger.info(f"Model files not found. Downloading from {self.model_url}...") + logger.info( + f"Model files not found. Downloading from {self.model_url}..." + ) zip_path = os.path.join(self.cache_dir, "memryx_model.zip") # Download the ZIP file @@ -143,30 +178,36 @@ class MemryXDetector(DetectionApi): self.memx_post_model = os.path.join(self.cache_dir, file) logger.info(f"Assigned Model Path: {self.memx_model_path}") - logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}") + logger.info( + f"Assigned Post-processing Model Path: {self.memx_post_model}" + ) # Cleanup: Remove the ZIP file after extraction os.remove(zip_path) logger.info("Cleaned up ZIP file after extraction.") - def send_input(self, connection_id, tensor_input: np.ndarray): + def send_input(self, connection_id, tensor_input: np.ndarray): """Pre-process (if needed) and send frame to MemryX input queue""" if tensor_input is None: raise ValueError("[send_input] No image data provided for inference") - + if self.memx_model_type == ModelTypeEnum.yolox: tensor_input = tensor_input.squeeze(0) - padded_img = np.ones((640, 640, 3), - dtype=np.uint8) * 114 + padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114 - scale = min(640 / float(tensor_input.shape[0]), - 640 / float(tensor_input.shape[1])) - sx,sy = int(tensor_input.shape[1] * scale), int(tensor_input.shape[0] * scale) + scale = min( + 640 / float(tensor_input.shape[0]), 640 / float(tensor_input.shape[1]) + ) + sx, sy = ( + int(tensor_input.shape[1] * scale), + int(tensor_input.shape[0] * scale), + ) - resized_img = cv2.resize(tensor_input, (sx,sy), interpolation=cv2.INTER_LINEAR) + resized_img = cv2.resize( + tensor_input, (sx, sy), interpolation=cv2.INTER_LINEAR + ) padded_img[:sy, :sx] = resized_img.astype(np.uint8) - # Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels x0 = padded_img[0::2, 0::2, :] # Top-left @@ -176,7 +217,7 @@ class MemryXDetector(DetectionApi): # Step 5: Concatenate along the channel dimension (axis 2) concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2) - processed_input = concatenated_img.astype(np.float32) + processed_input = concatenated_img.astype(np.float32) else: processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize # Assuming original input is always NHWC and MemryX wants HWNC: @@ -191,8 +232,10 @@ class MemryXDetector(DetectionApi): while True: try: # Wait for a frame from the queue (blocking call) - frame = self.capture_queue.get(block=True) # Blocks until data is available - + frame = self.capture_queue.get( + block=True + ) # Blocks until data is available + return frame except Exception as e: @@ -201,7 +244,9 @@ class MemryXDetector(DetectionApi): def receive_output(self): """Retrieve processed results from MemryX output queue + a copy of the original frame""" - connection_id = self.capture_id_queue.get() # Get the corresponding connection ID + connection_id = ( + self.capture_id_queue.get() + ) # Get the corresponding connection ID detections = self.output_queue.get() # Get detections from MemryX return connection_id, detections @@ -216,10 +261,10 @@ class MemryXDetector(DetectionApi): break (_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction - + if class_id < 0: break - + detections[i] = [ class_id, confidence, @@ -228,7 +273,7 @@ class MemryXDetector(DetectionApi): y_max / self.memx_model_height, x_max / self.memx_model_width, ] - + # Return the list of final detections self.output_queue.put(detections) @@ -244,9 +289,8 @@ class MemryXDetector(DetectionApi): (pos[0] - (pos[2] / 2)) / self.memx_model_width, # x_min (pos[1] + (pos[3] / 2)) / self.memx_model_height, # y_max (pos[0] + (pos[2] / 2)) / self.memx_model_width, # x_max - ] - + def set_strides_grids(self): grids = [] expanded_strides = [] @@ -266,15 +310,13 @@ class MemryXDetector(DetectionApi): self.expanded_strides = np.concatenate(expanded_strides, 1) def sigmoid(self, x: np.ndarray) -> np.ndarray: - return 1 / (1 + np.exp(-x)) def onnx_concat(self, inputs: list, axis: int) -> np.ndarray: - # Ensure all inputs are numpy arrays if not all(isinstance(x, np.ndarray) for x in inputs): raise TypeError("All inputs must be numpy arrays.") - + # Ensure shapes match on non-concat axes ref_shape = list(inputs[0].shape) for i, tensor in enumerate(inputs[1:], start=1): @@ -282,12 +324,13 @@ class MemryXDetector(DetectionApi): if ax == axis: continue if tensor.shape[ax] != ref_shape[ax]: - raise ValueError(f"Shape mismatch at axis {ax} between input[0] and input[{i}]") + raise ValueError( + f"Shape mismatch at axis {ax} between input[0] and input[{i}]" + ) return np.concatenate(inputs, axis=axis) def onnx_reshape(self, data: np.ndarray, shape: np.ndarray) -> np.ndarray: - # Ensure shape is a 1D array of integers target_shape = shape.astype(int).tolist() @@ -295,23 +338,24 @@ class MemryXDetector(DetectionApi): reshaped = np.reshape(data, target_shape) return reshaped - - def post_process_yolox(self, output): - output = [np.expand_dims(tensor, axis=0) for tensor in output] # Shape: (1, H, W, C) + def post_process_yolox(self, output): + output = [ + np.expand_dims(tensor, axis=0) for tensor in output + ] # Shape: (1, H, W, C) # Move channel axis from 3rd (last) position to 1st position → (1, C, H, W) output = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in output] - output_785 = output[0] # 785 - output_794 = output[1] # 794 - output_795 = output[2] # 795 - output_811 = output[3] # 811 - output_820 = output[4] # 820 - output_821 = output[5] # 821 - output_837 = output[6] # 837 - output_846 = output[7] # 846 - output_847 = output[8] # 847 + output_785 = output[0] # 785 + output_794 = output[1] # 794 + output_795 = output[2] # 795 + output_811 = output[3] # 811 + output_820 = output[4] # 820 + output_821 = output[5] # 821 + output_837 = output[6] # 837 + output_846 = output[7] # 846 + output_847 = output[8] # 847 output_795 = self.sigmoid(output_795) output_785 = self.sigmoid(output_785) @@ -324,7 +368,7 @@ class MemryXDetector(DetectionApi): concat_2 = self.onnx_concat([output_820, output_821, output_811], axis=1) concat_3 = self.onnx_concat([output_846, output_847, output_837], axis=1) - shape = np.array([1, 85, -1], dtype=np.int64) + shape = np.array([1, 85, -1], dtype=np.int64) reshape_1 = self.onnx_reshape(concat_1, shape) reshape_2 = self.onnx_reshape(concat_2, shape) @@ -332,7 +376,7 @@ class MemryXDetector(DetectionApi): concat_out = self.onnx_concat([reshape_1, reshape_2, reshape_3], axis=2) - output = concat_out.transpose(0,2,1) #1, 840, 85 + output = concat_out.transpose(0, 2, 1) # 1, 840, 85 self.num_classes = output.shape[2] - 5 @@ -343,8 +387,10 @@ class MemryXDetector(DetectionApi): results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides image_pred = results[0, ...] - class_conf = np.max(image_pred[:, 5:5 + self.num_classes], axis=1, keepdims=True) - class_pred = np.argmax(image_pred[:, 5:5 + self.num_classes], axis=1) + class_conf = np.max( + image_pred[:, 5 : 5 + self.num_classes], axis=1, keepdims=True + ) + class_pred = np.argmax(image_pred[:, 5 : 5 + self.num_classes], axis=1) class_pred = np.expand_dims(class_pred, axis=1) conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= 0.3).squeeze() @@ -364,7 +410,6 @@ class MemryXDetector(DetectionApi): self.output_queue.put(final_detections) - def post_process_ssdlite(self, outputs): dets = outputs[0].squeeze(0) # Shape: (1, num_dets, 5) labels = outputs[1].squeeze(0) @@ -414,12 +459,13 @@ class MemryXDetector(DetectionApi): x_max /= self.memx_model_width y_max /= self.memx_model_height - final_detections[i] = [class_id, confidence, y_min, x_min, y_max, x_max] + final_detections[i] = [class_id, confidence, y_min, x_min, y_max, x_max] self.output_queue.put(final_detections) - def onnx_reshape_with_allowzero(self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0) -> np.ndarray: - + def onnx_reshape_with_allowzero( + self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0 + ) -> np.ndarray: shape = shape.astype(int) input_shape = data.shape output_shape = [] @@ -436,9 +482,11 @@ class MemryXDetector(DetectionApi): return reshaped def process_output(self, *outputs): - """Output callback function -- receives frames from the MX3 and triggers post-processing""" + """Output callback function -- receives frames from the MX3 and triggers post-processing""" if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]: - outputs = [np.expand_dims(tensor, axis=0) for tensor in outputs] # Shape: (1, H, W, C) + outputs = [ + np.expand_dims(tensor, axis=0) for tensor in outputs + ] # Shape: (1, H, W, C) # Move channel axis from 3rd (last) position to 1st position → (1, C, H, W) outputs = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in outputs] @@ -466,15 +514,17 @@ class MemryXDetector(DetectionApi): split_sizes = [64, 80] # Calculate indices at which to split - indices = np.cumsum(split_sizes)[:-1] # [64] — split before the second chunk + indices = np.cumsum(split_sizes)[ + :-1 + ] # [64] — split before the second chunk # Perform split along axis 1 split_0, split_1 = np.split(concat_4, indices, axis=axis) - shape1 = np.array([1,4,16,8400]) + shape1 = np.array([1, 4, 16, 8400]) reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0) - transpose_1 = reshape_4.transpose(0,2,1,3) + transpose_1 = reshape_4.transpose(0, 2, 1, 3) axis = 1 # As per ONNX softmax node @@ -488,10 +538,14 @@ class MemryXDetector(DetectionApi): W = np.arange(16, dtype=np.float32).reshape(1, 16, 1, 1) # (1, 16, 1, 1) # Apply 1x1 convolution: this is a weighted sum over channels - conv_output = np.sum(softmax_output * W, axis=1, keepdims=True) # shape: (1, 1, 4, 8400) + conv_output = np.sum( + softmax_output * W, axis=1, keepdims=True + ) # shape: (1, 1, 4, 8400) - shape2 = np.array([1,4,8400]) - reshape_5 = self.onnx_reshape_with_allowzero(conv_output, shape2, allowzero=0) + shape2 = np.array([1, 4, 8400]) + reshape_5 = self.onnx_reshape_with_allowzero( + conv_output, shape2, allowzero=0 + ) # ONNX Slice — get first 2 channels: [0:2] along axis 1 slice_output1 = reshape_5[:, 0:2, :] # Result: (1, 2, 8400) @@ -511,7 +565,7 @@ class MemryXDetector(DetectionApi): div_output = add1 / 2.0 - concat_5 = self.onnx_concat([div_output, sub1], axis=1) + concat_5 = self.onnx_concat([div_output, sub1], axis=1) # Expand B to (1, 1, 8400) so it can broadcast across axis=1 (4 channels) const_C_expanded = self.const_C[:, np.newaxis, :] # Shape: (1, 1, 8400) @@ -522,23 +576,25 @@ class MemryXDetector(DetectionApi): sigmoid_output = self.sigmoid(split_1) outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1) - final_detections = post_process_yolov9(outputs, self.memx_model_width, self.memx_model_height) + final_detections = post_process_yolov9( + outputs, self.memx_model_width, self.memx_model_height + ) self.output_queue.put(final_detections) - + elif self.memx_model_type == ModelTypeEnum.yolonas: return self.post_process_yolonas(outputs) - + elif self.memx_model_type == ModelTypeEnum.yolox: return self.post_process_yolox(outputs) - + elif self.memx_model_type == ModelTypeEnum.ssd: return self.post_process_ssdlite(outputs) - + else: raise Exception( f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models." ) - + def detect_raw(self, tensor_input: np.ndarray): - """ Removed synchronous detect_raw() function so that we only use async """ + """Removed synchronous detect_raw() function so that we only use async""" return 0 diff --git a/frigate/object_detection/base.py b/frigate/object_detection/base.py index 7cb5cd6ca..b4b434a45 100644 --- a/frigate/object_detection/base.py +++ b/frigate/object_detection/base.py @@ -139,7 +139,7 @@ def run_detector( avg_speed.value = (avg_speed.value * 9 + duration) / 10 logger.info("Exited detection process...") - return self.detect_api.detect_raw(tensor_input=tensor_input) + return self.detect_api.detect_raw(tensor_input=tensor_input) def async_run_detector( @@ -194,7 +194,7 @@ def async_run_detector( logger.warning(f"Failed to get frame {connection_id} from SHM") continue - #send input to Accelator + # send input to Accelator start.value = datetime.datetime.now().timestamp() object_detector.detect_api.send_input(connection_id, input_frame) @@ -231,7 +231,6 @@ def async_run_detector( logger.info("Exited async detection process...") - class ObjectDetectProcess: def __init__( self, @@ -266,7 +265,7 @@ class ObjectDetectProcess: self.detection_start.value = 0.0 if (self.detect_process is not None) and self.detect_process.is_alive(): self.stop() - if (self.detector_config.type == 'memryx'): + if self.detector_config.type == "memryx": # MemryX requires asynchronous detection handling using async_run_detector self.detect_process = util.Process( target=async_run_detector, @@ -292,7 +291,7 @@ class ObjectDetectProcess: self.detection_start, self.detector_config, ), - ) + ) self.detect_process.daemon = True self.detect_process.start()