ruff formatting

This commit is contained in:
Tim 2025-04-15 11:33:59 -04:00
parent 6aef5e80d3
commit cf429549b9
2 changed files with 139 additions and 84 deletions

View File

@ -11,7 +11,9 @@ try:
# from memryx import AsyncAccl # Import MemryX SDK
from memryx import AsyncAccl
except ModuleNotFoundError:
raise ImportError("MemryX SDK is not installed. Install it and set up MIX environment.")
raise ImportError(
"MemryX SDK is not installed. Install it and set up MIX environment."
)
from pydantic import BaseModel, Field
from typing_extensions import Literal
@ -23,15 +25,18 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "memryx"
# Configuration class for model settings
class ModelConfig(BaseModel):
path: str = Field(default=None, title="Model Path") # Path to the DFP file
labelmap_path: str = Field(default=None, title="Path to Label Map")
class MemryXDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
device: str = Field(default="PCIe", title="Device Path")
class MemryXDetector(DetectionApi):
type_key = DETECTOR_KEY # Set the type key
supported_models = [
@ -61,38 +66,60 @@ class MemryXDetector(DetectionApi):
self.cache_dir = "/memryx_models"
if self.memx_model_type == ModelTypeEnum.yolov9:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
)
elif self.memx_model_type == ModelTypeEnum.yolov8:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
)
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
# Shared constants for both yolov8 and yolov9 post-processing
self.const_A = np.load("/memryx_models/yolov9/_model_22_Constant_9_output_0.npy")
self.const_B = np.load("/memryx_models/yolov9/_model_22_Constant_10_output_0.npy")
self.const_C = np.load("/memryx_models/yolov9/_model_22_Constant_12_output_0.npy")
self.const_A = np.load(
"/memryx_models/yolov9/_model_22_Constant_9_output_0.npy"
)
self.const_B = np.load(
"/memryx_models/yolov9/_model_22_Constant_10_output_0.npy"
)
self.const_C = np.load(
"/memryx_models/yolov9/_model_22_Constant_12_output_0.npy"
)
elif self.memx_model_type == ModelTypeEnum.yolonas:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
)
self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx"
elif self.memx_model_type == ModelTypeEnum.yolox:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
)
self.set_strides_grids()
elif self.memx_model_type == ModelTypeEnum.ssd:
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
self.expected_post_model = "ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
self.model_url = (
"https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
)
self.expected_post_model = (
"ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
)
self.check_and_prepare_model()
logger.info(f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}")
logger.info(
f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}"
)
try:
# Load MemryX Model
logger.info(f"dfp path: {self.memx_model_path}")
# Initialization code
self.accl = AsyncAccl(self.memx_model_path, mxserver_addr="host.docker.internal")
self.accl = AsyncAccl(
self.memx_model_path, mxserver_addr="host.docker.internal"
)
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
@ -102,7 +129,9 @@ class MemryXDetector(DetectionApi):
self.accl.connect_input(self.process_input)
self.accl.connect_output(self.process_output)
logger.info(f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}")
logger.info(
f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}"
)
except Exception as e:
logger.error(f"Failed to initialize MemryX model: {e}")
@ -117,14 +146,20 @@ class MemryXDetector(DetectionApi):
logger.info(f"Assigned Model Path: {self.memx_model_path}")
else:
post_model_file_path = os.path.join(self.cache_dir, self.expected_post_model)
post_model_file_path = os.path.join(
self.cache_dir, self.expected_post_model
)
# Check if both post model file exist
if os.path.isfile(post_model_file_path):
self.memx_post_model = post_model_file_path
logger.info(f"Post-processing model found at {post_model_file_path}, skipping download.")
logger.info(
f"Post-processing model found at {post_model_file_path}, skipping download."
)
else:
logger.info(f"Model files not found. Downloading from {self.model_url}...")
logger.info(
f"Model files not found. Downloading from {self.model_url}..."
)
zip_path = os.path.join(self.cache_dir, "memryx_model.zip")
# Download the ZIP file
@ -143,7 +178,9 @@ class MemryXDetector(DetectionApi):
self.memx_post_model = os.path.join(self.cache_dir, file)
logger.info(f"Assigned Model Path: {self.memx_model_path}")
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
logger.info(
f"Assigned Post-processing Model Path: {self.memx_post_model}"
)
# Cleanup: Remove the ZIP file after extraction
os.remove(zip_path)
@ -157,17 +194,21 @@ class MemryXDetector(DetectionApi):
if self.memx_model_type == ModelTypeEnum.yolox:
tensor_input = tensor_input.squeeze(0)
padded_img = np.ones((640, 640, 3),
dtype=np.uint8) * 114
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
scale = min(640 / float(tensor_input.shape[0]),
640 / float(tensor_input.shape[1]))
sx,sy = int(tensor_input.shape[1] * scale), int(tensor_input.shape[0] * scale)
scale = min(
640 / float(tensor_input.shape[0]), 640 / float(tensor_input.shape[1])
)
sx, sy = (
int(tensor_input.shape[1] * scale),
int(tensor_input.shape[0] * scale),
)
resized_img = cv2.resize(tensor_input, (sx,sy), interpolation=cv2.INTER_LINEAR)
resized_img = cv2.resize(
tensor_input, (sx, sy), interpolation=cv2.INTER_LINEAR
)
padded_img[:sy, :sx] = resized_img.astype(np.uint8)
# Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels
x0 = padded_img[0::2, 0::2, :] # Top-left
x1 = padded_img[1::2, 0::2, :] # Bottom-left
@ -191,7 +232,9 @@ class MemryXDetector(DetectionApi):
while True:
try:
# Wait for a frame from the queue (blocking call)
frame = self.capture_queue.get(block=True) # Blocks until data is available
frame = self.capture_queue.get(
block=True
) # Blocks until data is available
return frame
@ -201,7 +244,9 @@ class MemryXDetector(DetectionApi):
def receive_output(self):
"""Retrieve processed results from MemryX output queue + a copy of the original frame"""
connection_id = self.capture_id_queue.get() # Get the corresponding connection ID
connection_id = (
self.capture_id_queue.get()
) # Get the corresponding connection ID
detections = self.output_queue.get() # Get detections from MemryX
return connection_id, detections
@ -244,7 +289,6 @@ class MemryXDetector(DetectionApi):
(pos[0] - (pos[2] / 2)) / self.memx_model_width, # x_min
(pos[1] + (pos[3] / 2)) / self.memx_model_height, # y_max
(pos[0] + (pos[2] / 2)) / self.memx_model_width, # x_max
]
def set_strides_grids(self):
@ -266,11 +310,9 @@ class MemryXDetector(DetectionApi):
self.expanded_strides = np.concatenate(expanded_strides, 1)
def sigmoid(self, x: np.ndarray) -> np.ndarray:
return 1 / (1 + np.exp(-x))
def onnx_concat(self, inputs: list, axis: int) -> np.ndarray:
# Ensure all inputs are numpy arrays
if not all(isinstance(x, np.ndarray) for x in inputs):
raise TypeError("All inputs must be numpy arrays.")
@ -282,12 +324,13 @@ class MemryXDetector(DetectionApi):
if ax == axis:
continue
if tensor.shape[ax] != ref_shape[ax]:
raise ValueError(f"Shape mismatch at axis {ax} between input[0] and input[{i}]")
raise ValueError(
f"Shape mismatch at axis {ax} between input[0] and input[{i}]"
)
return np.concatenate(inputs, axis=axis)
def onnx_reshape(self, data: np.ndarray, shape: np.ndarray) -> np.ndarray:
# Ensure shape is a 1D array of integers
target_shape = shape.astype(int).tolist()
@ -297,8 +340,9 @@ class MemryXDetector(DetectionApi):
return reshaped
def post_process_yolox(self, output):
output = [np.expand_dims(tensor, axis=0) for tensor in output] # Shape: (1, H, W, C)
output = [
np.expand_dims(tensor, axis=0) for tensor in output
] # Shape: (1, H, W, C)
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
output = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in output]
@ -343,7 +387,9 @@ class MemryXDetector(DetectionApi):
results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides
image_pred = results[0, ...]
class_conf = np.max(image_pred[:, 5:5 + self.num_classes], axis=1, keepdims=True)
class_conf = np.max(
image_pred[:, 5 : 5 + self.num_classes], axis=1, keepdims=True
)
class_pred = np.argmax(image_pred[:, 5 : 5 + self.num_classes], axis=1)
class_pred = np.expand_dims(class_pred, axis=1)
@ -364,7 +410,6 @@ class MemryXDetector(DetectionApi):
self.output_queue.put(final_detections)
def post_process_ssdlite(self, outputs):
dets = outputs[0].squeeze(0) # Shape: (1, num_dets, 5)
labels = outputs[1].squeeze(0)
@ -418,8 +463,9 @@ class MemryXDetector(DetectionApi):
self.output_queue.put(final_detections)
def onnx_reshape_with_allowzero(self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0) -> np.ndarray:
def onnx_reshape_with_allowzero(
self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0
) -> np.ndarray:
shape = shape.astype(int)
input_shape = data.shape
output_shape = []
@ -438,7 +484,9 @@ class MemryXDetector(DetectionApi):
def process_output(self, *outputs):
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
outputs = [np.expand_dims(tensor, axis=0) for tensor in outputs] # Shape: (1, H, W, C)
outputs = [
np.expand_dims(tensor, axis=0) for tensor in outputs
] # Shape: (1, H, W, C)
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
outputs = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in outputs]
@ -466,7 +514,9 @@ class MemryXDetector(DetectionApi):
split_sizes = [64, 80]
# Calculate indices at which to split
indices = np.cumsum(split_sizes)[:-1] # [64] — split before the second chunk
indices = np.cumsum(split_sizes)[
:-1
] # [64] — split before the second chunk
# Perform split along axis 1
split_0, split_1 = np.split(concat_4, indices, axis=axis)
@ -488,10 +538,14 @@ class MemryXDetector(DetectionApi):
W = np.arange(16, dtype=np.float32).reshape(1, 16, 1, 1) # (1, 16, 1, 1)
# Apply 1x1 convolution: this is a weighted sum over channels
conv_output = np.sum(softmax_output * W, axis=1, keepdims=True) # shape: (1, 1, 4, 8400)
conv_output = np.sum(
softmax_output * W, axis=1, keepdims=True
) # shape: (1, 1, 4, 8400)
shape2 = np.array([1, 4, 8400])
reshape_5 = self.onnx_reshape_with_allowzero(conv_output, shape2, allowzero=0)
reshape_5 = self.onnx_reshape_with_allowzero(
conv_output, shape2, allowzero=0
)
# ONNX Slice — get first 2 channels: [0:2] along axis 1
slice_output1 = reshape_5[:, 0:2, :] # Result: (1, 2, 8400)
@ -522,7 +576,9 @@ class MemryXDetector(DetectionApi):
sigmoid_output = self.sigmoid(split_1)
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
final_detections = post_process_yolov9(outputs, self.memx_model_width, self.memx_model_height)
final_detections = post_process_yolov9(
outputs, self.memx_model_width, self.memx_model_height
)
self.output_queue.put(final_detections)
elif self.memx_model_type == ModelTypeEnum.yolonas:

View File

@ -231,7 +231,6 @@ def async_run_detector(
logger.info("Exited async detection process...")
class ObjectDetectProcess:
def __init__(
self,
@ -266,7 +265,7 @@ class ObjectDetectProcess:
self.detection_start.value = 0.0
if (self.detect_process is not None) and self.detect_process.is_alive():
self.stop()
if (self.detector_config.type == 'memryx'):
if self.detector_config.type == "memryx":
# MemryX requires asynchronous detection handling using async_run_detector
self.detect_process = util.Process(
target=async_run_detector,