Added multi-device MemryX support

This commit is contained in:
abinila siva 2025-04-25 16:06:54 -04:00 committed by GitHub
parent 2a1e00b2fd
commit 32bb3cd25d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,7 +56,12 @@ class MemryXDetector(DetectionApi):
self.memx_model_path = detector_config.model.path # Path to .dfp file self.memx_model_path = detector_config.model.path # Path to .dfp file
self.memx_post_model = None # Path to .post file self.memx_post_model = None # Path to .post file
self.expected_post_model = None self.expected_post_model = None
self.memx_device_path = detector_config.device # Device path self.memx_device_path = detector_config.device # Device path
# Parse the device string to split PCIe:<index>
device_str = self.memx_device_path
device_id = int(device_str.split(":")[1])
self.memx_model_height = detector_config.model.height self.memx_model_height = detector_config.model.height
self.memx_model_width = detector_config.model.width self.memx_model_width = detector_config.model.width
self.memx_model_type = detector_config.model.model_type self.memx_model_type = detector_config.model.model_type
@ -116,10 +121,13 @@ class MemryXDetector(DetectionApi):
logger.info(f"dfp path: {self.memx_model_path}") logger.info(f"dfp path: {self.memx_model_path}")
# Initialization code # Initialization code
# Load MemryX Model with a unique device target
self.accl = AsyncAccl( self.accl = AsyncAccl(
self.memx_model_path, mxserver_addr="host.docker.internal" self.memx_model_path,
mxserver_addr="host.docker.internal",
group_id=device_id, # AsyncAccl device id
) )
# Models that use cropped post-processing sections (YOLO-NAS and SSD) # Models that use cropped post-processing sections (YOLO-NAS and SSD)
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs # --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
if self.memx_post_model: if self.memx_post_model:
@ -158,34 +166,27 @@ class MemryXDetector(DetectionApi):
post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model
if dfp_exists and post_exists: if dfp_exists and post_exists:
logger.info( logger.info("Using cached models.")
f"Using cached models."
)
return return
logger.info( logger.info(f"Model files not found. Downloading from {self.model_url}...")
f"Model files not found. Downloading from {self.model_url}..."
)
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip") zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
try: try:
# Download the ZIP file # Before downloading, check if someone else already downloaded
urllib.request.urlretrieve(self.model_url, zip_path) if not os.path.exists(zip_path):
logger.info( # Download only if zip does not exist
f"Model ZIP downloaded to {zip_path}. Extracting..." urllib.request.urlretrieve(self.model_url, zip_path)
) logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
# Extract ZIP file # Before extracting, check if model folder exists already
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
logger.info(
f"Model extracted to {self.cache_dir}."
)
# Determine the subfolder to search in
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value) model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
if not os.path.exists(model_subdir):
# Assign extracted files to correct paths with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
logger.info(f"Model extracted to {self.cache_dir}.")
# Assign extracted files to correct paths
for file in os.listdir(model_subdir): for file in os.listdir(model_subdir):
file_path = os.path.join(model_subdir, file) file_path = os.path.join(model_subdir, file)
if file == self.expected_dfp_model: if file == self.expected_dfp_model:
@ -196,14 +197,15 @@ class MemryXDetector(DetectionApi):
logger.info(f"Assigned Model Path: {self.memx_model_path}") logger.info(f"Assigned Model Path: {self.memx_model_path}")
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}") logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
if self.memx_model_type == ModelTypeEnum.yolov9: if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
self.load_yolo_constants() self.load_yolo_constants()
except Exception as e: except Exception as e:
logger.error(f"Failed to prepare model: {e}") logger.error(f"Failed to prepare model: {e}")
raise raise
finally: finally:
# Cleanup: Remove the ZIP file after extraction # Remove zip only if we still have it and models exist
if os.path.exists(zip_path): if os.path.exists(zip_path):
os.remove(zip_path) os.remove(zip_path)
logger.info("Cleaned up ZIP file after extraction.") logger.info("Cleaned up ZIP file after extraction.")