Live classification model training (#18583)

* Implement model training via ZMQ and add model states to represent training

* Get model updates working

* Improve toasts and model state

* Clean up logging

* Add back in
This commit is contained in:
Nicolas Mowen 2025-06-05 09:13:12 -06:00 committed by GitHub
parent 85d721eb6b
commit be8ee068e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 169 additions and 22 deletions

View File

@ -62,6 +62,7 @@ COPY --from=rocm /opt/rocm-dist/ /
#######################################################################
FROM deps-prelim AS rocm-prelim-hsa-override0
ENV HSA_ENABLE_SDMA=0
ENV TF_ROCM_USE_IMMEDIATE_MODE=1
COPY --from=rocm-dist / /

View File

@ -7,7 +7,7 @@ import shutil
from typing import Any
import cv2
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
from fastapi import APIRouter, Depends, Request, UploadFile
from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename
from peewee import DoesNotExist
@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__)
@ -476,9 +475,7 @@ def get_classification_images(name: str):
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
):
async def train_configured_model(request: Request, name: str):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
@ -492,7 +489,8 @@ async def train_configured_model(
status_code=404,
)
background_tasks.add_task(train_classification_model, name)
context: EmbeddingsContext = request.app.embeddings
context.start_classification_training(name)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,

View File

@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
class EmbeddingsRequestEnum(Enum):
# audio
transcribe_audio = "transcribe_audio"
# custom classification
train_classification = "train_classification"
# face
clear_face_classifier = "clear_face_classifier"
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
recognize_face = "recognize_face"
register_face = "register_face"
reprocess_face = "reprocess_face"
reprocess_plate = "reprocess_plate"
# semantic search
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
reindex = "reindex"
transcribe_audio = "transcribe_audio"
# LPR
reprocess_plate = "reprocess_plate"
class EmbeddingsResponder:

View File

@ -29,7 +29,9 @@ class LoggerConfig(FrigateBaseModel):
logging.getLogger().setLevel(self.default.value.upper())
log_levels = {
"absl": LogLevel.error,
"httpx": LogLevel.error,
"tensorflow": LogLevel.error,
"werkzeug": LogLevel.error,
"ws4py": LogLevel.error,
**self.logs,

View File

@ -3,11 +3,13 @@
import datetime
import logging
import os
import threading
from typing import Any
import cv2
import numpy as np
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
EventMetadataTypeEnum,
@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import load_labels
from frigate.util.classification import train_classification_model
from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics
@ -63,6 +67,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
prefill=0,
)
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
@ -143,7 +159,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
def handle_request(self, topic, request_data):
return None
if topic == EmbeddingsRequestEnum.train_classification.value:
if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
threading.Thread(target=self.__retrain_model).start()
return {
"success": True,
"message": f"Began training {self.model_config.name} model.",
}
else:
return None
else:
return None
def expire_object(self, object_id, camera):
pass
@ -182,6 +215,18 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
prefill=0,
)
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, obj_data, frame):
if obj_data["label"] not in self.model_config.object_config.objects:
return
@ -236,7 +281,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.detected_objects[obj_data["id"]] = score
def handle_request(self, topic, request_data):
return None
if topic == EmbeddingsRequestEnum.train_classification.value:
if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
threading.Thread(target=self.__retrain_model).start()
return {
"success": True,
"message": f"Began training {self.model_config.name} model.",
}
else:
return None
else:
return None
def expire_object(self, object_id, camera):
if object_id in self.detected_objects:

View File

@ -292,6 +292,11 @@ class EmbeddingsContext:
def reindex_embeddings(self) -> dict[str, Any]:
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
def start_classification_training(self, model_name: str) -> dict[str, Any]:
return self.requestor.send_data(
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
)
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
return self.requestor.send_data(
EmbeddingsRequestEnum.transcribe_audio.value, {"event": event}

View File

@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum):
downloading = "downloading"
downloaded = "downloaded"
error = "error"
training = "training"
complete = "complete"
class TrackedObjectUpdateTypesEnum(str, Enum):

View File

@ -1,7 +1,7 @@
"""Util for classification models."""
import logging
import os
import sys
import cv2
import numpy as np
@ -50,7 +50,13 @@ def train_classification_model(model_name: str) -> bool:
]
)
tf.get_logger().setLevel(logging.ERROR)
# TF and Keras are very loud with logging
# we want to avoid these logs so we
# temporarily redirect stdout / stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
@ -112,3 +118,7 @@ def train_classification_model(model_name: str) -> bool:
# write model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
f.write(tflite_model)
# restore original stdout / stderr
sys.stdout = original_stdout
sys.stderr = original_stderr

View File

@ -9,12 +9,15 @@
"success": {
"deletedCategory": "Deleted Class",
"deletedImage": "Deleted Images",
"categorizedImage": "Successfully Classified Image"
"categorizedImage": "Successfully Classified Image",
"trainedModel": "Successfully trained model.",
"trainingModel": "Successfully started model training."
},
"error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Failed to start model training: {{errorMessage}}"
}
},
"deleteCategory": {

View File

@ -73,7 +73,9 @@ export type ModelState =
| "not_downloaded"
| "downloading"
| "downloaded"
| "error";
| "error"
| "training"
| "complete";
export type EmbeddingsReindexProgressType = {
thumbnails: number;

View File

@ -45,6 +45,9 @@ import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -54,6 +57,33 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
// model state
const [wasTraining, setWasTraining] = useState(false);
const { payload: lastModelState } = useModelState(model.name, true);
const modelState = useMemo<ModelState>(() => {
if (!lastModelState || lastModelState == "downloaded") {
return "complete";
}
return lastModelState;
}, [lastModelState]);
useEffect(() => {
if (!wasTraining) {
return;
}
if (modelState == "complete") {
toast.success(t("toast.success.trainedModel"), {
position: "top-center",
});
setWasTraining(false);
}
// only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [modelState]);
// dataset
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
@ -101,8 +131,27 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
// actions
const trainModel = useCallback(() => {
axios.post(`classification/${model.name}/train`);
}, [model]);
axios
.post(`classification/${model.name}/train`)
.then((resp) => {
if (resp.status == 200) {
setWasTraining(true);
toast.success(t("toast.success.trainingModel"), {
position: "top-center",
});
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
position: "top-center",
});
});
}, [model, t]);
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
null,
@ -274,7 +323,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button>
</div>
) : (
<Button onClick={trainModel}>Train Model</Button>
<Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
Train Model
{modelState == "training" && <ActivityIndicator size={20} />}
</Button>
)}
</div>
{pageToggle == "train" ? (