From b88fa9ece6887c28248368d99e59d7dcbb04481b Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Fri, 18 Jul 2025 08:28:02 -0600 Subject: [PATCH] Object attribute classification (#19205) * Add enum for type of classification for objects * Update recognized license plate topic to be used as attribute updater * Update attribute for attribute type object classification * Cleanup --- frigate/api/event.py | 3 +- frigate/comms/event_metadata_updater.py | 2 +- frigate/config/classification.py | 9 ++++++ .../common/license_plate/mixin.py | 4 +-- .../real_time/custom_classification.py | 23 +++++++++++--- frigate/track/object_processing.py | 30 ++++++++++--------- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/frigate/api/event.py b/frigate/api/event.py index d76e5f10b..c8f423b5d 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -1148,7 +1148,8 @@ def set_plate( new_score = None request.app.event_metadata_updater.publish( - EventMetadataTypeEnum.recognized_license_plate, (event_id, new_plate, new_score) + EventMetadataTypeEnum.attribute, + (event_id, "recognized_license_plate", new_plate, new_score), ) return JSONResponse( diff --git a/frigate/comms/event_metadata_updater.py b/frigate/comms/event_metadata_updater.py index 6305de5a1..5a2d6104d 100644 --- a/frigate/comms/event_metadata_updater.py +++ b/frigate/comms/event_metadata_updater.py @@ -15,7 +15,7 @@ class EventMetadataTypeEnum(str, Enum): manual_event_end = "manual_event_end" regenerate_description = "regenerate_description" sub_label = "sub_label" - recognized_license_plate = "recognized_license_plate" + attribute = "attribute" lpr_event_create = "lpr_event_create" save_lpr_snapshot = "save_lpr_snapshot" diff --git a/frigate/config/classification.py b/frigate/config/classification.py index 6b6e0cf52..572c70e23 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -34,6 +34,11 @@ class TriggerAction(str, Enum): NOTIFICATION = "notification" +class ObjectClassificationType(str, Enum): + sub_label = "sub_label" + attribute = "attribute" + + class AudioTranscriptionConfig(FrigateBaseModel): enabled: bool = Field(default=False, title="Enable audio transcription.") language: str = Field( @@ -88,6 +93,10 @@ class CustomClassificationStateConfig(FrigateBaseModel): class CustomClassificationObjectConfig(FrigateBaseModel): objects: list[str] = Field(title="Object types to classify.") + classification_type: ObjectClassificationType = Field( + default=ObjectClassificationType.sub_label, + title="Type of classification that is applied.", + ) class CustomClassificationConfig(FrigateBaseModel): diff --git a/frigate/data_processing/common/license_plate/mixin.py b/frigate/data_processing/common/license_plate/mixin.py index 66a0f63c9..c88fbc982 100644 --- a/frigate/data_processing/common/license_plate/mixin.py +++ b/frigate/data_processing/common/license_plate/mixin.py @@ -1537,8 +1537,8 @@ class LicensePlateProcessingMixin: ), ) self.sub_label_publisher.publish( - EventMetadataTypeEnum.recognized_license_plate, - (id, top_plate, avg_confidence), + EventMetadataTypeEnum.attribute, + (id, "recognized_license_plate", top_plate, avg_confidence), ) # save the best snapshot for dedicated lpr cams not using frigate+ diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 0ba8b3d17..71eb1cd87 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -15,7 +15,10 @@ from frigate.comms.event_metadata_updater import ( ) from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig -from frigate.config.classification import CustomClassificationConfig +from frigate.config.classification import ( + CustomClassificationConfig, + ObjectClassificationType, +) from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.log import redirect_output_to_logger from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels @@ -285,10 +288,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): sub_label = self.labelmap[best_id] self.detected_objects[obj_data["id"]] = score - if sub_label != "none": + if ( + self.model_config.object_config.classification_type + == ObjectClassificationType.sub_label + ): + if sub_label != "none": + self.sub_label_publisher.publish( + EventMetadataTypeEnum.sub_label, + (obj_data["id"], sub_label, score), + ) + elif ( + self.model_config.object_config.classification_type + == ObjectClassificationType.attribute + ): self.sub_label_publisher.publish( - EventMetadataTypeEnum.sub_label, - (obj_data["id"], sub_label, score), + EventMetadataTypeEnum.attribute, + (obj_data["id"], self.model_config.name, sub_label, score), ) def handle_request(self, topic, request_data): diff --git a/frigate/track/object_processing.py b/frigate/track/object_processing.py index 6409dd925..8e8836278 100644 --- a/frigate/track/object_processing.py +++ b/frigate/track/object_processing.py @@ -376,10 +376,14 @@ class TrackedObjectProcessor(threading.Thread): return True - def set_recognized_license_plate( - self, event_id: str, recognized_license_plate: str | None, score: float | None + def set_object_attribute( + self, + event_id: str, + field_name: str, + field_value: str | None, + score: float | None, ) -> None: - """Update recognized license plate for given event id.""" + """Update attribute for given event id.""" tracked_obj: TrackedObject = None for state in self.camera_states.values(): @@ -397,18 +401,18 @@ class TrackedObjectProcessor(threading.Thread): return if tracked_obj: - tracked_obj.obj_data["recognized_license_plate"] = ( - recognized_license_plate, + tracked_obj.obj_data[field_name] = ( + field_value, score, ) if event: data = event.data - data["recognized_license_plate"] = recognized_license_plate - if recognized_license_plate is None: - data["recognized_license_plate_score"] = None + data[field_name] = field_value + if field_value is None: + data[f"{field_name}_score"] = None elif score is not None: - data["recognized_license_plate_score"] = score + data[f"{field_name}_score"] = score event.data = data event.save() @@ -644,11 +648,9 @@ class TrackedObjectProcessor(threading.Thread): if topic.endswith(EventMetadataTypeEnum.sub_label.value): (event_id, sub_label, score) = payload self.set_sub_label(event_id, sub_label, score) - if topic.endswith(EventMetadataTypeEnum.recognized_license_plate.value): - (event_id, recognized_license_plate, score) = payload - self.set_recognized_license_plate( - event_id, recognized_license_plate, score - ) + if topic.endswith(EventMetadataTypeEnum.attribute.value): + (event_id, field_name, field_value, score) = payload + self.set_object_attribute(event_id, field_name, field_value, score) elif topic.endswith(EventMetadataTypeEnum.lpr_event_create.value): self.create_lpr_event(payload) elif topic.endswith(EventMetadataTypeEnum.save_lpr_snapshot.value):