Don't fail on invalid class IDs for TensorRT detector (#8438)

* Don't fail on invalid class IDs

* Fix whitespace

* Make log warning
This commit is contained in:
Nicolas Mowen 2023-11-03 20:19:58 -06:00 committed by GitHub
parent ac53993f70
commit 4f7b710112
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -293,6 +293,16 @@ class TensorRtDetector(DetectionApi):
# raw_detections: Nx7 numpy arrays of # raw_detections: Nx7 numpy arrays of
# [[x, y, w, h, box_confidence, class_id, class_prob], # [[x, y, w, h, box_confidence, class_id, class_prob],
# throw out any detections with negative class IDs
valid_detections = []
for r in raw_detections:
if r[5] >= 0:
valid_detections.append(r)
else:
logger.warning(f"Found TensorRT detection with invalid class id {r}")
raw_detections = valid_detections
# Calculate score as box_confidence x class_prob # Calculate score as box_confidence x class_prob
raw_detections[:, 4] = raw_detections[:, 4] * raw_detections[:, 6] raw_detections[:, 4] = raw_detections[:, 4] * raw_detections[:, 6]
# Reorder elements by the score, best on top, remove class_prob # Reorder elements by the score, best on top, remove class_prob
@ -303,6 +313,7 @@ class TensorRtDetector(DetectionApi):
ordered[:, 3] = np.clip(ordered[:, 3] + ordered[:, 1], 0, 1) ordered[:, 3] = np.clip(ordered[:, 3] + ordered[:, 1], 0, 1)
# put result into the correct order and limit to top 20 # put result into the correct order and limit to top 20
detections = ordered[:, [5, 4, 1, 0, 3, 2]][:20] detections = ordered[:, [5, 4, 1, 0, 3, 2]][:20]
# pad to 20x6 shape # pad to 20x6 shape
append_cnt = 20 - len(detections) append_cnt = 20 - len(detections)
if append_cnt > 0: if append_cnt > 0: