From 32bb3cd25d702a5334ad698a340bdd4a1ebc310d Mon Sep 17 00:00:00 2001 From: abinila siva <163017635+abinila4@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:06:54 -0400 Subject: [PATCH] Added multi-device MemryX support --- frigate/detectors/plugins/memryx.py | 56 +++++++++++++++-------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/frigate/detectors/plugins/memryx.py b/frigate/detectors/plugins/memryx.py index f3fac1e89..a326513b1 100644 --- a/frigate/detectors/plugins/memryx.py +++ b/frigate/detectors/plugins/memryx.py @@ -56,7 +56,12 @@ class MemryXDetector(DetectionApi): self.memx_model_path = detector_config.model.path # Path to .dfp file self.memx_post_model = None # Path to .post file self.expected_post_model = None + self.memx_device_path = detector_config.device # Device path + # Parse the device string to split PCIe: + device_str = self.memx_device_path + device_id = int(device_str.split(":")[1]) + self.memx_model_height = detector_config.model.height self.memx_model_width = detector_config.model.width self.memx_model_type = detector_config.model.model_type @@ -116,10 +121,13 @@ class MemryXDetector(DetectionApi): logger.info(f"dfp path: {self.memx_model_path}") # Initialization code + # Load MemryX Model with a unique device target self.accl = AsyncAccl( - self.memx_model_path, mxserver_addr="host.docker.internal" + self.memx_model_path, + mxserver_addr="host.docker.internal", + group_id=device_id, # AsyncAccl device id ) - + # Models that use cropped post-processing sections (YOLO-NAS and SSD) # --> These will be moved to pure numpy in the future to improve performance on low-end CPUs if self.memx_post_model: @@ -158,34 +166,27 @@ class MemryXDetector(DetectionApi): post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model if dfp_exists and post_exists: - logger.info( - f"Using cached models." - ) + logger.info("Using cached models.") return - - logger.info( - f"Model files not found. Downloading from {self.model_url}..." - ) + + logger.info(f"Model files not found. Downloading from {self.model_url}...") zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip") try: - # Download the ZIP file - urllib.request.urlretrieve(self.model_url, zip_path) - logger.info( - f"Model ZIP downloaded to {zip_path}. Extracting..." - ) + # Before downloading, check if someone else already downloaded + if not os.path.exists(zip_path): + # Download only if zip does not exist + urllib.request.urlretrieve(self.model_url, zip_path) + logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...") - # Extract ZIP file - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(self.cache_dir) - logger.info( - f"Model extracted to {self.cache_dir}." - ) - - # Determine the subfolder to search in + # Before extracting, check if model folder exists already model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value) - - # Assign extracted files to correct paths + if not os.path.exists(model_subdir): + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(self.cache_dir) + logger.info(f"Model extracted to {self.cache_dir}.") + + # Assign extracted files to correct paths for file in os.listdir(model_subdir): file_path = os.path.join(model_subdir, file) if file == self.expected_dfp_model: @@ -196,14 +197,15 @@ class MemryXDetector(DetectionApi): logger.info(f"Assigned Model Path: {self.memx_model_path}") logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}") - if self.memx_model_type == ModelTypeEnum.yolov9: - self.load_yolo_constants() + if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]: + self.load_yolo_constants() except Exception as e: logger.error(f"Failed to prepare model: {e}") raise + finally: - # Cleanup: Remove the ZIP file after extraction + # Remove zip only if we still have it and models exist if os.path.exists(zip_path): os.remove(zip_path) logger.info("Cleaned up ZIP file after extraction.")