mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-26 19:06:11 +01:00
Implement common post_processing (#11408)
* implement common post_processing * fix formatting * rename yolonas to post_process_yolonas
This commit is contained in:
parent
97f5ba0145
commit
a70dd02788
@ -80,7 +80,7 @@ model:
|
|||||||
# Valid values are nhwc or nchw (default: shown below)
|
# Valid values are nhwc or nchw (default: shown below)
|
||||||
input_tensor: nhwc
|
input_tensor: nhwc
|
||||||
# Optional: Object detection model type, currently only used with the OpenVINO detector
|
# Optional: Object detection model type, currently only used with the OpenVINO detector
|
||||||
# Valid values are ssd, yolox (default: shown below)
|
# Valid values are ssd, yolox, yolonas (default: shown below)
|
||||||
model_type: ssd
|
model_type: ssd
|
||||||
# Optional: Label name modifications. These are merged into the standard labelmap.
|
# Optional: Label name modifications. These are merged into the standard labelmap.
|
||||||
labelmap:
|
labelmap:
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from frigate.detectors.detector_config import ModelTypeEnum
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -9,8 +13,67 @@ class DetectionApi(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, detector_config):
|
def __init__(self, detector_config):
|
||||||
pass
|
self.detector_config = detector_config
|
||||||
|
self.thresh = 0.5
|
||||||
|
self.height = detector_config.model.height
|
||||||
|
self.width = detector_config.model.width
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def post_process_yolonas(self, output):
|
||||||
|
"""
|
||||||
|
@param output: output of inference
|
||||||
|
expected shape: [np.array(1, N, 4), np.array(1, N, 80)]
|
||||||
|
where N depends on the input size e.g. N=2100 for 320x320 images
|
||||||
|
|
||||||
|
@return: best results: np.array(20, 6) where each row is
|
||||||
|
in this order (class_id, score, y1/height, x1/width, y2/height, x2/width)
|
||||||
|
"""
|
||||||
|
|
||||||
|
N = output[0].shape[1]
|
||||||
|
|
||||||
|
boxes = output[0].reshape(N, 4)
|
||||||
|
scores = output[1].reshape(N, 80)
|
||||||
|
|
||||||
|
class_ids = np.argmax(scores, axis=1)
|
||||||
|
scores = scores[np.arange(N), class_ids]
|
||||||
|
|
||||||
|
args_best = np.argwhere(scores > self.thresh)[:, 0]
|
||||||
|
|
||||||
|
num_matches = len(args_best)
|
||||||
|
if num_matches == 0:
|
||||||
|
return np.zeros((20, 6), np.float32)
|
||||||
|
elif num_matches > 20:
|
||||||
|
args_best20 = np.argpartition(scores[args_best], -20)[-20:]
|
||||||
|
args_best = args_best[args_best20]
|
||||||
|
|
||||||
|
boxes = boxes[args_best]
|
||||||
|
class_ids = class_ids[args_best]
|
||||||
|
scores = scores[args_best]
|
||||||
|
|
||||||
|
boxes = np.transpose(
|
||||||
|
np.vstack(
|
||||||
|
(
|
||||||
|
boxes[:, 1] / self.height,
|
||||||
|
boxes[:, 0] / self.width,
|
||||||
|
boxes[:, 3] / self.height,
|
||||||
|
boxes[:, 2] / self.width,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = np.hstack(
|
||||||
|
(class_ids[..., np.newaxis], scores[..., np.newaxis], boxes)
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.resize(results, (20, 6))
|
||||||
|
|
||||||
|
def post_process(self, output):
|
||||||
|
if self.detector_config.model.model_type == ModelTypeEnum.yolonas:
|
||||||
|
return self.yolonas(output)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Model type "{self.detector_config.model.model_type}" is currently not supported.'
|
||||||
|
)
|
||||||
|
@ -30,6 +30,7 @@ class InputTensorEnum(str, Enum):
|
|||||||
class ModelTypeEnum(str, Enum):
|
class ModelTypeEnum(str, Enum):
|
||||||
ssd = "ssd"
|
ssd = "ssd"
|
||||||
yolox = "yolox"
|
yolox = "yolox"
|
||||||
|
yolonas = "yolonas"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
Loading…
Reference in New Issue
Block a user