Update memryx.py

This commit is contained in:
abinila siva 2025-04-14 16:40:28 -04:00 committed by GitHub
parent 1fb98066fc
commit 4d1a97f059
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -78,7 +78,6 @@ class MemryXDetector(DetectionApi):
elif self.memx_model_type == ModelTypeEnum.yolox:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
# self.expected_post_model = "YOLOX_640_640_3_onnx_post.onnx"
self.set_strides_grids()
elif self.memx_model_type == ModelTypeEnum.ssd:
@ -98,7 +97,7 @@ class MemryXDetector(DetectionApi):
self.accl.set_postprocessing_model(self.memx_post_model, model_idx=0)
self.accl.connect_input(self.process_input)
self.accl.connect_output(self.process_output)
# self.accl.wait() # Wait for the accelerator to finish
# self.accl.wait() # This ensures a clean exit, but Frigate manages process termination
logger.info(f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}")
@ -117,7 +116,7 @@ class MemryXDetector(DetectionApi):
else:
post_model_file_path = os.path.join(self.cache_dir, self.expected_post_model)
# Check if both required model files exist
# Check if both post model file exist
if os.path.isfile(post_model_file_path):
self.memx_post_model = post_model_file_path
logger.info(f"Post-processing model found at {post_model_file_path}, skipping download.")
@ -147,16 +146,42 @@ class MemryXDetector(DetectionApi):
os.remove(zip_path)
logger.info("Cleaned up ZIP file after extraction.")
def send_input(self, connection_id, input_frame):
def send_input(self, connection_id, tensor_input: np.ndarray):
"""Send frame directly to MemryX processing."""
# logging.info(f"Processing frame for connection ID: {connection_id}")
if input_frame is None:
if tensor_input is None:
raise ValueError("[send_input] No image data provided for inference")
if self.memx_model_type == ModelTypeEnum.yolox:
tensor_input = tensor_input.squeeze(0)
padded_img = np.ones((640, 640, 3),
dtype=np.uint8) * 114
scale = min(640 / float(tensor_input.shape[0]),
640 / float(tensor_input.shape[1]))
sx,sy = int(tensor_input.shape[1] * scale), int(tensor_input.shape[0] * scale)
resized_img = cv2.resize(tensor_input, (sx,sy), interpolation=cv2.INTER_LINEAR)
padded_img[:sy, :sx] = resized_img.astype(np.uint8)
# Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels
x0 = padded_img[0::2, 0::2, :] # Top-left
x1 = padded_img[1::2, 0::2, :] # Bottom-left
x2 = padded_img[0::2, 1::2, :] # Top-right
x3 = padded_img[1::2, 1::2, :] # Bottom-right
# Step 5: Concatenate along the channel dimension (axis 2)
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
processed_input = concatenated_img.astype(np.float32)
else:
processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize
# Assuming original input is always NHWC and MemryX wants HWNC:
processed_input = processed_input.transpose(1, 2, 0, 3) # NHWC -> HWNC
# Send frame to MemryX for processing
self.capture_queue.put(input_frame) # MemryX will process this
self.capture_id_queue.put(connection_id) # Keep track of connection ID
self.capture_queue.put(processed_input)
self.capture_id_queue.put(connection_id)
def process_input(self):
"""
@ -389,7 +414,6 @@ class MemryXDetector(DetectionApi):
final_detections[i] = [class_id, confidence, y_min, x_min, y_max, x_max]
# logger.info(f"Final detections: {final_detections}")
self.output_queue.put(final_detections)
def onnx_reshape_with_allowzero(self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0) -> np.ndarray:
@ -487,8 +511,6 @@ class MemryXDetector(DetectionApi):
concat_5 = self.onnx_concat([div_output, sub1], axis=1)
# const_C = np.load("_model_22_Constant_12_output_0.npy") # Shape: (1, 8400)
# Expand B to (1, 1, 8400) so it can broadcast across axis=1 (4 channels)
const_C_expanded = self.const_C[:, np.newaxis, :] # Shape: (1, 1, 8400)
@ -520,5 +542,4 @@ class MemryXDetector(DetectionApi):
Run inference on the input image and return raw results.
tensor_input: Preprocessed image (normalized & resized)
"""
# logger.info("[detect_raw] Running inference on MemryX")
return 0