mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Don't fail on invalid class IDs for TensorRT detector (#8438)
* Don't fail on invalid class IDs * Fix whitespace * Make log warning
This commit is contained in:
parent
ac53993f70
commit
4f7b710112
@ -293,6 +293,16 @@ class TensorRtDetector(DetectionApi):
|
||||
# raw_detections: Nx7 numpy arrays of
|
||||
# [[x, y, w, h, box_confidence, class_id, class_prob],
|
||||
|
||||
# throw out any detections with negative class IDs
|
||||
valid_detections = []
|
||||
for r in raw_detections:
|
||||
if r[5] >= 0:
|
||||
valid_detections.append(r)
|
||||
else:
|
||||
logger.warning(f"Found TensorRT detection with invalid class id {r}")
|
||||
|
||||
raw_detections = valid_detections
|
||||
|
||||
# Calculate score as box_confidence x class_prob
|
||||
raw_detections[:, 4] = raw_detections[:, 4] * raw_detections[:, 6]
|
||||
# Reorder elements by the score, best on top, remove class_prob
|
||||
@ -303,6 +313,7 @@ class TensorRtDetector(DetectionApi):
|
||||
ordered[:, 3] = np.clip(ordered[:, 3] + ordered[:, 1], 0, 1)
|
||||
# put result into the correct order and limit to top 20
|
||||
detections = ordered[:, [5, 4, 1, 0, 3, 2]][:20]
|
||||
|
||||
# pad to 20x6 shape
|
||||
append_cnt = 20 - len(detections)
|
||||
if append_cnt > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user