From be8ee068e260d33af24b116ff6dcedf31bcf3f27 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 5 Jun 2025 09:13:12 -0600 Subject: [PATCH] Live classification model training (#18583) * Implement model training via ZMQ and add model states to represent training * Get model updates working * Improve toasts and model state * Clean up logging * Add back in --- docker/rocm/Dockerfile | 1 + frigate/api/classification.py | 10 ++- frigate/comms/embeddings_updater.py | 16 +++-- frigate/config/logger.py | 2 + .../real_time/custom_classification.py | 68 ++++++++++++++++++- frigate/embeddings/__init__.py | 5 ++ frigate/types.py | 2 + frigate/util/classification.py | 14 +++- .../locales/en/views/classificationModel.json | 7 +- web/src/types/ws.ts | 4 +- .../classification/ModelTrainingView.tsx | 62 ++++++++++++++++- 11 files changed, 169 insertions(+), 22 deletions(-) diff --git a/docker/rocm/Dockerfile b/docker/rocm/Dockerfile index f755c9f66..3bc28cae8 100644 --- a/docker/rocm/Dockerfile +++ b/docker/rocm/Dockerfile @@ -62,6 +62,7 @@ COPY --from=rocm /opt/rocm-dist/ / ####################################################################### FROM deps-prelim AS rocm-prelim-hsa-override0 ENV HSA_ENABLE_SDMA=0 +ENV TF_ROCM_USE_IMMEDIATE_MODE=1 COPY --from=rocm-dist / / diff --git a/frigate/api/classification.py b/frigate/api/classification.py index f5acc437c..1fc17a08f 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -7,7 +7,7 @@ import shutil from typing import Any import cv2 -from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile +from fastapi import APIRouter, Depends, Request, UploadFile from fastapi.responses import JSONResponse from pathvalidate import sanitize_filename from peewee import DoesNotExist @@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig from frigate.const import CLIPS_DIR, FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event -from frigate.util.classification import train_classification_model from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -476,9 +475,7 @@ def get_classification_images(name: str): @router.post("/classification/{name}/train") -async def train_configured_model( - request: Request, name: str, background_tasks: BackgroundTasks -): +async def train_configured_model(request: Request, name: str): config: FrigateConfig = request.app.frigate_config if name not in config.classification.custom: @@ -492,7 +489,8 @@ async def train_configured_model( status_code=404, ) - background_tasks.add_task(train_classification_model, name) + context: EmbeddingsContext = request.app.embeddings + context.start_classification_training(name) return JSONResponse( content={"success": True, "message": "Started classification model training."}, status_code=200, diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index 00bc88b3d..5edb9e77d 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings" class EmbeddingsRequestEnum(Enum): + # audio + transcribe_audio = "transcribe_audio" + # custom classification + train_classification = "train_classification" + # face clear_face_classifier = "clear_face_classifier" - embed_description = "embed_description" - embed_thumbnail = "embed_thumbnail" - generate_search = "generate_search" recognize_face = "recognize_face" register_face = "register_face" reprocess_face = "reprocess_face" - reprocess_plate = "reprocess_plate" + # semantic search + embed_description = "embed_description" + embed_thumbnail = "embed_thumbnail" + generate_search = "generate_search" reindex = "reindex" - transcribe_audio = "transcribe_audio" + # LPR + reprocess_plate = "reprocess_plate" class EmbeddingsResponder: diff --git a/frigate/config/logger.py b/frigate/config/logger.py index e6e1c06d3..a3eed23d0 100644 --- a/frigate/config/logger.py +++ b/frigate/config/logger.py @@ -29,7 +29,9 @@ class LoggerConfig(FrigateBaseModel): logging.getLogger().setLevel(self.default.value.upper()) log_levels = { + "absl": LogLevel.error, "httpx": LogLevel.error, + "tensorflow": LogLevel.error, "werkzeug": LogLevel.error, "ws4py": LogLevel.error, **self.logs, diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 0e254ab0d..df4baf70b 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -3,11 +3,13 @@ import datetime import logging import os +import threading from typing import Any import cv2 import numpy as np +from frigate.comms.embeddings_updater import EmbeddingsRequestEnum from frigate.comms.event_metadata_updater import ( EventMetadataPublisher, EventMetadataTypeEnum, @@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import ( from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import CustomClassificationConfig -from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR +from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum from frigate.util.builtin import load_labels +from frigate.util.classification import train_classification_model from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics @@ -63,6 +67,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): prefill=0, ) + def __retrain_model(self) -> None: + train_classification_model(self.model_config.name) + self.__build_detector() + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.complete, + }, + ) + logger.info(f"Successfully loaded updated model for {self.model_config.name}") + def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") @@ -143,7 +159,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) def handle_request(self, topic, request_data): - return None + if topic == EmbeddingsRequestEnum.train_classification.value: + if request_data.get("model_name") == self.model_config.name: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + threading.Thread(target=self.__retrain_model).start() + return { + "success": True, + "message": f"Began training {self.model_config.name} model.", + } + else: + return None + else: + return None def expire_object(self, object_id, camera): pass @@ -182,6 +215,18 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): prefill=0, ) + def __retrain_model(self) -> None: + train_classification_model(self.model_config.name) + self.__build_detector() + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.complete, + }, + ) + logger.info(f"Successfully loaded updated model for {self.model_config.name}") + def process_frame(self, obj_data, frame): if obj_data["label"] not in self.model_config.object_config.objects: return @@ -236,7 +281,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.detected_objects[obj_data["id"]] = score def handle_request(self, topic, request_data): - return None + if topic == EmbeddingsRequestEnum.train_classification.value: + if request_data.get("model_name") == self.model_config.name: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + threading.Thread(target=self.__retrain_model).start() + return { + "success": True, + "message": f"Began training {self.model_config.name} model.", + } + else: + return None + else: + return None def expire_object(self, object_id, camera): if object_id in self.detected_objects: diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index bc1887e2c..650eefc81 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -292,6 +292,11 @@ class EmbeddingsContext: def reindex_embeddings(self) -> dict[str, Any]: return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {}) + def start_classification_training(self, model_name: str) -> dict[str, Any]: + return self.requestor.send_data( + EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name} + ) + def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: return self.requestor.send_data( EmbeddingsRequestEnum.transcribe_audio.value, {"event": event} diff --git a/frigate/types.py b/frigate/types.py index ee48cc02b..a9e27ba90 100644 --- a/frigate/types.py +++ b/frigate/types.py @@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum): downloading = "downloading" downloaded = "downloaded" error = "error" + training = "training" + complete = "complete" class TrackedObjectUpdateTypesEnum(str, Enum): diff --git a/frigate/util/classification.py b/frigate/util/classification.py index a8624870b..92da7c93e 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -1,7 +1,7 @@ """Util for classification models.""" -import logging import os +import sys import cv2 import numpy as np @@ -50,7 +50,13 @@ def train_classification_model(model_name: str) -> bool: ] ) - tf.get_logger().setLevel(logging.ERROR) + # TF and Keras are very loud with logging + # we want to avoid these logs so we + # temporarily redirect stdout / stderr + original_stdout = sys.stdout + original_stderr = sys.stderr + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") # Start with imagenet base model with 35% of channels in each layer base_model = MobileNetV2( @@ -112,3 +118,7 @@ def train_classification_model(model_name: str) -> bool: # write model with open(os.path.join(model_dir, "model.tflite"), "wb") as f: f.write(tflite_model) + + # restore original stdout / stderr + sys.stdout = original_stdout + sys.stderr = original_stderr diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index eb09ecaa0..0af0179b9 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -9,12 +9,15 @@ "success": { "deletedCategory": "Deleted Class", "deletedImage": "Deleted Images", - "categorizedImage": "Successfully Classified Image" + "categorizedImage": "Successfully Classified Image", + "trainedModel": "Successfully trained model.", + "trainingModel": "Successfully started model training." }, "error": { "deleteImageFailed": "Failed to delete: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", - "categorizeFailed": "Failed to categorize image: {{errorMessage}}" + "categorizeFailed": "Failed to categorize image: {{errorMessage}}", + "trainingFailed": "Failed to start model training: {{errorMessage}}" } }, "deleteCategory": { diff --git a/web/src/types/ws.ts b/web/src/types/ws.ts index d1e810494..06ec9ae1d 100644 --- a/web/src/types/ws.ts +++ b/web/src/types/ws.ts @@ -73,7 +73,9 @@ export type ModelState = | "not_downloaded" | "downloading" | "downloaded" - | "error"; + | "error" + | "training" + | "complete"; export type EmbeddingsReindexProgressType = { thumbnails: number; diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 53ef7fa66..1f62a4f53 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -45,6 +45,9 @@ import { toast } from "sonner"; import useSWR from "swr"; import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; import { TbCategoryPlus } from "react-icons/tb"; +import { useModelState } from "@/api/ws"; +import { ModelState } from "@/types/ws"; +import ActivityIndicator from "@/components/indicators/activity-indicator"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; @@ -54,6 +57,33 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const [page, setPage] = useState("train"); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); + // model state + + const [wasTraining, setWasTraining] = useState(false); + const { payload: lastModelState } = useModelState(model.name, true); + const modelState = useMemo(() => { + if (!lastModelState || lastModelState == "downloaded") { + return "complete"; + } + + return lastModelState; + }, [lastModelState]); + + useEffect(() => { + if (!wasTraining) { + return; + } + + if (modelState == "complete") { + toast.success(t("toast.success.trainedModel"), { + position: "top-center", + }); + setWasTraining(false); + } + // only refresh when modelState changes + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [modelState]); + // dataset const { data: trainImages, mutate: refreshTrain } = useSWR( @@ -101,8 +131,27 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { // actions const trainModel = useCallback(() => { - axios.post(`classification/${model.name}/train`); - }, [model]); + axios + .post(`classification/${model.name}/train`) + .then((resp) => { + if (resp.status == 200) { + setWasTraining(true); + toast.success(t("toast.success.trainingModel"), { + position: "top-center", + }); + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + + toast.error(t("toast.error.trainingFailed", { errorMessage }), { + position: "top-center", + }); + }); + }, [model, t]); const [deleteDialogOpen, setDeleteDialogOpen] = useState( null, @@ -274,7 +323,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { ) : ( - + )} {pageToggle == "train" ? (