Bird classification (#15966)

* Start working on bird processor

* Initial setup for bird processing

* Improvements to handling

* Get classification working

* Cleanup classification

* Add classification config

* Update sort
This commit is contained in:
Nicolas Mowen 2025-01-13 08:09:04 -07:00
parent c9c011f05b
commit 9a2de78fc9
6 changed files with 185 additions and 7 deletions

View File

@ -3,13 +3,13 @@ from frigate.detectors import DetectorConfig, ModelConfig # noqa: F401
from .auth import * # noqa: F403
from .camera import * # noqa: F403
from .camera_group import * # noqa: F403
from .classification import * # noqa: F403
from .config import * # noqa: F403
from .database import * # noqa: F403
from .logger import * # noqa: F403
from .mqtt import * # noqa: F403
from .notification import * # noqa: F403
from .proxy import * # noqa: F403
from .semantic_search import * # noqa: F403
from .telemetry import * # noqa: F403
from .tls import * # noqa: F403
from .ui import * # noqa: F403

View File

@ -11,6 +11,22 @@ __all__ = [
]
class BirdClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable bird classification.")
threshold: float = Field(
default=0.9,
title="Minimum classification score required to be considered a match.",
gt=0.0,
le=1.0,
)
class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)
class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable semantic search.")
reindex: Optional[bool] = Field(

View File

@ -51,17 +51,18 @@ from .camera.review import ReviewConfig
from .camera.snapshots import SnapshotsConfig
from .camera.timestamp import TimestampStyleConfig
from .camera_group import CameraGroupConfig
from .classification import (
ClassificationConfig,
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .database import DatabaseConfig
from .env import EnvVars
from .logger import LoggerConfig
from .mqtt import MqttConfig
from .notification import NotificationConfig
from .proxy import ProxyConfig
from .semantic_search import (
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .telemetry import TelemetryConfig
from .tls import TlsConfig
from .ui import UIConfig
@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel):
default_factory=TelemetryConfig, title="Telemetry configuration."
)
tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.")
classification: ClassificationConfig = Field(
default_factory=ClassificationConfig, title="Object classification config."
)
semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration."
)

View File

@ -0,0 +1,154 @@
"""Handle processing images to classify birds."""
import logging
import os
import cv2
import numpy as np
import requests
from frigate.config import FrigateConfig
from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR
from frigate.util.object import calculate_region
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter
logger = logging.getLogger(__name__)
class BirdProcessor(RealTimeProcessorApi):
def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics):
super().__init__(config, metrics)
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, any] = None
self.tensor_output_details: dict[str, any] = None
self.detected_birds: dict[str, float] = {}
self.labelmap: dict[int, str] = {}
download_path = os.path.join(MODEL_CACHE_DIR, "bird")
self.model_files = {
"bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
"birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt",
}
if not all(
os.path.exists(os.path.join(download_path, n))
for n in self.model_files.keys()
):
# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader
self.downloader = ModelDownloader(
model_name="bird",
download_path=download_path,
file_names=self.model_files.keys(),
download_func=self.__download_models,
complete_func=self.__build_detector,
)
self.downloader.ensure_model_files()
else:
self.__build_detector()
def __download_models(self, path: str) -> None:
try:
file_name = os.path.basename(path)
# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader
ModelDownloader.download_from_url(self.model_files[file_name], path)
except Exception as e:
logger.error(f"Failed to download {path}: {e}")
def __build_detector(self) -> None:
self.interpreter = Interpreter(
model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"),
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
i = 0
with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
line = f.readline()
while line:
start = line.find("(")
end = line.find(")")
self.labelmap[i] = line[start + 1 : end]
i += 1
line = f.readline()
def process_frame(self, obj_data, frame):
if obj_data["label"] != "bird":
return
x, y, x2, y2 = calculate_region(
frame.shape,
obj_data["box"][0],
obj_data["box"][1],
obj_data["box"][2],
obj_data["box"][3],
224,
1.0,
)
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[
y:y2,
x:x2,
]
cv2.imwrite("/media/frigate/test_class.png", input)
input = np.expand_dims(input, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)
if best_id == 964:
logger.debug("No bird classification was detected.")
return
score = round(probs[best_id], 2)
if score < self.config.classification.bird.threshold:
logger.debug(f"Score {score} is not above required threshold")
return
previous_score = self.detected_birds.get(obj_data["id"], 0.0)
if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return
resp = requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
json={
"camera": obj_data.get("camera"),
"subLabel": self.labelmap[best_id],
"subLabelScore": score,
},
)
if resp.status_code == 200:
self.detected_birds[obj_data["id"]] = score
def handle_request(self, request_data):
return None
def expire_object(self, object_id):
if object_id in self.detected_birds:
self.detected_birds.pop(object_id)

View File

@ -8,7 +8,7 @@ from pyclipper import ET_CLOSEDPOLYGON, JT_ROUND, PyclipperOffset
from shapely.geometry import Polygon
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import LicensePlateRecognitionConfig
from frigate.config.classification import LicensePlateRecognitionConfig
from frigate.embeddings.embeddings import Embeddings
logger = logging.getLogger(__name__)

View File

@ -30,6 +30,7 @@ from frigate.const import (
UPDATE_EVENT_DESCRIPTION,
)
from frigate.data_processing.real_time.api import RealTimeProcessorApi
from frigate.data_processing.real_time.bird_processor import BirdProcessor
from frigate.data_processing.real_time.face_processor import FaceProcessor
from frigate.data_processing.types import DataProcessorMetrics
from frigate.embeddings.lpr.lpr import LicensePlateRecognition
@ -78,6 +79,9 @@ class EmbeddingMaintainer(threading.Thread):
if self.config.face_recognition.enabled:
self.processors.append(FaceProcessor(self.config, metrics))
if self.config.classification.bird.enabled:
self.processors.append(BirdProcessor(self.config, metrics))
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
self.stop_event = stop_event