Unify tensor handling

This commit is contained in:
Abinila Siva 2025-05-06 13:50:22 -04:00
parent 18f1cc1638
commit d0cd0759ea
2 changed files with 5 additions and 10 deletions

View File

@ -31,6 +31,7 @@ class InputTensorEnum(str, Enum):
class InputDTypeEnum(str, Enum):
float = "float"
float_denorm = "float_denorm" # non-normalized float
int = "int"
@ -208,4 +209,4 @@ class BaseDetectorConfig(BaseModel):
)
model_config = ConfigDict(
extra="allow", arbitrary_types_allowed=True, protected_namespaces=()
)
)

View File

@ -212,9 +212,8 @@ class MemryXDetector(DetectionApi):
raise ValueError("[send_input] No image data provided for inference")
if self.memx_model_type == ModelTypeEnum.yolox:
tensor_input = tensor_input.squeeze(0)
tensor_input = tensor_input.squeeze(2)
tensor_input = tensor_input * 255.0
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
scale = min(
@ -238,15 +237,10 @@ class MemryXDetector(DetectionApi):
# Step 5: Concatenate along the channel dimension (axis 2)
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
processed_input = concatenated_img.astype(np.float32)
else:
tensor_input = tensor_input.squeeze(0) # (H, W, C)
# Add axis=2 to create Z=1: (H, W, Z=1, C)
processed_input = np.expand_dims(tensor_input, axis=2) # Now (H, W, 1, 3)
tensor_input = concatenated_img.astype(np.float32)
# Send frame to MemryX for processing
self.capture_queue.put(processed_input)
self.capture_queue.put(tensor_input)
self.capture_id_queue.put(connection_id)
def process_input(self):