From 0f4cac736a146729a2cdf049d40c020a04d207f5 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Fri, 27 Jun 2025 08:35:02 -0600 Subject: [PATCH] Improve classification UI (#18910) * Move threhsold to base model config * Improve score handling * Add back button --- frigate/config/classification.py | 6 +-- .../real_time/custom_classification.py | 6 ++- web/src/types/frigateConfig.ts | 2 +- .../classification/ModelTrainingView.tsx | 48 ++++++++++++++----- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/frigate/config/classification.py b/frigate/config/classification.py index c0584ce63..6430c96fa 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -59,9 +59,6 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel): crop: list[int, int, int, int] = Field( title="Crop of image frame on this camera to run classification on." ) - threshold: float = Field( - default=0.8, title="Classification score threshold to change the state." - ) class CustomClassificationStateConfig(FrigateBaseModel): @@ -86,6 +83,9 @@ class CustomClassificationObjectConfig(FrigateBaseModel): class CustomClassificationConfig(FrigateBaseModel): enabled: bool = Field(default=True, title="Enable running the model.") name: str | None = Field(default=None, title="Name of classification model.") + threshold: float = Field( + default=0.8, title="Classification score threshold to change the state." + ) object_config: CustomClassificationObjectConfig | None = Field(default=None) state_config: CustomClassificationStateConfig | None = Field(default=None) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1e2b91a2d..05a555701 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -152,7 +152,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): score, ) - if score >= camera_config.threshold: + if score >= self.model_config.threshold: self.requestor.send_data( f"{camera}/classification/{self.model_config.name}", self.labelmap[best_id], @@ -271,6 +271,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): score, ) + if score < self.model_config.threshold: + logger.debug(f"Score {score} is less than threshold.") + return + if score <= previous_score: logger.debug(f"Score {score} is worse than previous score {previous_score}") return diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index 3ccc5b06d..7d4c27794 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -282,6 +282,7 @@ export type CameraStreamingSettings = { export type CustomClassificationModelConfig = { enabled: boolean; name: string; + threshold: number; object_config: null | { objects: string[]; }; @@ -289,7 +290,6 @@ export type CustomClassificationModelConfig = { cameras: { [cameraName: string]: { crop: [number, number, number, number]; - threshold: number; }; }; motion: boolean; diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index ea265bd51..14de1a118 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -48,12 +48,15 @@ import { TbCategoryPlus } from "react-icons/tb"; import { useModelState } from "@/api/ws"; import { ModelState } from "@/types/ws"; import ActivityIndicator from "@/components/indicators/activity-indicator"; +import { useNavigate } from "react-router-dom"; +import { IoMdArrowRoundBack } from "react-icons/io"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; }; export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { t } = useTranslation(["views/classificationModel"]); + const navigate = useNavigate(); const [page, setPage] = useState("train"); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); @@ -294,14 +297,28 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
- {}} - /> +
+ + {}} + /> +
{selectedImages?.length > 0 ? (
@@ -640,15 +657,17 @@ function TrainGrid({ trainImages .map((raw) => { const parts = raw.replaceAll(".webp", "").split("-"); + const rawScore = Number.parseFloat(parts[2]); return { raw, timestamp: parts[0], label: parts[1], - score: Number.parseFloat(parts[2]) * 100, + score: rawScore * 100, + truePositive: rawScore >= model.threshold, }; }) .sort((a, b) => b.timestamp.localeCompare(a.timestamp)), - [trainImages], + [model, trainImages], ); return ( @@ -684,7 +703,14 @@ function TrainGrid({
{data.label.replaceAll("_", " ")}
-
{data.score}%
+
+ {data.score}% +