Unify tensor handling

This commit is contained in:
Abinila Siva 2025-05-06 13:50:22 -04:00
parent 18f1cc1638
commit d0cd0759ea
2 changed files with 5 additions and 10 deletions

View File

@ -31,6 +31,7 @@ class InputTensorEnum(str, Enum):
class InputDTypeEnum(str, Enum): class InputDTypeEnum(str, Enum):
float = "float" float = "float"
float_denorm = "float_denorm" # non-normalized float
int = "int" int = "int"

View File

@ -212,9 +212,8 @@ class MemryXDetector(DetectionApi):
raise ValueError("[send_input] No image data provided for inference") raise ValueError("[send_input] No image data provided for inference")
if self.memx_model_type == ModelTypeEnum.yolox: if self.memx_model_type == ModelTypeEnum.yolox:
tensor_input = tensor_input.squeeze(0) tensor_input = tensor_input.squeeze(2)
tensor_input = tensor_input * 255.0
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114 padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
scale = min( scale = min(
@ -238,15 +237,10 @@ class MemryXDetector(DetectionApi):
# Step 5: Concatenate along the channel dimension (axis 2) # Step 5: Concatenate along the channel dimension (axis 2)
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2) concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
processed_input = concatenated_img.astype(np.float32) tensor_input = concatenated_img.astype(np.float32)
else:
tensor_input = tensor_input.squeeze(0) # (H, W, C)
# Add axis=2 to create Z=1: (H, W, Z=1, C)
processed_input = np.expand_dims(tensor_input, axis=2) # Now (H, W, 1, 3)
# Send frame to MemryX for processing # Send frame to MemryX for processing
self.capture_queue.put(processed_input) self.capture_queue.put(tensor_input)
self.capture_id_queue.put(connection_id) self.capture_id_queue.put(connection_id)
def process_input(self): def process_input(self):