mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
ruff formatting
This commit is contained in:
parent
6aef5e80d3
commit
cf429549b9
@ -11,7 +11,9 @@ try:
|
|||||||
# from memryx import AsyncAccl # Import MemryX SDK
|
# from memryx import AsyncAccl # Import MemryX SDK
|
||||||
from memryx import AsyncAccl
|
from memryx import AsyncAccl
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ImportError("MemryX SDK is not installed. Install it and set up MIX environment.")
|
raise ImportError(
|
||||||
|
"MemryX SDK is not installed. Install it and set up MIX environment."
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -23,17 +25,20 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
DETECTOR_KEY = "memryx"
|
DETECTOR_KEY = "memryx"
|
||||||
|
|
||||||
|
|
||||||
# Configuration class for model settings
|
# Configuration class for model settings
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
||||||
labelmap_path: str = Field(default=None, title="Path to Label Map")
|
labelmap_path: str = Field(default=None, title="Path to Label Map")
|
||||||
|
|
||||||
|
|
||||||
class MemryXDetectorConfig(BaseDetectorConfig):
|
class MemryXDetectorConfig(BaseDetectorConfig):
|
||||||
type: Literal[DETECTOR_KEY]
|
type: Literal[DETECTOR_KEY]
|
||||||
device: str = Field(default="PCIe", title="Device Path")
|
device: str = Field(default="PCIe", title="Device Path")
|
||||||
|
|
||||||
|
|
||||||
class MemryXDetector(DetectionApi):
|
class MemryXDetector(DetectionApi):
|
||||||
type_key = DETECTOR_KEY # Set the type key
|
type_key = DETECTOR_KEY # Set the type key
|
||||||
supported_models = [
|
supported_models = [
|
||||||
ModelTypeEnum.ssd,
|
ModelTypeEnum.ssd,
|
||||||
ModelTypeEnum.yolonas,
|
ModelTypeEnum.yolonas,
|
||||||
@ -51,7 +56,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
||||||
self.memx_post_model = None # Path to .post file
|
self.memx_post_model = None # Path to .post file
|
||||||
self.expected_post_model = None
|
self.expected_post_model = None
|
||||||
self.memx_device_path = detector_config.device # Device path
|
self.memx_device_path = detector_config.device # Device path
|
||||||
self.memx_model_height = detector_config.model.height
|
self.memx_model_height = detector_config.model.height
|
||||||
@ -61,38 +66,60 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.cache_dir = "/memryx_models"
|
self.cache_dir = "/memryx_models"
|
||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yolov9:
|
if self.memx_model_type == ModelTypeEnum.yolov9:
|
||||||
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
|
self.model_url = (
|
||||||
|
"https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
|
||||||
|
)
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolov8:
|
elif self.memx_model_type == ModelTypeEnum.yolov8:
|
||||||
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
|
self.model_url = (
|
||||||
|
"https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
|
||||||
|
)
|
||||||
|
|
||||||
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
||||||
# Shared constants for both yolov8 and yolov9 post-processing
|
# Shared constants for both yolov8 and yolov9 post-processing
|
||||||
self.const_A = np.load("/memryx_models/yolov9/_model_22_Constant_9_output_0.npy")
|
self.const_A = np.load(
|
||||||
self.const_B = np.load("/memryx_models/yolov9/_model_22_Constant_10_output_0.npy")
|
"/memryx_models/yolov9/_model_22_Constant_9_output_0.npy"
|
||||||
self.const_C = np.load("/memryx_models/yolov9/_model_22_Constant_12_output_0.npy")
|
)
|
||||||
|
self.const_B = np.load(
|
||||||
|
"/memryx_models/yolov9/_model_22_Constant_10_output_0.npy"
|
||||||
|
)
|
||||||
|
self.const_C = np.load(
|
||||||
|
"/memryx_models/yolov9/_model_22_Constant_12_output_0.npy"
|
||||||
|
)
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||||
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
|
self.model_url = (
|
||||||
|
"https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
|
||||||
|
)
|
||||||
self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx"
|
self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx"
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||||
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
|
self.model_url = (
|
||||||
|
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
|
||||||
|
)
|
||||||
self.set_strides_grids()
|
self.set_strides_grids()
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||||
self.model_url = "https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
|
self.model_url = (
|
||||||
self.expected_post_model = "ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
"https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
|
||||||
|
)
|
||||||
|
self.expected_post_model = (
|
||||||
|
"ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
self.check_and_prepare_model()
|
self.check_and_prepare_model()
|
||||||
logger.info(f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}")
|
logger.info(
|
||||||
|
f"Initializing MemryX with model: {self.memx_model_path} on device {self.memx_device_path}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load MemryX Model
|
# Load MemryX Model
|
||||||
logger.info(f"dfp path: {self.memx_model_path}")
|
logger.info(f"dfp path: {self.memx_model_path}")
|
||||||
|
|
||||||
# Initialization code
|
# Initialization code
|
||||||
self.accl = AsyncAccl(self.memx_model_path, mxserver_addr="host.docker.internal")
|
self.accl = AsyncAccl(
|
||||||
|
self.memx_model_path, mxserver_addr="host.docker.internal"
|
||||||
|
)
|
||||||
|
|
||||||
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
||||||
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
||||||
@ -102,7 +129,9 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.accl.connect_input(self.process_input)
|
self.accl.connect_input(self.process_input)
|
||||||
self.accl.connect_output(self.process_output)
|
self.accl.connect_output(self.process_output)
|
||||||
|
|
||||||
logger.info(f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}")
|
logger.info(
|
||||||
|
f"Loaded MemryX model from {self.memx_model_path} and {self.memx_post_model}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize MemryX model: {e}")
|
logger.error(f"Failed to initialize MemryX model: {e}")
|
||||||
@ -117,14 +146,20 @@ class MemryXDetector(DetectionApi):
|
|||||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
post_model_file_path = os.path.join(self.cache_dir, self.expected_post_model)
|
post_model_file_path = os.path.join(
|
||||||
|
self.cache_dir, self.expected_post_model
|
||||||
|
)
|
||||||
|
|
||||||
# Check if both post model file exist
|
# Check if both post model file exist
|
||||||
if os.path.isfile(post_model_file_path):
|
if os.path.isfile(post_model_file_path):
|
||||||
self.memx_post_model = post_model_file_path
|
self.memx_post_model = post_model_file_path
|
||||||
logger.info(f"Post-processing model found at {post_model_file_path}, skipping download.")
|
logger.info(
|
||||||
|
f"Post-processing model found at {post_model_file_path}, skipping download."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Model files not found. Downloading from {self.model_url}...")
|
logger.info(
|
||||||
|
f"Model files not found. Downloading from {self.model_url}..."
|
||||||
|
)
|
||||||
zip_path = os.path.join(self.cache_dir, "memryx_model.zip")
|
zip_path = os.path.join(self.cache_dir, "memryx_model.zip")
|
||||||
|
|
||||||
# Download the ZIP file
|
# Download the ZIP file
|
||||||
@ -143,30 +178,36 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.memx_post_model = os.path.join(self.cache_dir, file)
|
self.memx_post_model = os.path.join(self.cache_dir, file)
|
||||||
|
|
||||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||||
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
|
logger.info(
|
||||||
|
f"Assigned Post-processing Model Path: {self.memx_post_model}"
|
||||||
|
)
|
||||||
|
|
||||||
# Cleanup: Remove the ZIP file after extraction
|
# Cleanup: Remove the ZIP file after extraction
|
||||||
os.remove(zip_path)
|
os.remove(zip_path)
|
||||||
logger.info("Cleaned up ZIP file after extraction.")
|
logger.info("Cleaned up ZIP file after extraction.")
|
||||||
|
|
||||||
def send_input(self, connection_id, tensor_input: np.ndarray):
|
def send_input(self, connection_id, tensor_input: np.ndarray):
|
||||||
"""Pre-process (if needed) and send frame to MemryX input queue"""
|
"""Pre-process (if needed) and send frame to MemryX input queue"""
|
||||||
if tensor_input is None:
|
if tensor_input is None:
|
||||||
raise ValueError("[send_input] No image data provided for inference")
|
raise ValueError("[send_input] No image data provided for inference")
|
||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yolox:
|
if self.memx_model_type == ModelTypeEnum.yolox:
|
||||||
tensor_input = tensor_input.squeeze(0)
|
tensor_input = tensor_input.squeeze(0)
|
||||||
|
|
||||||
padded_img = np.ones((640, 640, 3),
|
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
|
||||||
dtype=np.uint8) * 114
|
|
||||||
|
|
||||||
scale = min(640 / float(tensor_input.shape[0]),
|
scale = min(
|
||||||
640 / float(tensor_input.shape[1]))
|
640 / float(tensor_input.shape[0]), 640 / float(tensor_input.shape[1])
|
||||||
sx,sy = int(tensor_input.shape[1] * scale), int(tensor_input.shape[0] * scale)
|
)
|
||||||
|
sx, sy = (
|
||||||
|
int(tensor_input.shape[1] * scale),
|
||||||
|
int(tensor_input.shape[0] * scale),
|
||||||
|
)
|
||||||
|
|
||||||
resized_img = cv2.resize(tensor_input, (sx,sy), interpolation=cv2.INTER_LINEAR)
|
resized_img = cv2.resize(
|
||||||
|
tensor_input, (sx, sy), interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
padded_img[:sy, :sx] = resized_img.astype(np.uint8)
|
padded_img[:sy, :sx] = resized_img.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
# Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels
|
# Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels
|
||||||
x0 = padded_img[0::2, 0::2, :] # Top-left
|
x0 = padded_img[0::2, 0::2, :] # Top-left
|
||||||
@ -176,7 +217,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
# Step 5: Concatenate along the channel dimension (axis 2)
|
# Step 5: Concatenate along the channel dimension (axis 2)
|
||||||
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
|
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
|
||||||
processed_input = concatenated_img.astype(np.float32)
|
processed_input = concatenated_img.astype(np.float32)
|
||||||
else:
|
else:
|
||||||
processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize
|
processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize
|
||||||
# Assuming original input is always NHWC and MemryX wants HWNC:
|
# Assuming original input is always NHWC and MemryX wants HWNC:
|
||||||
@ -191,8 +232,10 @@ class MemryXDetector(DetectionApi):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Wait for a frame from the queue (blocking call)
|
# Wait for a frame from the queue (blocking call)
|
||||||
frame = self.capture_queue.get(block=True) # Blocks until data is available
|
frame = self.capture_queue.get(
|
||||||
|
block=True
|
||||||
|
) # Blocks until data is available
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -201,7 +244,9 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
def receive_output(self):
|
def receive_output(self):
|
||||||
"""Retrieve processed results from MemryX output queue + a copy of the original frame"""
|
"""Retrieve processed results from MemryX output queue + a copy of the original frame"""
|
||||||
connection_id = self.capture_id_queue.get() # Get the corresponding connection ID
|
connection_id = (
|
||||||
|
self.capture_id_queue.get()
|
||||||
|
) # Get the corresponding connection ID
|
||||||
detections = self.output_queue.get() # Get detections from MemryX
|
detections = self.output_queue.get() # Get detections from MemryX
|
||||||
|
|
||||||
return connection_id, detections
|
return connection_id, detections
|
||||||
@ -216,10 +261,10 @@ class MemryXDetector(DetectionApi):
|
|||||||
break
|
break
|
||||||
|
|
||||||
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
|
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
|
||||||
|
|
||||||
if class_id < 0:
|
if class_id < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
detections[i] = [
|
detections[i] = [
|
||||||
class_id,
|
class_id,
|
||||||
confidence,
|
confidence,
|
||||||
@ -228,7 +273,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
y_max / self.memx_model_height,
|
y_max / self.memx_model_height,
|
||||||
x_max / self.memx_model_width,
|
x_max / self.memx_model_width,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Return the list of final detections
|
# Return the list of final detections
|
||||||
self.output_queue.put(detections)
|
self.output_queue.put(detections)
|
||||||
|
|
||||||
@ -244,9 +289,8 @@ class MemryXDetector(DetectionApi):
|
|||||||
(pos[0] - (pos[2] / 2)) / self.memx_model_width, # x_min
|
(pos[0] - (pos[2] / 2)) / self.memx_model_width, # x_min
|
||||||
(pos[1] + (pos[3] / 2)) / self.memx_model_height, # y_max
|
(pos[1] + (pos[3] / 2)) / self.memx_model_height, # y_max
|
||||||
(pos[0] + (pos[2] / 2)) / self.memx_model_width, # x_max
|
(pos[0] + (pos[2] / 2)) / self.memx_model_width, # x_max
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_strides_grids(self):
|
def set_strides_grids(self):
|
||||||
grids = []
|
grids = []
|
||||||
expanded_strides = []
|
expanded_strides = []
|
||||||
@ -266,15 +310,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.expanded_strides = np.concatenate(expanded_strides, 1)
|
self.expanded_strides = np.concatenate(expanded_strides, 1)
|
||||||
|
|
||||||
def sigmoid(self, x: np.ndarray) -> np.ndarray:
|
def sigmoid(self, x: np.ndarray) -> np.ndarray:
|
||||||
|
|
||||||
return 1 / (1 + np.exp(-x))
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
def onnx_concat(self, inputs: list, axis: int) -> np.ndarray:
|
def onnx_concat(self, inputs: list, axis: int) -> np.ndarray:
|
||||||
|
|
||||||
# Ensure all inputs are numpy arrays
|
# Ensure all inputs are numpy arrays
|
||||||
if not all(isinstance(x, np.ndarray) for x in inputs):
|
if not all(isinstance(x, np.ndarray) for x in inputs):
|
||||||
raise TypeError("All inputs must be numpy arrays.")
|
raise TypeError("All inputs must be numpy arrays.")
|
||||||
|
|
||||||
# Ensure shapes match on non-concat axes
|
# Ensure shapes match on non-concat axes
|
||||||
ref_shape = list(inputs[0].shape)
|
ref_shape = list(inputs[0].shape)
|
||||||
for i, tensor in enumerate(inputs[1:], start=1):
|
for i, tensor in enumerate(inputs[1:], start=1):
|
||||||
@ -282,12 +324,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
if ax == axis:
|
if ax == axis:
|
||||||
continue
|
continue
|
||||||
if tensor.shape[ax] != ref_shape[ax]:
|
if tensor.shape[ax] != ref_shape[ax]:
|
||||||
raise ValueError(f"Shape mismatch at axis {ax} between input[0] and input[{i}]")
|
raise ValueError(
|
||||||
|
f"Shape mismatch at axis {ax} between input[0] and input[{i}]"
|
||||||
|
)
|
||||||
|
|
||||||
return np.concatenate(inputs, axis=axis)
|
return np.concatenate(inputs, axis=axis)
|
||||||
|
|
||||||
def onnx_reshape(self, data: np.ndarray, shape: np.ndarray) -> np.ndarray:
|
def onnx_reshape(self, data: np.ndarray, shape: np.ndarray) -> np.ndarray:
|
||||||
|
|
||||||
# Ensure shape is a 1D array of integers
|
# Ensure shape is a 1D array of integers
|
||||||
target_shape = shape.astype(int).tolist()
|
target_shape = shape.astype(int).tolist()
|
||||||
|
|
||||||
@ -295,23 +338,24 @@ class MemryXDetector(DetectionApi):
|
|||||||
reshaped = np.reshape(data, target_shape)
|
reshaped = np.reshape(data, target_shape)
|
||||||
|
|
||||||
return reshaped
|
return reshaped
|
||||||
|
|
||||||
def post_process_yolox(self, output):
|
|
||||||
|
|
||||||
output = [np.expand_dims(tensor, axis=0) for tensor in output] # Shape: (1, H, W, C)
|
def post_process_yolox(self, output):
|
||||||
|
output = [
|
||||||
|
np.expand_dims(tensor, axis=0) for tensor in output
|
||||||
|
] # Shape: (1, H, W, C)
|
||||||
|
|
||||||
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
|
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
|
||||||
output = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in output]
|
output = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in output]
|
||||||
|
|
||||||
output_785 = output[0] # 785
|
output_785 = output[0] # 785
|
||||||
output_794 = output[1] # 794
|
output_794 = output[1] # 794
|
||||||
output_795 = output[2] # 795
|
output_795 = output[2] # 795
|
||||||
output_811 = output[3] # 811
|
output_811 = output[3] # 811
|
||||||
output_820 = output[4] # 820
|
output_820 = output[4] # 820
|
||||||
output_821 = output[5] # 821
|
output_821 = output[5] # 821
|
||||||
output_837 = output[6] # 837
|
output_837 = output[6] # 837
|
||||||
output_846 = output[7] # 846
|
output_846 = output[7] # 846
|
||||||
output_847 = output[8] # 847
|
output_847 = output[8] # 847
|
||||||
|
|
||||||
output_795 = self.sigmoid(output_795)
|
output_795 = self.sigmoid(output_795)
|
||||||
output_785 = self.sigmoid(output_785)
|
output_785 = self.sigmoid(output_785)
|
||||||
@ -324,7 +368,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
concat_2 = self.onnx_concat([output_820, output_821, output_811], axis=1)
|
concat_2 = self.onnx_concat([output_820, output_821, output_811], axis=1)
|
||||||
concat_3 = self.onnx_concat([output_846, output_847, output_837], axis=1)
|
concat_3 = self.onnx_concat([output_846, output_847, output_837], axis=1)
|
||||||
|
|
||||||
shape = np.array([1, 85, -1], dtype=np.int64)
|
shape = np.array([1, 85, -1], dtype=np.int64)
|
||||||
|
|
||||||
reshape_1 = self.onnx_reshape(concat_1, shape)
|
reshape_1 = self.onnx_reshape(concat_1, shape)
|
||||||
reshape_2 = self.onnx_reshape(concat_2, shape)
|
reshape_2 = self.onnx_reshape(concat_2, shape)
|
||||||
@ -332,7 +376,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
concat_out = self.onnx_concat([reshape_1, reshape_2, reshape_3], axis=2)
|
concat_out = self.onnx_concat([reshape_1, reshape_2, reshape_3], axis=2)
|
||||||
|
|
||||||
output = concat_out.transpose(0,2,1) #1, 840, 85
|
output = concat_out.transpose(0, 2, 1) # 1, 840, 85
|
||||||
|
|
||||||
self.num_classes = output.shape[2] - 5
|
self.num_classes = output.shape[2] - 5
|
||||||
|
|
||||||
@ -343,8 +387,10 @@ class MemryXDetector(DetectionApi):
|
|||||||
results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides
|
results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides
|
||||||
image_pred = results[0, ...]
|
image_pred = results[0, ...]
|
||||||
|
|
||||||
class_conf = np.max(image_pred[:, 5:5 + self.num_classes], axis=1, keepdims=True)
|
class_conf = np.max(
|
||||||
class_pred = np.argmax(image_pred[:, 5:5 + self.num_classes], axis=1)
|
image_pred[:, 5 : 5 + self.num_classes], axis=1, keepdims=True
|
||||||
|
)
|
||||||
|
class_pred = np.argmax(image_pred[:, 5 : 5 + self.num_classes], axis=1)
|
||||||
class_pred = np.expand_dims(class_pred, axis=1)
|
class_pred = np.expand_dims(class_pred, axis=1)
|
||||||
|
|
||||||
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= 0.3).squeeze()
|
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= 0.3).squeeze()
|
||||||
@ -364,7 +410,6 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
self.output_queue.put(final_detections)
|
self.output_queue.put(final_detections)
|
||||||
|
|
||||||
|
|
||||||
def post_process_ssdlite(self, outputs):
|
def post_process_ssdlite(self, outputs):
|
||||||
dets = outputs[0].squeeze(0) # Shape: (1, num_dets, 5)
|
dets = outputs[0].squeeze(0) # Shape: (1, num_dets, 5)
|
||||||
labels = outputs[1].squeeze(0)
|
labels = outputs[1].squeeze(0)
|
||||||
@ -414,12 +459,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
x_max /= self.memx_model_width
|
x_max /= self.memx_model_width
|
||||||
y_max /= self.memx_model_height
|
y_max /= self.memx_model_height
|
||||||
|
|
||||||
final_detections[i] = [class_id, confidence, y_min, x_min, y_max, x_max]
|
final_detections[i] = [class_id, confidence, y_min, x_min, y_max, x_max]
|
||||||
|
|
||||||
self.output_queue.put(final_detections)
|
self.output_queue.put(final_detections)
|
||||||
|
|
||||||
def onnx_reshape_with_allowzero(self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0) -> np.ndarray:
|
def onnx_reshape_with_allowzero(
|
||||||
|
self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0
|
||||||
|
) -> np.ndarray:
|
||||||
shape = shape.astype(int)
|
shape = shape.astype(int)
|
||||||
input_shape = data.shape
|
input_shape = data.shape
|
||||||
output_shape = []
|
output_shape = []
|
||||||
@ -436,9 +482,11 @@ class MemryXDetector(DetectionApi):
|
|||||||
return reshaped
|
return reshaped
|
||||||
|
|
||||||
def process_output(self, *outputs):
|
def process_output(self, *outputs):
|
||||||
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
||||||
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
||||||
outputs = [np.expand_dims(tensor, axis=0) for tensor in outputs] # Shape: (1, H, W, C)
|
outputs = [
|
||||||
|
np.expand_dims(tensor, axis=0) for tensor in outputs
|
||||||
|
] # Shape: (1, H, W, C)
|
||||||
|
|
||||||
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
|
# Move channel axis from 3rd (last) position to 1st position → (1, C, H, W)
|
||||||
outputs = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in outputs]
|
outputs = [np.transpose(tensor, (0, 3, 1, 2)) for tensor in outputs]
|
||||||
@ -466,15 +514,17 @@ class MemryXDetector(DetectionApi):
|
|||||||
split_sizes = [64, 80]
|
split_sizes = [64, 80]
|
||||||
|
|
||||||
# Calculate indices at which to split
|
# Calculate indices at which to split
|
||||||
indices = np.cumsum(split_sizes)[:-1] # [64] — split before the second chunk
|
indices = np.cumsum(split_sizes)[
|
||||||
|
:-1
|
||||||
|
] # [64] — split before the second chunk
|
||||||
|
|
||||||
# Perform split along axis 1
|
# Perform split along axis 1
|
||||||
split_0, split_1 = np.split(concat_4, indices, axis=axis)
|
split_0, split_1 = np.split(concat_4, indices, axis=axis)
|
||||||
|
|
||||||
shape1 = np.array([1,4,16,8400])
|
shape1 = np.array([1, 4, 16, 8400])
|
||||||
reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0)
|
reshape_4 = self.onnx_reshape_with_allowzero(split_0, shape1, allowzero=0)
|
||||||
|
|
||||||
transpose_1 = reshape_4.transpose(0,2,1,3)
|
transpose_1 = reshape_4.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
axis = 1 # As per ONNX softmax node
|
axis = 1 # As per ONNX softmax node
|
||||||
|
|
||||||
@ -488,10 +538,14 @@ class MemryXDetector(DetectionApi):
|
|||||||
W = np.arange(16, dtype=np.float32).reshape(1, 16, 1, 1) # (1, 16, 1, 1)
|
W = np.arange(16, dtype=np.float32).reshape(1, 16, 1, 1) # (1, 16, 1, 1)
|
||||||
|
|
||||||
# Apply 1x1 convolution: this is a weighted sum over channels
|
# Apply 1x1 convolution: this is a weighted sum over channels
|
||||||
conv_output = np.sum(softmax_output * W, axis=1, keepdims=True) # shape: (1, 1, 4, 8400)
|
conv_output = np.sum(
|
||||||
|
softmax_output * W, axis=1, keepdims=True
|
||||||
|
) # shape: (1, 1, 4, 8400)
|
||||||
|
|
||||||
shape2 = np.array([1,4,8400])
|
shape2 = np.array([1, 4, 8400])
|
||||||
reshape_5 = self.onnx_reshape_with_allowzero(conv_output, shape2, allowzero=0)
|
reshape_5 = self.onnx_reshape_with_allowzero(
|
||||||
|
conv_output, shape2, allowzero=0
|
||||||
|
)
|
||||||
|
|
||||||
# ONNX Slice — get first 2 channels: [0:2] along axis 1
|
# ONNX Slice — get first 2 channels: [0:2] along axis 1
|
||||||
slice_output1 = reshape_5[:, 0:2, :] # Result: (1, 2, 8400)
|
slice_output1 = reshape_5[:, 0:2, :] # Result: (1, 2, 8400)
|
||||||
@ -511,7 +565,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
div_output = add1 / 2.0
|
div_output = add1 / 2.0
|
||||||
|
|
||||||
concat_5 = self.onnx_concat([div_output, sub1], axis=1)
|
concat_5 = self.onnx_concat([div_output, sub1], axis=1)
|
||||||
|
|
||||||
# Expand B to (1, 1, 8400) so it can broadcast across axis=1 (4 channels)
|
# Expand B to (1, 1, 8400) so it can broadcast across axis=1 (4 channels)
|
||||||
const_C_expanded = self.const_C[:, np.newaxis, :] # Shape: (1, 1, 8400)
|
const_C_expanded = self.const_C[:, np.newaxis, :] # Shape: (1, 1, 8400)
|
||||||
@ -522,23 +576,25 @@ class MemryXDetector(DetectionApi):
|
|||||||
sigmoid_output = self.sigmoid(split_1)
|
sigmoid_output = self.sigmoid(split_1)
|
||||||
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
|
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
|
||||||
|
|
||||||
final_detections = post_process_yolov9(outputs, self.memx_model_width, self.memx_model_height)
|
final_detections = post_process_yolov9(
|
||||||
|
outputs, self.memx_model_width, self.memx_model_height
|
||||||
|
)
|
||||||
self.output_queue.put(final_detections)
|
self.output_queue.put(final_detections)
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||||
return self.post_process_yolonas(outputs)
|
return self.post_process_yolonas(outputs)
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||||
return self.post_process_yolox(outputs)
|
return self.post_process_yolox(outputs)
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||||
return self.post_process_ssdlite(outputs)
|
return self.post_process_ssdlite(outputs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
|
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
|
||||||
)
|
)
|
||||||
|
|
||||||
def detect_raw(self, tensor_input: np.ndarray):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
""" Removed synchronous detect_raw() function so that we only use async """
|
"""Removed synchronous detect_raw() function so that we only use async"""
|
||||||
return 0
|
return 0
|
||||||
|
@ -139,7 +139,7 @@ def run_detector(
|
|||||||
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
||||||
|
|
||||||
logger.info("Exited detection process...")
|
logger.info("Exited detection process...")
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
def async_run_detector(
|
def async_run_detector(
|
||||||
@ -194,7 +194,7 @@ def async_run_detector(
|
|||||||
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
#send input to Accelator
|
# send input to Accelator
|
||||||
start.value = datetime.datetime.now().timestamp()
|
start.value = datetime.datetime.now().timestamp()
|
||||||
object_detector.detect_api.send_input(connection_id, input_frame)
|
object_detector.detect_api.send_input(connection_id, input_frame)
|
||||||
|
|
||||||
@ -231,7 +231,6 @@ def async_run_detector(
|
|||||||
logger.info("Exited async detection process...")
|
logger.info("Exited async detection process...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectProcess:
|
class ObjectDetectProcess:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -266,7 +265,7 @@ class ObjectDetectProcess:
|
|||||||
self.detection_start.value = 0.0
|
self.detection_start.value = 0.0
|
||||||
if (self.detect_process is not None) and self.detect_process.is_alive():
|
if (self.detect_process is not None) and self.detect_process.is_alive():
|
||||||
self.stop()
|
self.stop()
|
||||||
if (self.detector_config.type == 'memryx'):
|
if self.detector_config.type == "memryx":
|
||||||
# MemryX requires asynchronous detection handling using async_run_detector
|
# MemryX requires asynchronous detection handling using async_run_detector
|
||||||
self.detect_process = util.Process(
|
self.detect_process = util.Process(
|
||||||
target=async_run_detector,
|
target=async_run_detector,
|
||||||
@ -292,7 +291,7 @@ class ObjectDetectProcess:
|
|||||||
self.detection_start,
|
self.detection_start,
|
||||||
self.detector_config,
|
self.detector_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.detect_process.daemon = True
|
self.detect_process.daemon = True
|
||||||
self.detect_process.start()
|
self.detect_process.start()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user