mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
Added variable model size support
This commit is contained in:
parent
821cc1842d
commit
99b2b37238
@ -73,16 +73,26 @@ class MemryXDetector(DetectionApi):
|
||||
self.cache_dir = "/memryx_models"
|
||||
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolo-generic.zip"
|
||||
model_mapping = {
|
||||
(640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_640.zip", "yolov9_640"),
|
||||
(320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320")
|
||||
}
|
||||
self.model_url, self.model_folder = model_mapping.get(
|
||||
(self.memx_model_height, self.memx_model_width),
|
||||
("https://developer.memryx.com/example_files/1p2_frigate/yolov9_320.zip", "yolov9_320")
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"YOLO_v9_small_640_640_3_onnx.dfp"
|
||||
"YOLO_v9_small_onnx.dfp"
|
||||
)
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolonas.zip"
|
||||
model_mapping = {
|
||||
(640, 640): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_640.zip", "yolonas_640"),
|
||||
(320, 320): ("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320")
|
||||
}
|
||||
self.model_url, self.model_folder = model_mapping.get(
|
||||
(self.memx_model_height, self.memx_model_width),
|
||||
("https://developer.memryx.com/example_files/1p2_frigate/yolonas_320.zip", "yolonas_320")
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"yolo_nas_s.dfp"
|
||||
@ -92,6 +102,7 @@ class MemryXDetector(DetectionApi):
|
||||
)
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||
self.model_folder = "yolox"
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
|
||||
)
|
||||
@ -101,6 +112,7 @@ class MemryXDetector(DetectionApi):
|
||||
self.set_strides_grids()
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||
self.model_folder = "ssd"
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/ssd.zip"
|
||||
)
|
||||
@ -124,7 +136,7 @@ class MemryXDetector(DetectionApi):
|
||||
# Load MemryX Model with a unique device target
|
||||
self.accl = AsyncAccl(
|
||||
self.memx_model_path,
|
||||
mxserver_addr=mxserver_addr,
|
||||
mxserver_addr = mxserver_addr,
|
||||
group_id=device_id, # AsyncAccl device id
|
||||
)
|
||||
|
||||
@ -145,7 +157,7 @@ class MemryXDetector(DetectionApi):
|
||||
raise
|
||||
|
||||
def load_yolo_constants(self):
|
||||
base = f"{self.cache_dir}/{self.memx_model_type.value}"
|
||||
base = f"{self.cache_dir}/{self.model_folder}"
|
||||
# constants for yolov9 post-processing
|
||||
self.const_A = np.load(
|
||||
f"{base}/_model_22_Constant_9_output_0.npy"
|
||||
@ -170,7 +182,7 @@ class MemryXDetector(DetectionApi):
|
||||
return
|
||||
|
||||
logger.info(f"Model files not found. Downloading from {self.model_url}...")
|
||||
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
|
||||
zip_path = os.path.join(self.cache_dir, f"{self.model_folder}.zip")
|
||||
|
||||
try:
|
||||
# Before downloading, check if already downloaded
|
||||
@ -180,7 +192,7 @@ class MemryXDetector(DetectionApi):
|
||||
logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
|
||||
|
||||
# Before extracting, check if model folder exists already
|
||||
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
|
||||
model_subdir = os.path.join(self.cache_dir, self.model_folder)
|
||||
if not os.path.exists(model_subdir):
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(self.cache_dir)
|
||||
@ -541,7 +553,8 @@ class MemryXDetector(DetectionApi):
|
||||
# Perform split along axis 1
|
||||
split_0, split_1 = np.split(concat_4, indices, axis=axis)
|
||||
|
||||
shape1 = np.array([1, 4, 16, 8400])
|
||||
num_boxes = 2100 if self.memx_model_height == 320 else 8400
|
||||
shape1 = np.array([1, 4, 16, num_boxes])
|
||||
reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0)
|
||||
|
||||
transpose_1 = reshape_4.transpose(0, 2, 1, 3)
|
||||
@ -562,7 +575,7 @@ class MemryXDetector(DetectionApi):
|
||||
softmax_output * W, axis=1, keepdims=True
|
||||
) # shape: (1, 1, 4, 8400)
|
||||
|
||||
shape2 = np.array([1, 4, 8400])
|
||||
shape2 = np.array([1, 4, num_boxes])
|
||||
reshape_5 = self.onnx_reshape_with_allowzero(
|
||||
conv_output, shape2, allowzero=0
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user