mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	allow for custom object detection model via configuration
This commit is contained in:
		
							parent
							
								
									89e317a6bb
								
							
						
					
					
						commit
						a7b7a45b23
					
				| @ -170,6 +170,7 @@ class FrigateApp: | |||||||
|         self.mqtt_relay.start() |         self.mqtt_relay.start() | ||||||
| 
 | 
 | ||||||
|     def start_detectors(self): |     def start_detectors(self): | ||||||
|  |         model_path = self.config.model.path | ||||||
|         model_shape = (self.config.model.height, self.config.model.width) |         model_shape = (self.config.model.height, self.config.model.width) | ||||||
|         for name in self.config.cameras.keys(): |         for name in self.config.cameras.keys(): | ||||||
|             self.detection_out_events[name] = mp.Event() |             self.detection_out_events[name] = mp.Event() | ||||||
| @ -199,6 +200,7 @@ class FrigateApp: | |||||||
|                     name, |                     name, | ||||||
|                     self.detection_queue, |                     self.detection_queue, | ||||||
|                     self.detection_out_events, |                     self.detection_out_events, | ||||||
|  |                     model_path, | ||||||
|                     model_shape, |                     model_shape, | ||||||
|                     "cpu", |                     "cpu", | ||||||
|                     detector.num_threads, |                     detector.num_threads, | ||||||
| @ -208,6 +210,7 @@ class FrigateApp: | |||||||
|                     name, |                     name, | ||||||
|                     self.detection_queue, |                     self.detection_queue, | ||||||
|                     self.detection_out_events, |                     self.detection_out_events, | ||||||
|  |                     model_path, | ||||||
|                     model_shape, |                     model_shape, | ||||||
|                     detector.device, |                     detector.device, | ||||||
|                     detector.num_threads, |                     detector.num_threads, | ||||||
|  | |||||||
| @ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ModelConfig(FrigateBaseModel): | class ModelConfig(FrigateBaseModel): | ||||||
|  |     path: Optional[str] = Field(title="Custom Object detection model path.") | ||||||
|  |     labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") | ||||||
|     width: int = Field(default=320, title="Object detection model input width.") |     width: int = Field(default=320, title="Object detection model input width.") | ||||||
|     height: int = Field(default=320, title="Object detection model input height.") |     height: int = Field(default=320, title="Object detection model input height.") | ||||||
|     labelmap: Dict[int, str] = Field( |     labelmap: Dict[int, str] = Field( | ||||||
| @ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel): | |||||||
|         super().__init__(**config) |         super().__init__(**config) | ||||||
| 
 | 
 | ||||||
|         self._merged_labelmap = { |         self._merged_labelmap = { | ||||||
|             **load_labels("/labelmap.txt"), |             **load_labels(config.get("labelmap_path", "/labelmap.txt")), | ||||||
|             **config.get("labelmap", {}), |             **config.get("labelmap", {}), | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -45,7 +45,7 @@ class ObjectDetector(ABC): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LocalObjectDetector(ObjectDetector): | class LocalObjectDetector(ObjectDetector): | ||||||
|     def __init__(self, tf_device=None, num_threads=3, labels=None): |     def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None): | ||||||
|         self.fps = EventsPerSecond() |         self.fps = EventsPerSecond() | ||||||
|         if labels is None: |         if labels is None: | ||||||
|             self.labels = {} |             self.labels = {} | ||||||
| @ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector): | |||||||
|                 edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) |                 edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) | ||||||
|                 logger.info("TPU found") |                 logger.info("TPU found") | ||||||
|                 self.interpreter = tflite.Interpreter( |                 self.interpreter = tflite.Interpreter( | ||||||
|                     model_path="/edgetpu_model.tflite", |                     model_path=model_path or "/edgetpu_model.tflite", | ||||||
|                     experimental_delegates=[edge_tpu_delegate], |                     experimental_delegates=[edge_tpu_delegate], | ||||||
|                 ) |                 ) | ||||||
|             except ValueError: |             except ValueError: | ||||||
| @ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector): | |||||||
|                 "CPU detectors are not recommended and should only be used for testing or for trial purposes." |                 "CPU detectors are not recommended and should only be used for testing or for trial purposes." | ||||||
|             ) |             ) | ||||||
|             self.interpreter = tflite.Interpreter( |             self.interpreter = tflite.Interpreter( | ||||||
|                 model_path="/cpu_model.tflite", num_threads=num_threads |                 model_path=model_path or "/cpu_model.tflite", num_threads=num_threads | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         self.interpreter.allocate_tensors() |         self.interpreter.allocate_tensors() | ||||||
| @ -133,6 +133,7 @@ def run_detector( | |||||||
|     out_events: Dict[str, mp.Event], |     out_events: Dict[str, mp.Event], | ||||||
|     avg_speed, |     avg_speed, | ||||||
|     start, |     start, | ||||||
|  |     model_path, | ||||||
|     model_shape, |     model_shape, | ||||||
|     tf_device, |     tf_device, | ||||||
|     num_threads, |     num_threads, | ||||||
| @ -152,7 +153,9 @@ def run_detector( | |||||||
|     signal.signal(signal.SIGINT, receiveSignal) |     signal.signal(signal.SIGINT, receiveSignal) | ||||||
| 
 | 
 | ||||||
|     frame_manager = SharedMemoryFrameManager() |     frame_manager = SharedMemoryFrameManager() | ||||||
|     object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads) |     object_detector = LocalObjectDetector( | ||||||
|  |         tf_device=tf_device, model_path=model_path, num_threads=num_threads | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
|     outputs = {} |     outputs = {} | ||||||
|     for name in out_events.keys(): |     for name in out_events.keys(): | ||||||
| @ -189,6 +192,7 @@ class EdgeTPUProcess: | |||||||
|         name, |         name, | ||||||
|         detection_queue, |         detection_queue, | ||||||
|         out_events, |         out_events, | ||||||
|  |         model_path, | ||||||
|         model_shape, |         model_shape, | ||||||
|         tf_device=None, |         tf_device=None, | ||||||
|         num_threads=3, |         num_threads=3, | ||||||
| @ -199,6 +203,7 @@ class EdgeTPUProcess: | |||||||
|         self.avg_inference_speed = mp.Value("d", 0.01) |         self.avg_inference_speed = mp.Value("d", 0.01) | ||||||
|         self.detection_start = mp.Value("d", 0.0) |         self.detection_start = mp.Value("d", 0.0) | ||||||
|         self.detect_process = None |         self.detect_process = None | ||||||
|  |         self.model_path = model_path | ||||||
|         self.model_shape = model_shape |         self.model_shape = model_shape | ||||||
|         self.tf_device = tf_device |         self.tf_device = tf_device | ||||||
|         self.num_threads = num_threads |         self.num_threads = num_threads | ||||||
| @ -226,6 +231,7 @@ class EdgeTPUProcess: | |||||||
|                 self.out_events, |                 self.out_events, | ||||||
|                 self.avg_inference_speed, |                 self.avg_inference_speed, | ||||||
|                 self.detection_start, |                 self.detection_start, | ||||||
|  |                 self.model_path, | ||||||
|                 self.model_shape, |                 self.model_shape, | ||||||
|                 self.tf_device, |                 self.tf_device, | ||||||
|                 self.num_threads, |                 self.num_threads, | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user