Added variable model size support

This commit is contained in:
Abinila Siva 2025-05-21 12:13:41 -04:00 committed by GitHub
parent 821cc1842d
commit 99b2b37238
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,16 +73,26 @@ class MemryXDetector(DetectionApi):
self.cache_dir = "/memryx_models"
if self.memx_model_type == ModelTypeEnum.yologeneric:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolo-generic.zip"
model_mapping = {
(640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_640.zip", "yolov9_640"),
(320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320")
}
self.model_url, self.model_folder = model_mapping.get(
(self.memx_model_height, self.memx_model_width),
("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320")
)
self.expected_dfp_model = (
"YOLO_v9_small_640_640_3_onnx.dfp"
"YOLO_v9_small_onnx.dfp"
)
elif self.memx_model_type == ModelTypeEnum.yolonas:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolonas.zip"
model_mapping = {
(640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_640.zip", "yolonas_640"),
(320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320")
}
self.model_url, self.model_folder = model_mapping.get(
(self.memx_model_height, self.memx_model_width),
("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320")
)
self.expected_dfp_model = (
"yolo_nas_s.dfp"
@ -92,6 +102,7 @@ class MemryXDetector(DetectionApi):
)
elif self.memx_model_type == ModelTypeEnum.yolox:
self.model_folder = "yolox"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
)
@ -101,6 +112,7 @@ class MemryXDetector(DetectionApi):
self.set_strides_grids()
elif self.memx_model_type == ModelTypeEnum.ssd:
self.model_folder = "ssd"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/ssd.zip"
)
@ -124,7 +136,7 @@ class MemryXDetector(DetectionApi):
# Load MemryX Model with a unique device target
self.accl = AsyncAccl(
self.memx_model_path,
mxserver_addr=mxserver_addr,
mxserver_addr = mxserver_addr,
group_id=device_id, # AsyncAccl device id
)
@ -145,7 +157,7 @@ class MemryXDetector(DetectionApi):
raise
def load_yolo_constants(self):
base = f"{self.cache_dir}/{self.memx_model_type.value}"
base = f"{self.cache_dir}/{self.model_folder}"
# constants for yolov9 post-processing
self.const_A = np.load(
f"{base}/_model_22_Constant_9_output_0.npy"
@ -170,7 +182,7 @@ class MemryXDetector(DetectionApi):
return
logger.info(f"Model files not found. Downloading from {self.model_url}...")
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
zip_path = os.path.join(self.cache_dir, f"{self.model_folder}.zip")
try:
# Before downloading, check if already downloaded
@ -180,7 +192,7 @@ class MemryXDetector(DetectionApi):
logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
# Before extracting, check if model folder exists already
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
model_subdir = os.path.join(self.cache_dir, self.model_folder)
if not os.path.exists(model_subdir):
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
@ -541,7 +553,8 @@ class MemryXDetector(DetectionApi):
# Perform split along axis 1
split_0, split_1 = np.split(concat_4, indices, axis=axis)
shape1 = np.array([1, 4, 16, 8400])
num_boxes = 2100 if self.memx_model_height == 320 else 8400
shape1 = np.array([1, 4, 16, num_boxes])
reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0)
transpose_1 = reshape_4.transpose(0, 2, 1, 3)
@ -562,7 +575,7 @@ class MemryXDetector(DetectionApi):
softmax_output * W, axis=1, keepdims=True
) # shape: (1, 1, 4, 8400)
shape2 = np.array([1, 4, 8400])
shape2 = np.array([1, 4, num_boxes])
reshape_5 = self.onnx_reshape_with_allowzero(
conv_output, shape2, allowzero=0
)
@ -617,4 +630,4 @@ class MemryXDetector(DetectionApi):
def detect_raw(self, tensor_input: np.ndarray):
"""Removed synchronous detect_raw() function so that we only use async"""
return 0
return 0