mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-08 13:51:01 +02:00
Added multi-device MemryX support
This commit is contained in:
parent
2a1e00b2fd
commit
32bb3cd25d
@ -56,7 +56,12 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
||||||
self.memx_post_model = None # Path to .post file
|
self.memx_post_model = None # Path to .post file
|
||||||
self.expected_post_model = None
|
self.expected_post_model = None
|
||||||
|
|
||||||
self.memx_device_path = detector_config.device # Device path
|
self.memx_device_path = detector_config.device # Device path
|
||||||
|
# Parse the device string to split PCIe:<index>
|
||||||
|
device_str = self.memx_device_path
|
||||||
|
device_id = int(device_str.split(":")[1])
|
||||||
|
|
||||||
self.memx_model_height = detector_config.model.height
|
self.memx_model_height = detector_config.model.height
|
||||||
self.memx_model_width = detector_config.model.width
|
self.memx_model_width = detector_config.model.width
|
||||||
self.memx_model_type = detector_config.model.model_type
|
self.memx_model_type = detector_config.model.model_type
|
||||||
@ -116,10 +121,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
logger.info(f"dfp path: {self.memx_model_path}")
|
logger.info(f"dfp path: {self.memx_model_path}")
|
||||||
|
|
||||||
# Initialization code
|
# Initialization code
|
||||||
|
# Load MemryX Model with a unique device target
|
||||||
self.accl = AsyncAccl(
|
self.accl = AsyncAccl(
|
||||||
self.memx_model_path, mxserver_addr="host.docker.internal"
|
self.memx_model_path,
|
||||||
|
mxserver_addr="host.docker.internal",
|
||||||
|
group_id=device_id, # AsyncAccl device id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
||||||
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
||||||
if self.memx_post_model:
|
if self.memx_post_model:
|
||||||
@ -158,34 +166,27 @@ class MemryXDetector(DetectionApi):
|
|||||||
post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model
|
post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model
|
||||||
|
|
||||||
if dfp_exists and post_exists:
|
if dfp_exists and post_exists:
|
||||||
logger.info(
|
logger.info("Using cached models.")
|
||||||
f"Using cached models."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Model files not found. Downloading from {self.model_url}...")
|
||||||
f"Model files not found. Downloading from {self.model_url}..."
|
|
||||||
)
|
|
||||||
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
|
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Download the ZIP file
|
# Before downloading, check if someone else already downloaded
|
||||||
urllib.request.urlretrieve(self.model_url, zip_path)
|
if not os.path.exists(zip_path):
|
||||||
logger.info(
|
# Download only if zip does not exist
|
||||||
f"Model ZIP downloaded to {zip_path}. Extracting..."
|
urllib.request.urlretrieve(self.model_url, zip_path)
|
||||||
)
|
logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
|
||||||
|
|
||||||
# Extract ZIP file
|
# Before extracting, check if model folder exists already
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(self.cache_dir)
|
|
||||||
logger.info(
|
|
||||||
f"Model extracted to {self.cache_dir}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the subfolder to search in
|
|
||||||
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
|
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
|
||||||
|
if not os.path.exists(model_subdir):
|
||||||
# Assign extracted files to correct paths
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(self.cache_dir)
|
||||||
|
logger.info(f"Model extracted to {self.cache_dir}.")
|
||||||
|
|
||||||
|
# Assign extracted files to correct paths
|
||||||
for file in os.listdir(model_subdir):
|
for file in os.listdir(model_subdir):
|
||||||
file_path = os.path.join(model_subdir, file)
|
file_path = os.path.join(model_subdir, file)
|
||||||
if file == self.expected_dfp_model:
|
if file == self.expected_dfp_model:
|
||||||
@ -196,14 +197,15 @@ class MemryXDetector(DetectionApi):
|
|||||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||||
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
|
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
|
||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yolov9:
|
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
||||||
self.load_yolo_constants()
|
self.load_yolo_constants()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to prepare model: {e}")
|
logger.error(f"Failed to prepare model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Cleanup: Remove the ZIP file after extraction
|
# Remove zip only if we still have it and models exist
|
||||||
if os.path.exists(zip_path):
|
if os.path.exists(zip_path):
|
||||||
os.remove(zip_path)
|
os.remove(zip_path)
|
||||||
logger.info("Cleaned up ZIP file after extraction.")
|
logger.info("Cleaned up ZIP file after extraction.")
|
||||||
|
Loading…
Reference in New Issue
Block a user