allow for custom object detection model via configuration

This commit is contained in:
Jason Hunter 2021-09-12 02:06:37 -04:00 committed by Blake Blackshear
parent 89e317a6bb
commit a7b7a45b23
3 changed files with 16 additions and 5 deletions

View File

@ -170,6 +170,7 @@ class FrigateApp:
self.mqtt_relay.start() self.mqtt_relay.start()
def start_detectors(self): def start_detectors(self):
model_path = self.config.model.path
model_shape = (self.config.model.height, self.config.model.width) model_shape = (self.config.model.height, self.config.model.width)
for name in self.config.cameras.keys(): for name in self.config.cameras.keys():
self.detection_out_events[name] = mp.Event() self.detection_out_events[name] = mp.Event()
@ -199,6 +200,7 @@ class FrigateApp:
name, name,
self.detection_queue, self.detection_queue,
self.detection_out_events, self.detection_out_events,
model_path,
model_shape, model_shape,
"cpu", "cpu",
detector.num_threads, detector.num_threads,
@ -208,6 +210,7 @@ class FrigateApp:
name, name,
self.detection_queue, self.detection_queue,
self.detection_out_events, self.detection_out_events,
model_path,
model_shape, model_shape,
detector.device, detector.device,
detector.num_threads, detector.num_threads,

View File

@ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel):
class ModelConfig(FrigateBaseModel): class ModelConfig(FrigateBaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
width: int = Field(default=320, title="Object detection model input width.") width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.") height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field( labelmap: Dict[int, str] = Field(
@ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel):
super().__init__(**config) super().__init__(**config)
self._merged_labelmap = { self._merged_labelmap = {
**load_labels("/labelmap.txt"), **load_labels(config.get("labelmap_path", "/labelmap.txt")),
**config.get("labelmap", {}), **config.get("labelmap", {}),
} }

View File

@ -45,7 +45,7 @@ class ObjectDetector(ABC):
class LocalObjectDetector(ObjectDetector): class LocalObjectDetector(ObjectDetector):
def __init__(self, tf_device=None, num_threads=3, labels=None): def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None):
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
self.labels = {} self.labels = {}
@ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector):
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
logger.info("TPU found") logger.info("TPU found")
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path="/edgetpu_model.tflite", model_path=model_path or "/edgetpu_model.tflite",
experimental_delegates=[edge_tpu_delegate], experimental_delegates=[edge_tpu_delegate],
) )
except ValueError: except ValueError:
@ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector):
"CPU detectors are not recommended and should only be used for testing or for trial purposes." "CPU detectors are not recommended and should only be used for testing or for trial purposes."
) )
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path="/cpu_model.tflite", num_threads=num_threads model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
) )
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()
@ -133,6 +133,7 @@ def run_detector(
out_events: Dict[str, mp.Event], out_events: Dict[str, mp.Event],
avg_speed, avg_speed,
start, start,
model_path,
model_shape, model_shape,
tf_device, tf_device,
num_threads, num_threads,
@ -152,7 +153,9 @@ def run_detector(
signal.signal(signal.SIGINT, receiveSignal) signal.signal(signal.SIGINT, receiveSignal)
frame_manager = SharedMemoryFrameManager() frame_manager = SharedMemoryFrameManager()
object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads) object_detector = LocalObjectDetector(
tf_device=tf_device, model_path=model_path, num_threads=num_threads
)
outputs = {} outputs = {}
for name in out_events.keys(): for name in out_events.keys():
@ -189,6 +192,7 @@ class EdgeTPUProcess:
name, name,
detection_queue, detection_queue,
out_events, out_events,
model_path,
model_shape, model_shape,
tf_device=None, tf_device=None,
num_threads=3, num_threads=3,
@ -199,6 +203,7 @@ class EdgeTPUProcess:
self.avg_inference_speed = mp.Value("d", 0.01) self.avg_inference_speed = mp.Value("d", 0.01)
self.detection_start = mp.Value("d", 0.0) self.detection_start = mp.Value("d", 0.0)
self.detect_process = None self.detect_process = None
self.model_path = model_path
self.model_shape = model_shape self.model_shape = model_shape
self.tf_device = tf_device self.tf_device = tf_device
self.num_threads = num_threads self.num_threads = num_threads
@ -226,6 +231,7 @@ class EdgeTPUProcess:
self.out_events, self.out_events,
self.avg_inference_speed, self.avg_inference_speed,
self.detection_start, self.detection_start,
self.model_path,
self.model_shape, self.model_shape,
self.tf_device, self.tf_device,
self.num_threads, self.num_threads,