From 7ea288fe32d5bb1b41d8f007e30c970129197e3f Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 16 Jul 2025 20:46:59 -0600 Subject: [PATCH] Classification train updates (#19173) * Improve model train button * Add filters for classification * Cleanup * Don't run classification on false positives * Cleanup filter * Fix icon color --- .../real_time/custom_classification.py | 3 + web/public/locales/en/components/filter.json | 6 + .../locales/en/views/classificationModel.json | 3 +- .../overlay/dialog/TrainFilterDialog.tsx | 262 ++++++++++++++++++ web/src/types/classification.ts | 8 + .../classification/ModelTrainingView.tsx | 66 ++++- 6 files changed, 338 insertions(+), 10 deletions(-) create mode 100644 web/src/components/overlay/dialog/TrainFilterDialog.tsx create mode 100644 web/src/types/classification.ts diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 05a555701..0ba8b3d17 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -225,6 +225,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.model_config.name ].value = self.classifications_per_second.eps() + if obj_data["false_positive"]: + return + if obj_data["label"] not in self.model_config.object_config.objects: return diff --git a/web/public/locales/en/components/filter.json b/web/public/locales/en/components/filter.json index 08a0ee2b2..1eaccbb69 100644 --- a/web/public/locales/en/components/filter.json +++ b/web/public/locales/en/components/filter.json @@ -1,5 +1,11 @@ { "filter": "Filter", + "classes": { + "label": "Classes", + "all": { "title": "All Classes" }, + "count_one": "{{count}} Class", + "count_other": "{{count}} Classes" + }, "labels": { "label": "Labels", "all": { diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 0af0179b9..47b2b13bf 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -3,7 +3,8 @@ "deleteClassificationAttempts": "Delete Classification Images", "renameCategory": "Rename Class", "deleteCategory": "Delete Class", - "deleteImages": "Delete Images" + "deleteImages": "Delete Images", + "trainModel": "Train Model" }, "toast": { "success": { diff --git a/web/src/components/overlay/dialog/TrainFilterDialog.tsx b/web/src/components/overlay/dialog/TrainFilterDialog.tsx new file mode 100644 index 000000000..56037ec0a --- /dev/null +++ b/web/src/components/overlay/dialog/TrainFilterDialog.tsx @@ -0,0 +1,262 @@ +import { FaFilter } from "react-icons/fa"; + +import { useEffect, useMemo, useState } from "react"; +import { PlatformAwareSheet } from "./PlatformAwareDialog"; +import { Button } from "@/components/ui/button"; +import { isDesktop, isMobile } from "react-device-detect"; +import FilterSwitch from "@/components/filter/FilterSwitch"; +import { Switch } from "@/components/ui/switch"; +import { Label } from "@/components/ui/label"; +import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu"; +import { cn } from "@/lib/utils"; +import { DualThumbSlider } from "@/components/ui/slider"; +import { Input } from "@/components/ui/input"; +import { useTranslation } from "react-i18next"; +import { TrainFilter } from "@/types/classification"; + +type TrainFilterDialogProps = { + filter?: TrainFilter; + filterValues: { + classes: string[]; + }; + onUpdateFilter: (filter: TrainFilter) => void; +}; +export default function TrainFilterDialog({ + filter, + filterValues, + onUpdateFilter, +}: TrainFilterDialogProps) { + // data + const { t } = useTranslation(["components/filter"]); + const [currentFilter, setCurrentFilter] = useState(filter ?? {}); + + useEffect(() => { + if (filter) { + setCurrentFilter(filter); + } + }, [filter]); + + // state + + const [open, setOpen] = useState(false); + + const moreFiltersSelected = useMemo( + () => + currentFilter && + (currentFilter.classes || + (currentFilter.min_score ?? 0) > 0.5 || + (currentFilter.max_score ?? 1) < 1), + [currentFilter], + ); + + const trigger = ( + + ); + const content = ( +
+ + setCurrentFilter({ ...currentFilter, classes: newClasses }) + } + /> + + setCurrentFilter({ ...currentFilter, min_score: min, max_score: max }) + } + /> + {isDesktop && } +
+ + +
+
+ ); + + return ( + { + if (!open) { + setCurrentFilter(filter ?? {}); + } + + setOpen(open); + }} + /> + ); +} + +type ClassFilterContentProps = { + allClasses?: string[]; + classes?: string[]; + updateClasses: (classes: string[] | undefined) => void; +}; +export function ClassFilterContent({ + allClasses, + classes, + updateClasses, +}: ClassFilterContentProps) { + const { t } = useTranslation(["components/filter"]); + return ( + <> +
+ +
{t("classes.label")}
+ {allClasses && ( + <> +
+ + { + if (isChecked) { + updateClasses(undefined); + } + }} + /> +
+
+ {allClasses.map((item) => ( + { + if (isChecked) { + const updatedClasses = classes ? [...classes] : []; + + updatedClasses.push(item); + updateClasses(updatedClasses); + } else { + const updatedClasses = classes ? [...classes] : []; + + // can not deselect the last item + if (updatedClasses.length > 1) { + updatedClasses.splice(updatedClasses.indexOf(item), 1); + updateClasses(updatedClasses); + } + } + }} + /> + ))} +
+ + )} +
+ + ); +} + +type ScoreFilterContentProps = { + minScore: number | undefined; + maxScore: number | undefined; + setScoreRange: (min: number | undefined, max: number | undefined) => void; +}; +export function ScoreFilterContent({ + minScore, + maxScore, + setScoreRange, +}: ScoreFilterContentProps) { + const { t } = useTranslation(["components/filter"]); + return ( +
+ +
{t("score")}
+
+ { + const value = e.target.value; + + if (value) { + setScoreRange(parseInt(value) / 100.0, maxScore ?? 1.0); + } + }} + /> + setScoreRange(min, max)} + /> + { + const value = e.target.value; + + if (value) { + setScoreRange(minScore ?? 0.5, parseInt(value) / 100.0); + } + }} + /> +
+
+ ); +} diff --git a/web/src/types/classification.ts b/web/src/types/classification.ts new file mode 100644 index 000000000..54320175a --- /dev/null +++ b/web/src/types/classification.ts @@ -0,0 +1,8 @@ +const TRAIN_FILTERS = ["class", "score"] as const; +export type TrainFilters = (typeof TRAIN_FILTERS)[number]; + +export type TrainFilter = { + classes?: string[]; + min_score?: number; + max_score?: number; +}; diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index c8c2266d5..12a474bbb 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -50,6 +50,10 @@ import { ModelState } from "@/types/ws"; import ActivityIndicator from "@/components/indicators/activity-indicator"; import { useNavigate } from "react-router-dom"; import { IoMdArrowRoundBack } from "react-icons/io"; +import { MdAutoFixHigh } from "react-icons/md"; +import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; +import useApiFilter from "@/hooks/use-api-filter"; +import { TrainFilter } from "@/types/classification"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; @@ -96,6 +100,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { [id: string]: string[]; }>(`classification/${model.name}/dataset`); + const [trainFilter, setTrainFilter] = useApiFilter(); + // image multiselect const [selectedImages, setSelectedImages] = useState([]); @@ -340,14 +346,25 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { ) : ( - +
+ + +
)} {pageToggle == "train" ? ( @@ -355,6 +372,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { model={model} classes={Object.keys(dataset || {})} trainImages={trainImages || []} + trainFilter={trainFilter} selectedImages={selectedImages} onRefresh={refreshTrain} onClickImages={onClickImages} @@ -642,6 +660,7 @@ type TrainGridProps = { model: CustomClassificationModelConfig; classes: string[]; trainImages: string[]; + trainFilter?: TrainFilter; selectedImages: string[]; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; @@ -651,6 +670,7 @@ function TrainGrid({ model, classes, trainImages, + trainFilter, selectedImages, onClickImages, onRefresh, @@ -672,8 +692,36 @@ function TrainGrid({ truePositive: rawScore >= model.threshold, }; }) + .filter((data) => { + if (!trainFilter) { + return true; + } + + if ( + trainFilter.classes && + !trainFilter.classes.includes(data.label) + ) { + return false; + } + + if ( + trainFilter.min_score && + trainFilter.min_score > data.score / 100.0 + ) { + return false; + } + + if ( + trainFilter.max_score && + trainFilter.max_score <= data.score / 100.0 + ) { + return false; + } + + return true; + }) .sort((a, b) => b.timestamp.localeCompare(a.timestamp)), - [model, trainImages], + [model, trainImages, trainFilter], ); return (