mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-20 13:54:36 +01:00
* Wait for config to load before evaluating route access Fix race condition where custom role users are temporarily denied access after login while config is still loading. Defer route rendering in DefaultAppView until config is available so the complete role list is known before ProtectedRoute evaluates permissions * Use batching for state classification generation * Ignore incorrect scoring images if they make it through the deletion * Delete unclassified images * mitigate tensorflow atexit crash by pre-importing tflite/tensorflow on main thread Pre-import Interpreter in embeddings maintainer and add defensive lazy imports in classification processors to avoid worker-thread tensorflow imports causing "can't register atexit after shutdown" * don't require old password for users with admin role when changing passwords * don't render actions menu if no options are available * Remove hwaccel arg as it is not used for encoding * change password button text --------- Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com>
177 lines
5.7 KiB
Python
177 lines
5.7 KiB
Python
"""Handle processing images to classify birds."""
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from frigate.comms.event_metadata_updater import (
|
|
EventMetadataPublisher,
|
|
EventMetadataTypeEnum,
|
|
)
|
|
from frigate.config import FrigateConfig
|
|
from frigate.const import MODEL_CACHE_DIR
|
|
from frigate.log import redirect_output_to_logger
|
|
from frigate.util.object import calculate_region
|
|
|
|
from ..types import DataProcessorMetrics
|
|
from .api import RealTimeProcessorApi
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BirdRealTimeProcessor(RealTimeProcessorApi):
|
|
def __init__(
|
|
self,
|
|
config: FrigateConfig,
|
|
sub_label_publisher: EventMetadataPublisher,
|
|
metrics: DataProcessorMetrics,
|
|
):
|
|
super().__init__(config, metrics)
|
|
self.interpreter: Any | None = None
|
|
self.sub_label_publisher = sub_label_publisher
|
|
self.tensor_input_details: dict[str, Any] = None
|
|
self.tensor_output_details: dict[str, Any] = None
|
|
self.detected_birds: dict[str, float] = {}
|
|
self.labelmap: dict[int, str] = {}
|
|
|
|
GITHUB_RAW_ENDPOINT = os.environ.get(
|
|
"GITHUB_RAW_ENDPOINT", "https://raw.githubusercontent.com"
|
|
)
|
|
download_path = os.path.join(MODEL_CACHE_DIR, "bird")
|
|
self.model_files = {
|
|
"bird.tflite": f"{GITHUB_RAW_ENDPOINT}/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
|
|
"birdmap.txt": f"{GITHUB_RAW_ENDPOINT}/google-coral/test_data/master/inat_bird_labels.txt",
|
|
}
|
|
|
|
if not all(
|
|
os.path.exists(os.path.join(download_path, n))
|
|
for n in self.model_files.keys()
|
|
):
|
|
# conditionally import ModelDownloader
|
|
from frigate.util.downloader import ModelDownloader
|
|
|
|
self.downloader = ModelDownloader(
|
|
model_name="bird",
|
|
download_path=download_path,
|
|
file_names=self.model_files.keys(),
|
|
download_func=self.__download_models,
|
|
complete_func=self.__build_detector,
|
|
)
|
|
self.downloader.ensure_model_files()
|
|
else:
|
|
self.__build_detector()
|
|
|
|
def __download_models(self, path: str) -> None:
|
|
try:
|
|
file_name = os.path.basename(path)
|
|
|
|
# conditionally import ModelDownloader
|
|
from frigate.util.downloader import ModelDownloader
|
|
|
|
ModelDownloader.download_from_url(self.model_files[file_name], path)
|
|
except Exception as e:
|
|
logger.error(f"Failed to download {path}: {e}")
|
|
|
|
@redirect_output_to_logger(logger, logging.DEBUG)
|
|
def __build_detector(self) -> None:
|
|
try:
|
|
from tflite_runtime.interpreter import Interpreter
|
|
except ModuleNotFoundError:
|
|
from tensorflow.lite.python.interpreter import Interpreter
|
|
|
|
self.interpreter = Interpreter(
|
|
model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"),
|
|
num_threads=2,
|
|
)
|
|
self.interpreter.allocate_tensors()
|
|
self.tensor_input_details = self.interpreter.get_input_details()
|
|
self.tensor_output_details = self.interpreter.get_output_details()
|
|
|
|
i = 0
|
|
|
|
with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
|
|
line = f.readline()
|
|
while line:
|
|
start = line.find("(")
|
|
end = line.find(")")
|
|
self.labelmap[i] = line[start + 1 : end]
|
|
i += 1
|
|
line = f.readline()
|
|
|
|
def process_frame(self, obj_data, frame):
|
|
if not self.interpreter:
|
|
return
|
|
|
|
if obj_data["label"] != "bird":
|
|
return
|
|
|
|
x, y, x2, y2 = calculate_region(
|
|
frame.shape,
|
|
obj_data["box"][0],
|
|
obj_data["box"][1],
|
|
obj_data["box"][2],
|
|
obj_data["box"][3],
|
|
int(
|
|
max(
|
|
obj_data["box"][1] - obj_data["box"][0],
|
|
obj_data["box"][3] - obj_data["box"][2],
|
|
)
|
|
* 1.1
|
|
),
|
|
1.0,
|
|
)
|
|
|
|
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
|
input = rgb[
|
|
y:y2,
|
|
x:x2,
|
|
]
|
|
|
|
if input.shape != (224, 224):
|
|
try:
|
|
input = cv2.resize(input, (224, 224))
|
|
except Exception:
|
|
logger.warning("Failed to resize image for bird classification")
|
|
return
|
|
|
|
input = np.expand_dims(input, axis=0)
|
|
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
|
self.interpreter.invoke()
|
|
res: np.ndarray = self.interpreter.get_tensor(
|
|
self.tensor_output_details[0]["index"]
|
|
)[0]
|
|
probs = res / res.sum(axis=0)
|
|
best_id = np.argmax(probs)
|
|
|
|
if best_id == 964:
|
|
logger.debug("No bird classification was detected.")
|
|
return
|
|
|
|
score = round(probs[best_id], 2)
|
|
|
|
if score < self.config.classification.bird.threshold:
|
|
logger.debug(f"Score {score} is not above required threshold")
|
|
return
|
|
|
|
previous_score = self.detected_birds.get(obj_data["id"], 0.0)
|
|
|
|
if score <= previous_score:
|
|
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
|
return
|
|
|
|
self.sub_label_publisher.publish(
|
|
(obj_data["id"], self.labelmap[best_id], score),
|
|
EventMetadataTypeEnum.sub_label.value,
|
|
)
|
|
self.detected_birds[obj_data["id"]] = score
|
|
|
|
def handle_request(self, topic, request_data):
|
|
return None
|
|
|
|
def expire_object(self, object_id, camera):
|
|
if object_id in self.detected_birds:
|
|
self.detected_birds.pop(object_id)
|