Bird classification (#15966)

* Start working on bird processor

* Initial setup for bird processing

* Improvements to handling

* Get classification working

* Cleanup classification

* Add classification config

* Update sort
This commit is contained in:
Nicolas Mowen 2025-01-13 08:09:04 -07:00
parent c9c011f05b
commit 9a2de78fc9
6 changed files with 185 additions and 7 deletions

View File

@ -3,13 +3,13 @@ from frigate.detectors import DetectorConfig, ModelConfig # noqa: F401
from .auth import * # noqa: F403 from .auth import * # noqa: F403
from .camera import * # noqa: F403 from .camera import * # noqa: F403
from .camera_group import * # noqa: F403 from .camera_group import * # noqa: F403
from .classification import * # noqa: F403
from .config import * # noqa: F403 from .config import * # noqa: F403
from .database import * # noqa: F403 from .database import * # noqa: F403
from .logger import * # noqa: F403 from .logger import * # noqa: F403
from .mqtt import * # noqa: F403 from .mqtt import * # noqa: F403
from .notification import * # noqa: F403 from .notification import * # noqa: F403
from .proxy import * # noqa: F403 from .proxy import * # noqa: F403
from .semantic_search import * # noqa: F403
from .telemetry import * # noqa: F403 from .telemetry import * # noqa: F403
from .tls import * # noqa: F403 from .tls import * # noqa: F403
from .ui import * # noqa: F403 from .ui import * # noqa: F403

View File

@ -11,6 +11,22 @@ __all__ = [
] ]
class BirdClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable bird classification.")
threshold: float = Field(
default=0.9,
title="Minimum classification score required to be considered a match.",
gt=0.0,
le=1.0,
)
class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)
class SemanticSearchConfig(FrigateBaseModel): class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable semantic search.") enabled: bool = Field(default=False, title="Enable semantic search.")
reindex: Optional[bool] = Field( reindex: Optional[bool] = Field(

View File

@ -51,17 +51,18 @@ from .camera.review import ReviewConfig
from .camera.snapshots import SnapshotsConfig from .camera.snapshots import SnapshotsConfig
from .camera.timestamp import TimestampStyleConfig from .camera.timestamp import TimestampStyleConfig
from .camera_group import CameraGroupConfig from .camera_group import CameraGroupConfig
from .classification import (
ClassificationConfig,
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .database import DatabaseConfig from .database import DatabaseConfig
from .env import EnvVars from .env import EnvVars
from .logger import LoggerConfig from .logger import LoggerConfig
from .mqtt import MqttConfig from .mqtt import MqttConfig
from .notification import NotificationConfig from .notification import NotificationConfig
from .proxy import ProxyConfig from .proxy import ProxyConfig
from .semantic_search import (
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .telemetry import TelemetryConfig from .telemetry import TelemetryConfig
from .tls import TlsConfig from .tls import TlsConfig
from .ui import UIConfig from .ui import UIConfig
@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel):
default_factory=TelemetryConfig, title="Telemetry configuration." default_factory=TelemetryConfig, title="Telemetry configuration."
) )
tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.") tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.")
classification: ClassificationConfig = Field(
default_factory=ClassificationConfig, title="Object classification config."
)
semantic_search: SemanticSearchConfig = Field( semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration." default_factory=SemanticSearchConfig, title="Semantic search configuration."
) )

View File

@ -0,0 +1,154 @@
"""Handle processing images to classify birds."""
import logging
import os
import cv2
import numpy as np
import requests
from frigate.config import FrigateConfig
from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR
from frigate.util.object import calculate_region
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter
logger = logging.getLogger(__name__)
class BirdProcessor(RealTimeProcessorApi):
def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics):
super().__init__(config, metrics)
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, any] = None
self.tensor_output_details: dict[str, any] = None
self.detected_birds: dict[str, float] = {}
self.labelmap: dict[int, str] = {}
download_path = os.path.join(MODEL_CACHE_DIR, "bird")
self.model_files = {
"bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
"birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt",
}
if not all(
os.path.exists(os.path.join(download_path, n))
for n in self.model_files.keys()
):
# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader
self.downloader = ModelDownloader(
model_name="bird",
download_path=download_path,
file_names=self.model_files.keys(),
download_func=self.__download_models,
complete_func=self.__build_detector,
)
self.downloader.ensure_model_files()
else:
self.__build_detector()
def __download_models(self, path: str) -> None:
try:
file_name = os.path.basename(path)
# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader
ModelDownloader.download_from_url(self.model_files[file_name], path)
except Exception as e:
logger.error(f"Failed to download {path}: {e}")
def __build_detector(self) -> None:
self.interpreter = Interpreter(
model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"),
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
i = 0
with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
line = f.readline()
while line:
start = line.find("(")
end = line.find(")")
self.labelmap[i] = line[start + 1 : end]
i += 1
line = f.readline()
def process_frame(self, obj_data, frame):
if obj_data["label"] != "bird":
return
x, y, x2, y2 = calculate_region(
frame.shape,
obj_data["box"][0],
obj_data["box"][1],
obj_data["box"][2],
obj_data["box"][3],
224,
1.0,
)
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[
y:y2,
x:x2,
]
cv2.imwrite("/media/frigate/test_class.png", input)
input = np.expand_dims(input, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)
if best_id == 964:
logger.debug("No bird classification was detected.")
return
score = round(probs[best_id], 2)
if score < self.config.classification.bird.threshold:
logger.debug(f"Score {score} is not above required threshold")
return
previous_score = self.detected_birds.get(obj_data["id"], 0.0)
if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return
resp = requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
json={
"camera": obj_data.get("camera"),
"subLabel": self.labelmap[best_id],
"subLabelScore": score,
},
)
if resp.status_code == 200:
self.detected_birds[obj_data["id"]] = score
def handle_request(self, request_data):
return None
def expire_object(self, object_id):
if object_id in self.detected_birds:
self.detected_birds.pop(object_id)

View File

@ -8,7 +8,7 @@ from pyclipper import ET_CLOSEDPOLYGON, JT_ROUND, PyclipperOffset
from shapely.geometry import Polygon from shapely.geometry import Polygon
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import LicensePlateRecognitionConfig from frigate.config.classification import LicensePlateRecognitionConfig
from frigate.embeddings.embeddings import Embeddings from frigate.embeddings.embeddings import Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -30,6 +30,7 @@ from frigate.const import (
UPDATE_EVENT_DESCRIPTION, UPDATE_EVENT_DESCRIPTION,
) )
from frigate.data_processing.real_time.api import RealTimeProcessorApi from frigate.data_processing.real_time.api import RealTimeProcessorApi
from frigate.data_processing.real_time.bird_processor import BirdProcessor
from frigate.data_processing.real_time.face_processor import FaceProcessor from frigate.data_processing.real_time.face_processor import FaceProcessor
from frigate.data_processing.types import DataProcessorMetrics from frigate.data_processing.types import DataProcessorMetrics
from frigate.embeddings.lpr.lpr import LicensePlateRecognition from frigate.embeddings.lpr.lpr import LicensePlateRecognition
@ -78,6 +79,9 @@ class EmbeddingMaintainer(threading.Thread):
if self.config.face_recognition.enabled: if self.config.face_recognition.enabled:
self.processors.append(FaceProcessor(self.config, metrics)) self.processors.append(FaceProcessor(self.config, metrics))
if self.config.classification.bird.enabled:
self.processors.append(BirdProcessor(self.config, metrics))
# create communication for updating event descriptions # create communication for updating event descriptions
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
self.stop_event = stop_event self.stop_event = stop_event