download models at runtime

This commit is contained in:
Abinila Siva 2025-04-17 09:43:18 -04:00
parent a154a13f98
commit 2304457557
3 changed files with 83 additions and 73 deletions

View File

@ -294,18 +294,8 @@ install_and_hold() { \
}; \
install_and_hold memx-accl'
# Set the working directory for model files
WORKDIR /memryx_models
# Download and extract MemryX model files
RUN wget -O yolox.zip https://developer.memryx.com/example_files/1p2_frigate/yolox.zip && unzip yolox.zip && \
wget -O ssdlite.zip https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip && unzip ssdlite.zip && \
wget -O yolo_nas.zip https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip && unzip yolo_nas.zip && \
wget -O yolov9.zip https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip && unzip yolov9.zip && \
wget -O yolov8.zip https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip && unzip yolov8.zip
# Set permissions for the models directory
RUN chmod -R 755 /memryx_models
# Copy the 81-class COCO label map
RUN wget -O coco_81class_labelmap.txt https://developer.memryx.com/example_files/1p2_frigate/labelmap.txt
###### End MemryX setup

View File

@ -39,8 +39,6 @@ class ModelTypeEnum(str, Enum):
rfdetr = "rfdetr"
ssd = "ssd"
yolox = "yolox"
yolov9 = "yolov9"
yolov8 = "yolov8"
yolonas = "yolonas"
yologeneric = "yolo-generic"

View File

@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "memryx"
# Configuration class for model settings
class ModelConfig(BaseModel):
path: str = Field(default=None, title="Model Path") # Path to the DFP file
@ -38,12 +37,11 @@ class MemryXDetectorConfig(BaseDetectorConfig):
class MemryXDetector(DetectionApi):
type_key = DETECTOR_KEY # Set the type key
type_key = DETECTOR_KEY # Set the type key
supported_models = [
ModelTypeEnum.ssd,
ModelTypeEnum.yolonas,
ModelTypeEnum.yolov9,
ModelTypeEnum.yolov8,
ModelTypeEnum.yologeneric,
ModelTypeEnum.yolox,
]
@ -63,48 +61,49 @@ class MemryXDetector(DetectionApi):
self.memx_model_width = detector_config.model.width
self.memx_model_type = detector_config.model.model_type
# If it's yologeneric, treat it as yolov9
if self.memx_model_type == ModelTypeEnum.yologeneric:
self.memx_model_type = ModelTypeEnum.yolov9
self.cache_dir = "/memryx_models"
if self.memx_model_type == ModelTypeEnum.yolov9:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
)
elif self.memx_model_type == ModelTypeEnum.yolov8:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
)
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
# Shared constants for both yolov8 and yolov9 post-processing
self.const_A = np.load(
"/memryx_models/yolov9/_model_22_Constant_9_output_0.npy"
)
self.const_B = np.load(
"/memryx_models/yolov9/_model_22_Constant_10_output_0.npy"
)
self.const_C = np.load(
"/memryx_models/yolov9/_model_22_Constant_12_output_0.npy"
self.expected_dfp_model = (
"YOLO_v9_small_640_640_3_onnx.dfp"
)
elif self.memx_model_type == ModelTypeEnum.yolonas:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
"https://developer.memryx.com/example_files/1p2_frigate/yolonas.zip"
)
self.expected_dfp_model = (
"yolo_nas_s.dfp"
)
self.expected_post_model = (
"yolo_nas_s_post.onnx"
)
self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx"
elif self.memx_model_type == ModelTypeEnum.yolox:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
)
self.expected_dfp_model = (
"YOLOX_640_640_3_onnx.dfp"
)
self.set_strides_grids()
elif self.memx_model_type == ModelTypeEnum.ssd:
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
"https://developer.memryx.com/example_files/1p2_frigate/ssd.zip"
)
self.expected_dfp_model = (
"SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
)
self.expected_post_model = (
"ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
"SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
)
self.check_and_prepare_model()
@ -136,53 +135,76 @@ class MemryXDetector(DetectionApi):
except Exception as e:
logger.error(f"Failed to initialize MemryX model: {e}")
raise
def load_yolo_constants(self):
base = f"{self.cache_dir}/{self.memx_model_type.value}"
# constants for yolov9 post-processing
self.const_A = np.load(
f"{base}/_model_22_Constant_9_output_0.npy"
)
self.const_B = np.load(
f"{base}/_model_22_Constant_10_output_0.npy"
)
self.const_C = np.load(
f"{base}/_model_22_Constant_12_output_0.npy"
)
def check_and_prepare_model(self):
"""Check if both models exist; if not, download and extract them."""
"""Check if models exist; if not, download and extract them."""
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
if not self.expected_post_model:
logger.info(f"Assigned Model Path: {self.memx_model_path}")
dfp_exists = os.path.exists(self.memx_model_path) if self.memx_model_path else False
post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model
else:
post_model_file_path = os.path.join(
self.cache_dir, self.expected_post_model
if dfp_exists and post_exists:
logger.info(
f"Using cached models."
)
return
logger.info(
f"Model files not found. Downloading from {self.model_url}..."
)
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
try:
# Download the ZIP file
urllib.request.urlretrieve(self.model_url, zip_path)
logger.info(
f"Model ZIP downloaded to {zip_path}. Extracting..."
)
# Check if both post model file exist
if os.path.isfile(post_model_file_path):
self.memx_post_model = post_model_file_path
logger.info(
f"Post-processing model found at {post_model_file_path}, skipping download."
)
else:
logger.info(
f"Model files not found. Downloading from {self.model_url}..."
)
zip_path = os.path.join(self.cache_dir, "memryx_model.zip")
# Extract ZIP file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
logger.info(
f"Model extracted to {self.cache_dir}."
)
# Download the ZIP file
urllib.request.urlretrieve(self.model_url, zip_path)
logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
# Determine the subfolder to search in
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
# Assign extracted files to correct paths
for file in os.listdir(model_subdir):
file_path = os.path.join(model_subdir, file)
if file == self.expected_dfp_model:
self.memx_model_path = file_path
elif file == self.expected_post_model:
self.memx_post_model = file_path
# Extract ZIP file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
logger.info(f"Assigned Model Path: {self.memx_model_path}")
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
logger.info(f"Model extracted to {self.cache_dir}.")
if self.memx_model_type == ModelTypeEnum.yolov9:
self.load_yolo_constants()
# Assign extracted files to correct paths
for file in os.listdir(self.cache_dir):
if file == self.expected_post_model:
self.memx_post_model = os.path.join(self.cache_dir, file)
logger.info(f"Assigned Model Path: {self.memx_model_path}")
logger.info(
f"Assigned Post-processing Model Path: {self.memx_post_model}"
)
# Cleanup: Remove the ZIP file after extraction
except Exception as e:
logger.error(f"Failed to prepare model: {e}")
raise
finally:
# Cleanup: Remove the ZIP file after extraction
if os.path.exists(zip_path):
os.remove(zip_path)
logger.info("Cleaned up ZIP file after extraction.")