mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-31 13:48:19 +02:00
Live classification model training (#18583)
* Implement model training via ZMQ and add model states to represent training * Get model updates working * Improve toasts and model state * Clean up logging * Add back in
This commit is contained in:
parent
0c66412f73
commit
37ce7e7b68
@ -7,7 +7,7 @@ import shutil
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
|
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pathvalidate import sanitize_filename
|
from pathvalidate import sanitize_filename
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig
|
|||||||
from frigate.const import CLIPS_DIR, FACE_DIR
|
from frigate.const import CLIPS_DIR, FACE_DIR
|
||||||
from frigate.embeddings import EmbeddingsContext
|
from frigate.embeddings import EmbeddingsContext
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
from frigate.util.classification import train_classification_model
|
|
||||||
from frigate.util.path import get_event_snapshot
|
from frigate.util.path import get_event_snapshot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -494,9 +493,7 @@ def get_classification_images(name: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/classification/{name}/train")
|
@router.post("/classification/{name}/train")
|
||||||
async def train_configured_model(
|
async def train_configured_model(request: Request, name: str):
|
||||||
request: Request, name: str, background_tasks: BackgroundTasks
|
|
||||||
):
|
|
||||||
config: FrigateConfig = request.app.frigate_config
|
config: FrigateConfig = request.app.frigate_config
|
||||||
|
|
||||||
if name not in config.classification.custom:
|
if name not in config.classification.custom:
|
||||||
@ -510,7 +507,8 @@ async def train_configured_model(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
)
|
)
|
||||||
|
|
||||||
background_tasks.add_task(train_classification_model, name)
|
context: EmbeddingsContext = request.app.embeddings
|
||||||
|
context.start_classification_training(name)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"success": True, "message": "Started classification model training."},
|
content={"success": True, "message": "Started classification model training."},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
|
@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequestEnum(Enum):
|
class EmbeddingsRequestEnum(Enum):
|
||||||
|
# audio
|
||||||
|
transcribe_audio = "transcribe_audio"
|
||||||
|
# custom classification
|
||||||
|
train_classification = "train_classification"
|
||||||
|
# face
|
||||||
clear_face_classifier = "clear_face_classifier"
|
clear_face_classifier = "clear_face_classifier"
|
||||||
embed_description = "embed_description"
|
|
||||||
embed_thumbnail = "embed_thumbnail"
|
|
||||||
generate_search = "generate_search"
|
|
||||||
recognize_face = "recognize_face"
|
recognize_face = "recognize_face"
|
||||||
register_face = "register_face"
|
register_face = "register_face"
|
||||||
reprocess_face = "reprocess_face"
|
reprocess_face = "reprocess_face"
|
||||||
reprocess_plate = "reprocess_plate"
|
# semantic search
|
||||||
|
embed_description = "embed_description"
|
||||||
|
embed_thumbnail = "embed_thumbnail"
|
||||||
|
generate_search = "generate_search"
|
||||||
reindex = "reindex"
|
reindex = "reindex"
|
||||||
transcribe_audio = "transcribe_audio"
|
# LPR
|
||||||
|
reprocess_plate = "reprocess_plate"
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsResponder:
|
class EmbeddingsResponder:
|
||||||
|
@ -29,7 +29,9 @@ class LoggerConfig(FrigateBaseModel):
|
|||||||
logging.getLogger().setLevel(self.default.value.upper())
|
logging.getLogger().setLevel(self.default.value.upper())
|
||||||
|
|
||||||
log_levels = {
|
log_levels = {
|
||||||
|
"absl": LogLevel.error,
|
||||||
"httpx": LogLevel.error,
|
"httpx": LogLevel.error,
|
||||||
|
"tensorflow": LogLevel.error,
|
||||||
"werkzeug": LogLevel.error,
|
"werkzeug": LogLevel.error,
|
||||||
"ws4py": LogLevel.error,
|
"ws4py": LogLevel.error,
|
||||||
**self.logs,
|
**self.logs,
|
||||||
|
@ -3,11 +3,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
|
||||||
from frigate.comms.event_metadata_updater import (
|
from frigate.comms.event_metadata_updater import (
|
||||||
EventMetadataPublisher,
|
EventMetadataPublisher,
|
||||||
EventMetadataTypeEnum,
|
EventMetadataTypeEnum,
|
||||||
@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import (
|
|||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.config.classification import CustomClassificationConfig
|
from frigate.config.classification import CustomClassificationConfig
|
||||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||||
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.builtin import load_labels
|
from frigate.util.builtin import load_labels
|
||||||
|
from frigate.util.classification import train_classification_model
|
||||||
from frigate.util.object import box_overlaps, calculate_region
|
from frigate.util.object import box_overlaps, calculate_region
|
||||||
|
|
||||||
from ..types import DataProcessorMetrics
|
from ..types import DataProcessorMetrics
|
||||||
@ -63,6 +67,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
prefill=0,
|
prefill=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __retrain_model(self) -> None:
|
||||||
|
train_classification_model(self.model_config.name)
|
||||||
|
self.__build_detector()
|
||||||
|
self.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"state": ModelStatusTypesEnum.complete,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||||
|
|
||||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||||
camera = frame_data.get("camera")
|
camera = frame_data.get("camera")
|
||||||
|
|
||||||
@ -143,6 +159,23 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
|
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||||
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
|
self.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"state": ModelStatusTypesEnum.training,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
threading.Thread(target=self.__retrain_model).start()
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Began training {self.model_config.name} model.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id, camera):
|
def expire_object(self, object_id, camera):
|
||||||
@ -182,6 +215,18 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
prefill=0,
|
prefill=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __retrain_model(self) -> None:
|
||||||
|
train_classification_model(self.model_config.name)
|
||||||
|
self.__build_detector()
|
||||||
|
self.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"state": ModelStatusTypesEnum.complete,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||||
|
|
||||||
def process_frame(self, obj_data, frame):
|
def process_frame(self, obj_data, frame):
|
||||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||||
return
|
return
|
||||||
@ -236,6 +281,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.detected_objects[obj_data["id"]] = score
|
self.detected_objects[obj_data["id"]] = score
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
|
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||||
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
|
self.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"state": ModelStatusTypesEnum.training,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
threading.Thread(target=self.__retrain_model).start()
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Began training {self.model_config.name} model.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id, camera):
|
def expire_object(self, object_id, camera):
|
||||||
|
@ -301,6 +301,11 @@ class EmbeddingsContext:
|
|||||||
def reindex_embeddings(self) -> dict[str, Any]:
|
def reindex_embeddings(self) -> dict[str, Any]:
|
||||||
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
|
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
|
||||||
|
|
||||||
|
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
||||||
|
return self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
|
||||||
|
)
|
||||||
|
|
||||||
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
||||||
return self.requestor.send_data(
|
return self.requestor.send_data(
|
||||||
EmbeddingsRequestEnum.transcribe_audio.value, {"event": event}
|
EmbeddingsRequestEnum.transcribe_audio.value, {"event": event}
|
||||||
|
@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum):
|
|||||||
downloading = "downloading"
|
downloading = "downloading"
|
||||||
downloaded = "downloaded"
|
downloaded = "downloaded"
|
||||||
error = "error"
|
error = "error"
|
||||||
|
training = "training"
|
||||||
|
complete = "complete"
|
||||||
|
|
||||||
|
|
||||||
class TrackedObjectUpdateTypesEnum(str, Enum):
|
class TrackedObjectUpdateTypesEnum(str, Enum):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Util for classification models."""
|
"""Util for classification models."""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -50,7 +50,13 @@ def train_classification_model(model_name: str) -> bool:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
tf.get_logger().setLevel(logging.ERROR)
|
# TF and Keras are very loud with logging
|
||||||
|
# we want to avoid these logs so we
|
||||||
|
# temporarily redirect stdout / stderr
|
||||||
|
original_stdout = sys.stdout
|
||||||
|
original_stderr = sys.stderr
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
sys.stderr = open(os.devnull, "w")
|
||||||
|
|
||||||
# Start with imagenet base model with 35% of channels in each layer
|
# Start with imagenet base model with 35% of channels in each layer
|
||||||
base_model = MobileNetV2(
|
base_model = MobileNetV2(
|
||||||
@ -112,3 +118,7 @@ def train_classification_model(model_name: str) -> bool:
|
|||||||
# write model
|
# write model
|
||||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
# restore original stdout / stderr
|
||||||
|
sys.stdout = original_stdout
|
||||||
|
sys.stderr = original_stderr
|
||||||
|
@ -9,12 +9,15 @@
|
|||||||
"success": {
|
"success": {
|
||||||
"deletedCategory": "Deleted Class",
|
"deletedCategory": "Deleted Class",
|
||||||
"deletedImage": "Deleted Images",
|
"deletedImage": "Deleted Images",
|
||||||
"categorizedImage": "Successfully Classified Image"
|
"categorizedImage": "Successfully Classified Image",
|
||||||
|
"trainedModel": "Successfully trained model.",
|
||||||
|
"trainingModel": "Successfully started model training."
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
|
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
|
||||||
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
|
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
||||||
|
"trainingFailed": "Failed to start model training: {{errorMessage}}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"deleteCategory": {
|
"deleteCategory": {
|
||||||
|
@ -73,7 +73,9 @@ export type ModelState =
|
|||||||
| "not_downloaded"
|
| "not_downloaded"
|
||||||
| "downloading"
|
| "downloading"
|
||||||
| "downloaded"
|
| "downloaded"
|
||||||
| "error";
|
| "error"
|
||||||
|
| "training"
|
||||||
|
| "complete";
|
||||||
|
|
||||||
export type EmbeddingsReindexProgressType = {
|
export type EmbeddingsReindexProgressType = {
|
||||||
thumbnails: number;
|
thumbnails: number;
|
||||||
|
@ -45,6 +45,9 @@ import { toast } from "sonner";
|
|||||||
import useSWR from "swr";
|
import useSWR from "swr";
|
||||||
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
||||||
import { TbCategoryPlus } from "react-icons/tb";
|
import { TbCategoryPlus } from "react-icons/tb";
|
||||||
|
import { useModelState } from "@/api/ws";
|
||||||
|
import { ModelState } from "@/types/ws";
|
||||||
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
|
|
||||||
type ModelTrainingViewProps = {
|
type ModelTrainingViewProps = {
|
||||||
model: CustomClassificationModelConfig;
|
model: CustomClassificationModelConfig;
|
||||||
@ -54,6 +57,33 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
const [page, setPage] = useState<string>("train");
|
const [page, setPage] = useState<string>("train");
|
||||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||||
|
|
||||||
|
// model state
|
||||||
|
|
||||||
|
const [wasTraining, setWasTraining] = useState(false);
|
||||||
|
const { payload: lastModelState } = useModelState(model.name, true);
|
||||||
|
const modelState = useMemo<ModelState>(() => {
|
||||||
|
if (!lastModelState || lastModelState == "downloaded") {
|
||||||
|
return "complete";
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastModelState;
|
||||||
|
}, [lastModelState]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!wasTraining) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modelState == "complete") {
|
||||||
|
toast.success(t("toast.success.trainedModel"), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
setWasTraining(false);
|
||||||
|
}
|
||||||
|
// only refresh when modelState changes
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [modelState]);
|
||||||
|
|
||||||
// dataset
|
// dataset
|
||||||
|
|
||||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
@ -101,8 +131,27 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
// actions
|
// actions
|
||||||
|
|
||||||
const trainModel = useCallback(() => {
|
const trainModel = useCallback(() => {
|
||||||
axios.post(`classification/${model.name}/train`);
|
axios
|
||||||
}, [model]);
|
.post(`classification/${model.name}/train`)
|
||||||
|
.then((resp) => {
|
||||||
|
if (resp.status == 200) {
|
||||||
|
setWasTraining(true);
|
||||||
|
toast.success(t("toast.success.trainingModel"), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
const errorMessage =
|
||||||
|
error.response?.data?.message ||
|
||||||
|
error.response?.data?.detail ||
|
||||||
|
"Unknown error";
|
||||||
|
|
||||||
|
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}, [model, t]);
|
||||||
|
|
||||||
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
|
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
|
||||||
null,
|
null,
|
||||||
@ -274,7 +323,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Button onClick={trainModel}>Train Model</Button>
|
<Button
|
||||||
|
className="flex justify-center gap-2"
|
||||||
|
onClick={trainModel}
|
||||||
|
disabled={modelState != "complete"}
|
||||||
|
>
|
||||||
|
Train Model
|
||||||
|
{modelState == "training" && <ActivityIndicator size={20} />}
|
||||||
|
</Button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{pageToggle == "train" ? (
|
{pageToggle == "train" ? (
|
||||||
|
Loading…
Reference in New Issue
Block a user