From a7b7a45b23c7b541d2779e585625e7de517d61a9 Mon Sep 17 00:00:00 2001 From: Jason Hunter Date: Sun, 12 Sep 2021 02:06:37 -0400 Subject: [PATCH] allow for custom object detection model via configuration --- frigate/app.py | 3 +++ frigate/config.py | 4 +++- frigate/edgetpu.py | 14 ++++++++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/frigate/app.py b/frigate/app.py index bf3f12989..e73c56c6f 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -170,6 +170,7 @@ class FrigateApp: self.mqtt_relay.start() def start_detectors(self): + model_path = self.config.model.path model_shape = (self.config.model.height, self.config.model.width) for name in self.config.cameras.keys(): self.detection_out_events[name] = mp.Event() @@ -199,6 +200,7 @@ class FrigateApp: name, self.detection_queue, self.detection_out_events, + model_path, model_shape, "cpu", detector.num_threads, @@ -208,6 +210,7 @@ class FrigateApp: name, self.detection_queue, self.detection_out_events, + model_path, model_shape, detector.device, detector.num_threads, diff --git a/frigate/config.py b/frigate/config.py index bfd54b1ba..edba6a410 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel): class ModelConfig(FrigateBaseModel): + path: Optional[str] = Field(title="Custom Object detection model path.") + labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") width: int = Field(default=320, title="Object detection model input width.") height: int = Field(default=320, title="Object detection model input height.") labelmap: Dict[int, str] = Field( @@ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel): super().__init__(**config) self._merged_labelmap = { - **load_labels("/labelmap.txt"), + **load_labels(config.get("labelmap_path", "/labelmap.txt")), **config.get("labelmap", {}), } diff --git a/frigate/edgetpu.py b/frigate/edgetpu.py index 62c35eaf5..1992c6b35 100644 --- a/frigate/edgetpu.py +++ b/frigate/edgetpu.py @@ -45,7 +45,7 @@ class ObjectDetector(ABC): class LocalObjectDetector(ObjectDetector): - def __init__(self, tf_device=None, num_threads=3, labels=None): + def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None): self.fps = EventsPerSecond() if labels is None: self.labels = {} @@ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector): edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) logger.info("TPU found") self.interpreter = tflite.Interpreter( - model_path="/edgetpu_model.tflite", + model_path=model_path or "/edgetpu_model.tflite", experimental_delegates=[edge_tpu_delegate], ) except ValueError: @@ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector): "CPU detectors are not recommended and should only be used for testing or for trial purposes." ) self.interpreter = tflite.Interpreter( - model_path="/cpu_model.tflite", num_threads=num_threads + model_path=model_path or "/cpu_model.tflite", num_threads=num_threads ) self.interpreter.allocate_tensors() @@ -133,6 +133,7 @@ def run_detector( out_events: Dict[str, mp.Event], avg_speed, start, + model_path, model_shape, tf_device, num_threads, @@ -152,7 +153,9 @@ def run_detector( signal.signal(signal.SIGINT, receiveSignal) frame_manager = SharedMemoryFrameManager() - object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads) + object_detector = LocalObjectDetector( + tf_device=tf_device, model_path=model_path, num_threads=num_threads + ) outputs = {} for name in out_events.keys(): @@ -189,6 +192,7 @@ class EdgeTPUProcess: name, detection_queue, out_events, + model_path, model_shape, tf_device=None, num_threads=3, @@ -199,6 +203,7 @@ class EdgeTPUProcess: self.avg_inference_speed = mp.Value("d", 0.01) self.detection_start = mp.Value("d", 0.0) self.detect_process = None + self.model_path = model_path self.model_shape = model_shape self.tf_device = tf_device self.num_threads = num_threads @@ -226,6 +231,7 @@ class EdgeTPUProcess: self.out_events, self.avg_inference_speed, self.detection_start, + self.model_path, self.model_shape, self.tf_device, self.num_threads,