mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
allow for custom object detection model via configuration
This commit is contained in:
parent
89e317a6bb
commit
a7b7a45b23
@ -170,6 +170,7 @@ class FrigateApp:
|
|||||||
self.mqtt_relay.start()
|
self.mqtt_relay.start()
|
||||||
|
|
||||||
def start_detectors(self):
|
def start_detectors(self):
|
||||||
|
model_path = self.config.model.path
|
||||||
model_shape = (self.config.model.height, self.config.model.width)
|
model_shape = (self.config.model.height, self.config.model.width)
|
||||||
for name in self.config.cameras.keys():
|
for name in self.config.cameras.keys():
|
||||||
self.detection_out_events[name] = mp.Event()
|
self.detection_out_events[name] = mp.Event()
|
||||||
@ -199,6 +200,7 @@ class FrigateApp:
|
|||||||
name,
|
name,
|
||||||
self.detection_queue,
|
self.detection_queue,
|
||||||
self.detection_out_events,
|
self.detection_out_events,
|
||||||
|
model_path,
|
||||||
model_shape,
|
model_shape,
|
||||||
"cpu",
|
"cpu",
|
||||||
detector.num_threads,
|
detector.num_threads,
|
||||||
@ -208,6 +210,7 @@ class FrigateApp:
|
|||||||
name,
|
name,
|
||||||
self.detection_queue,
|
self.detection_queue,
|
||||||
self.detection_out_events,
|
self.detection_out_events,
|
||||||
|
model_path,
|
||||||
model_shape,
|
model_shape,
|
||||||
detector.device,
|
detector.device,
|
||||||
detector.num_threads,
|
detector.num_threads,
|
||||||
|
@ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelConfig(FrigateBaseModel):
|
class ModelConfig(FrigateBaseModel):
|
||||||
|
path: Optional[str] = Field(title="Custom Object detection model path.")
|
||||||
|
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
|
||||||
width: int = Field(default=320, title="Object detection model input width.")
|
width: int = Field(default=320, title="Object detection model input width.")
|
||||||
height: int = Field(default=320, title="Object detection model input height.")
|
height: int = Field(default=320, title="Object detection model input height.")
|
||||||
labelmap: Dict[int, str] = Field(
|
labelmap: Dict[int, str] = Field(
|
||||||
@ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel):
|
|||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
|
|
||||||
self._merged_labelmap = {
|
self._merged_labelmap = {
|
||||||
**load_labels("/labelmap.txt"),
|
**load_labels(config.get("labelmap_path", "/labelmap.txt")),
|
||||||
**config.get("labelmap", {}),
|
**config.get("labelmap", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class ObjectDetector(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class LocalObjectDetector(ObjectDetector):
|
class LocalObjectDetector(ObjectDetector):
|
||||||
def __init__(self, tf_device=None, num_threads=3, labels=None):
|
def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None):
|
||||||
self.fps = EventsPerSecond()
|
self.fps = EventsPerSecond()
|
||||||
if labels is None:
|
if labels is None:
|
||||||
self.labels = {}
|
self.labels = {}
|
||||||
@ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
|
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
|
||||||
logger.info("TPU found")
|
logger.info("TPU found")
|
||||||
self.interpreter = tflite.Interpreter(
|
self.interpreter = tflite.Interpreter(
|
||||||
model_path="/edgetpu_model.tflite",
|
model_path=model_path or "/edgetpu_model.tflite",
|
||||||
experimental_delegates=[edge_tpu_delegate],
|
experimental_delegates=[edge_tpu_delegate],
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||||
)
|
)
|
||||||
self.interpreter = tflite.Interpreter(
|
self.interpreter = tflite.Interpreter(
|
||||||
model_path="/cpu_model.tflite", num_threads=num_threads
|
model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
@ -133,6 +133,7 @@ def run_detector(
|
|||||||
out_events: Dict[str, mp.Event],
|
out_events: Dict[str, mp.Event],
|
||||||
avg_speed,
|
avg_speed,
|
||||||
start,
|
start,
|
||||||
|
model_path,
|
||||||
model_shape,
|
model_shape,
|
||||||
tf_device,
|
tf_device,
|
||||||
num_threads,
|
num_threads,
|
||||||
@ -152,7 +153,9 @@ def run_detector(
|
|||||||
signal.signal(signal.SIGINT, receiveSignal)
|
signal.signal(signal.SIGINT, receiveSignal)
|
||||||
|
|
||||||
frame_manager = SharedMemoryFrameManager()
|
frame_manager = SharedMemoryFrameManager()
|
||||||
object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads)
|
object_detector = LocalObjectDetector(
|
||||||
|
tf_device=tf_device, model_path=model_path, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for name in out_events.keys():
|
for name in out_events.keys():
|
||||||
@ -189,6 +192,7 @@ class EdgeTPUProcess:
|
|||||||
name,
|
name,
|
||||||
detection_queue,
|
detection_queue,
|
||||||
out_events,
|
out_events,
|
||||||
|
model_path,
|
||||||
model_shape,
|
model_shape,
|
||||||
tf_device=None,
|
tf_device=None,
|
||||||
num_threads=3,
|
num_threads=3,
|
||||||
@ -199,6 +203,7 @@ class EdgeTPUProcess:
|
|||||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
self.avg_inference_speed = mp.Value("d", 0.01)
|
||||||
self.detection_start = mp.Value("d", 0.0)
|
self.detection_start = mp.Value("d", 0.0)
|
||||||
self.detect_process = None
|
self.detect_process = None
|
||||||
|
self.model_path = model_path
|
||||||
self.model_shape = model_shape
|
self.model_shape = model_shape
|
||||||
self.tf_device = tf_device
|
self.tf_device = tf_device
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
@ -226,6 +231,7 @@ class EdgeTPUProcess:
|
|||||||
self.out_events,
|
self.out_events,
|
||||||
self.avg_inference_speed,
|
self.avg_inference_speed,
|
||||||
self.detection_start,
|
self.detection_start,
|
||||||
|
self.model_path,
|
||||||
self.model_shape,
|
self.model_shape,
|
||||||
self.tf_device,
|
self.tf_device,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
|
Loading…
Reference in New Issue
Block a user