From 161e7b3fd75a9065fa6cbba082746b22708ed38f Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Fri, 3 Mar 2023 23:44:17 +0000 Subject: [PATCH] Allow using full tensorflow in cpu/edgetpu detector plugins (#5611) It supports the same entrypoints, given that tflite is a small cut-out of the big tensorflow picture. This patch was created for downstream usage in nixpkgs, where we don't have the tflite python package, but do have the full tensorflow package. --- frigate/detectors/plugins/cpu_tfl.py | 8 ++++++-- frigate/detectors/plugins/edgetpu_tfl.py | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/frigate/detectors/plugins/cpu_tfl.py b/frigate/detectors/plugins/cpu_tfl.py index 9e24cb1f4..fb9cbbfae 100644 --- a/frigate/detectors/plugins/cpu_tfl.py +++ b/frigate/detectors/plugins/cpu_tfl.py @@ -5,7 +5,11 @@ from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig from typing import Literal from pydantic import Extra, Field -import tflite_runtime.interpreter as tflite + +try: + from tflite_runtime.interpreter import Interpreter +except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter logger = logging.getLogger(__name__) @@ -22,7 +26,7 @@ class CpuTfl(DetectionApi): type_key = DETECTOR_KEY def __init__(self, detector_config: CpuDetectorConfig): - self.interpreter = tflite.Interpreter( + self.interpreter = Interpreter( model_path=detector_config.model.path or "/cpu_model.tflite", num_threads=detector_config.num_threads or 3, ) diff --git a/frigate/detectors/plugins/edgetpu_tfl.py b/frigate/detectors/plugins/edgetpu_tfl.py index 024e6574b..840d41f66 100644 --- a/frigate/detectors/plugins/edgetpu_tfl.py +++ b/frigate/detectors/plugins/edgetpu_tfl.py @@ -5,8 +5,11 @@ from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig from typing import Literal from pydantic import Extra, Field -import tflite_runtime.interpreter as tflite -from tflite_runtime.interpreter import load_delegate + +try: + from tflite_runtime.interpreter import Interpreter, load_delegate +except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter, load_delegate logger = logging.getLogger(__name__) @@ -33,7 +36,7 @@ class EdgeTpuTfl(DetectionApi): logger.info(f"Attempting to load TPU as {device_config['device']}") edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) logger.info("TPU found") - self.interpreter = tflite.Interpreter( + self.interpreter = Interpreter( model_path=detector_config.model.path or "/edgetpu_model.tflite", experimental_delegates=[edge_tpu_delegate], )