mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	* Fix the `Any` typing hint treewide There has been confusion between the Any type[1] and the any function[2] in typing hints. [1] https://docs.python.org/3/library/typing.html#typing.Any [2] https://docs.python.org/3/library/functions.html#any * Fix typing for various frame_shape members Frame shapes are most likely defined by height and width, so a single int cannot express that. * Wrap gpu stats functions in Optional[] These can return `None`, so they need to be `Type | None`, which is what `Optional` expresses very nicely. * Fix return type in get_latest_segment_datetime Returns a datetime object, not an integer. * Make the return type of FrameManager.write optional This is necessary since the SharedMemoryFrameManager.write function can return None. * Fix total_seconds() return type in get_tz_modifiers The function returns a float, not an int. https://docs.python.org/3/library/datetime.html#datetime.timedelta.total_seconds * Account for floating point results in to_relative_box Because the function uses division the return types may either be int or float. * Resolve ruff deprecation warning The config has been split into formatter and linter, and the global options are deprecated.
		
			
				
	
	
		
			159 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			159 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Handle processing images to classify birds."""
 | |
| 
 | |
| import logging
 | |
| import os
 | |
| from typing import Any
 | |
| 
 | |
| import cv2
 | |
| import numpy as np
 | |
| 
 | |
| from frigate.comms.event_metadata_updater import (
 | |
|     EventMetadataPublisher,
 | |
|     EventMetadataTypeEnum,
 | |
| )
 | |
| from frigate.config import FrigateConfig
 | |
| from frigate.const import MODEL_CACHE_DIR
 | |
| from frigate.util.object import calculate_region
 | |
| 
 | |
| from ..types import DataProcessorMetrics
 | |
| from .api import RealTimeProcessorApi
 | |
| 
 | |
| try:
 | |
|     from tflite_runtime.interpreter import Interpreter
 | |
| except ModuleNotFoundError:
 | |
|     from tensorflow.lite.python.interpreter import Interpreter
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class BirdRealTimeProcessor(RealTimeProcessorApi):
 | |
|     def __init__(
 | |
|         self,
 | |
|         config: FrigateConfig,
 | |
|         sub_label_publisher: EventMetadataPublisher,
 | |
|         metrics: DataProcessorMetrics,
 | |
|     ):
 | |
|         super().__init__(config, metrics)
 | |
|         self.interpreter: Interpreter = None
 | |
|         self.sub_label_publisher = sub_label_publisher
 | |
|         self.tensor_input_details: dict[str, Any] = None
 | |
|         self.tensor_output_details: dict[str, Any] = None
 | |
|         self.detected_birds: dict[str, float] = {}
 | |
|         self.labelmap: dict[int, str] = {}
 | |
| 
 | |
|         download_path = os.path.join(MODEL_CACHE_DIR, "bird")
 | |
|         self.model_files = {
 | |
|             "bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
 | |
|             "birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt",
 | |
|         }
 | |
| 
 | |
|         if not all(
 | |
|             os.path.exists(os.path.join(download_path, n))
 | |
|             for n in self.model_files.keys()
 | |
|         ):
 | |
|             # conditionally import ModelDownloader
 | |
|             from frigate.util.downloader import ModelDownloader
 | |
| 
 | |
|             self.downloader = ModelDownloader(
 | |
|                 model_name="bird",
 | |
|                 download_path=download_path,
 | |
|                 file_names=self.model_files.keys(),
 | |
|                 download_func=self.__download_models,
 | |
|                 complete_func=self.__build_detector,
 | |
|             )
 | |
|             self.downloader.ensure_model_files()
 | |
|         else:
 | |
|             self.__build_detector()
 | |
| 
 | |
|     def __download_models(self, path: str) -> None:
 | |
|         try:
 | |
|             file_name = os.path.basename(path)
 | |
| 
 | |
|             # conditionally import ModelDownloader
 | |
|             from frigate.util.downloader import ModelDownloader
 | |
| 
 | |
|             ModelDownloader.download_from_url(self.model_files[file_name], path)
 | |
|         except Exception as e:
 | |
|             logger.error(f"Failed to download {path}: {e}")
 | |
| 
 | |
|     def __build_detector(self) -> None:
 | |
|         self.interpreter = Interpreter(
 | |
|             model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"),
 | |
|             num_threads=2,
 | |
|         )
 | |
|         self.interpreter.allocate_tensors()
 | |
|         self.tensor_input_details = self.interpreter.get_input_details()
 | |
|         self.tensor_output_details = self.interpreter.get_output_details()
 | |
| 
 | |
|         i = 0
 | |
| 
 | |
|         with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
 | |
|             line = f.readline()
 | |
|             while line:
 | |
|                 start = line.find("(")
 | |
|                 end = line.find(")")
 | |
|                 self.labelmap[i] = line[start + 1 : end]
 | |
|                 i += 1
 | |
|                 line = f.readline()
 | |
| 
 | |
|     def process_frame(self, obj_data, frame):
 | |
|         if obj_data["label"] != "bird":
 | |
|             return
 | |
| 
 | |
|         x, y, x2, y2 = calculate_region(
 | |
|             frame.shape,
 | |
|             obj_data["box"][0],
 | |
|             obj_data["box"][1],
 | |
|             obj_data["box"][2],
 | |
|             obj_data["box"][3],
 | |
|             224,
 | |
|             1.0,
 | |
|         )
 | |
| 
 | |
|         rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
 | |
|         input = rgb[
 | |
|             y:y2,
 | |
|             x:x2,
 | |
|         ]
 | |
| 
 | |
|         if input.shape != (224, 224):
 | |
|             input = cv2.resize(input, (224, 224))
 | |
| 
 | |
|         input = np.expand_dims(input, axis=0)
 | |
|         self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
 | |
|         self.interpreter.invoke()
 | |
|         res: np.ndarray = self.interpreter.get_tensor(
 | |
|             self.tensor_output_details[0]["index"]
 | |
|         )[0]
 | |
|         probs = res / res.sum(axis=0)
 | |
|         best_id = np.argmax(probs)
 | |
| 
 | |
|         if best_id == 964:
 | |
|             logger.debug("No bird classification was detected.")
 | |
|             return
 | |
| 
 | |
|         score = round(probs[best_id], 2)
 | |
| 
 | |
|         if score < self.config.classification.bird.threshold:
 | |
|             logger.debug(f"Score {score} is not above required threshold")
 | |
|             return
 | |
| 
 | |
|         previous_score = self.detected_birds.get(obj_data["id"], 0.0)
 | |
| 
 | |
|         if score <= previous_score:
 | |
|             logger.debug(f"Score {score} is worse than previous score {previous_score}")
 | |
|             return
 | |
| 
 | |
|         self.sub_label_publisher.publish(
 | |
|             EventMetadataTypeEnum.sub_label,
 | |
|             (obj_data["id"], self.labelmap[best_id], score),
 | |
|         )
 | |
|         self.detected_birds[obj_data["id"]] = score
 | |
| 
 | |
|     def handle_request(self, topic, request_data):
 | |
|         return None
 | |
| 
 | |
|     def expire_object(self, object_id, camera):
 | |
|         if object_id in self.detected_birds:
 | |
|             self.detected_birds.pop(object_id)
 |