diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py
index 05a555701..0ba8b3d17 100644
--- a/frigate/data_processing/real_time/custom_classification.py
+++ b/frigate/data_processing/real_time/custom_classification.py
@@ -225,6 +225,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.model_config.name
].value = self.classifications_per_second.eps()
+ if obj_data["false_positive"]:
+ return
+
if obj_data["label"] not in self.model_config.object_config.objects:
return
diff --git a/web/public/locales/en/components/filter.json b/web/public/locales/en/components/filter.json
index 08a0ee2b2..1eaccbb69 100644
--- a/web/public/locales/en/components/filter.json
+++ b/web/public/locales/en/components/filter.json
@@ -1,5 +1,11 @@
{
"filter": "Filter",
+ "classes": {
+ "label": "Classes",
+ "all": { "title": "All Classes" },
+ "count_one": "{{count}} Class",
+ "count_other": "{{count}} Classes"
+ },
"labels": {
"label": "Labels",
"all": {
diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json
index 0af0179b9..47b2b13bf 100644
--- a/web/public/locales/en/views/classificationModel.json
+++ b/web/public/locales/en/views/classificationModel.json
@@ -3,7 +3,8 @@
"deleteClassificationAttempts": "Delete Classification Images",
"renameCategory": "Rename Class",
"deleteCategory": "Delete Class",
- "deleteImages": "Delete Images"
+ "deleteImages": "Delete Images",
+ "trainModel": "Train Model"
},
"toast": {
"success": {
diff --git a/web/src/components/overlay/dialog/TrainFilterDialog.tsx b/web/src/components/overlay/dialog/TrainFilterDialog.tsx
new file mode 100644
index 000000000..56037ec0a
--- /dev/null
+++ b/web/src/components/overlay/dialog/TrainFilterDialog.tsx
@@ -0,0 +1,262 @@
+import { FaFilter } from "react-icons/fa";
+
+import { useEffect, useMemo, useState } from "react";
+import { PlatformAwareSheet } from "./PlatformAwareDialog";
+import { Button } from "@/components/ui/button";
+import { isDesktop, isMobile } from "react-device-detect";
+import FilterSwitch from "@/components/filter/FilterSwitch";
+import { Switch } from "@/components/ui/switch";
+import { Label } from "@/components/ui/label";
+import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
+import { cn } from "@/lib/utils";
+import { DualThumbSlider } from "@/components/ui/slider";
+import { Input } from "@/components/ui/input";
+import { useTranslation } from "react-i18next";
+import { TrainFilter } from "@/types/classification";
+
+type TrainFilterDialogProps = {
+ filter?: TrainFilter;
+ filterValues: {
+ classes: string[];
+ };
+ onUpdateFilter: (filter: TrainFilter) => void;
+};
+export default function TrainFilterDialog({
+ filter,
+ filterValues,
+ onUpdateFilter,
+}: TrainFilterDialogProps) {
+ // data
+ const { t } = useTranslation(["components/filter"]);
+ const [currentFilter, setCurrentFilter] = useState(filter ?? {});
+
+ useEffect(() => {
+ if (filter) {
+ setCurrentFilter(filter);
+ }
+ }, [filter]);
+
+ // state
+
+ const [open, setOpen] = useState(false);
+
+ const moreFiltersSelected = useMemo(
+ () =>
+ currentFilter &&
+ (currentFilter.classes ||
+ (currentFilter.min_score ?? 0) > 0.5 ||
+ (currentFilter.max_score ?? 1) < 1),
+ [currentFilter],
+ );
+
+ const trigger = (
+
+ );
+ const content = (
+
+
+ setCurrentFilter({ ...currentFilter, classes: newClasses })
+ }
+ />
+
+ setCurrentFilter({ ...currentFilter, min_score: min, max_score: max })
+ }
+ />
+ {isDesktop && }
+
+
+
+
+
+ );
+
+ return (
+ {
+ if (!open) {
+ setCurrentFilter(filter ?? {});
+ }
+
+ setOpen(open);
+ }}
+ />
+ );
+}
+
+type ClassFilterContentProps = {
+ allClasses?: string[];
+ classes?: string[];
+ updateClasses: (classes: string[] | undefined) => void;
+};
+export function ClassFilterContent({
+ allClasses,
+ classes,
+ updateClasses,
+}: ClassFilterContentProps) {
+ const { t } = useTranslation(["components/filter"]);
+ return (
+ <>
+
+
+
{t("classes.label")}
+ {allClasses && (
+ <>
+
+
+ {
+ if (isChecked) {
+ updateClasses(undefined);
+ }
+ }}
+ />
+
+
+ {allClasses.map((item) => (
+ {
+ if (isChecked) {
+ const updatedClasses = classes ? [...classes] : [];
+
+ updatedClasses.push(item);
+ updateClasses(updatedClasses);
+ } else {
+ const updatedClasses = classes ? [...classes] : [];
+
+ // can not deselect the last item
+ if (updatedClasses.length > 1) {
+ updatedClasses.splice(updatedClasses.indexOf(item), 1);
+ updateClasses(updatedClasses);
+ }
+ }
+ }}
+ />
+ ))}
+
+ >
+ )}
+
+ >
+ );
+}
+
+type ScoreFilterContentProps = {
+ minScore: number | undefined;
+ maxScore: number | undefined;
+ setScoreRange: (min: number | undefined, max: number | undefined) => void;
+};
+export function ScoreFilterContent({
+ minScore,
+ maxScore,
+ setScoreRange,
+}: ScoreFilterContentProps) {
+ const { t } = useTranslation(["components/filter"]);
+ return (
+
+
+
{t("score")}
+
+ {
+ const value = e.target.value;
+
+ if (value) {
+ setScoreRange(parseInt(value) / 100.0, maxScore ?? 1.0);
+ }
+ }}
+ />
+ setScoreRange(min, max)}
+ />
+ {
+ const value = e.target.value;
+
+ if (value) {
+ setScoreRange(minScore ?? 0.5, parseInt(value) / 100.0);
+ }
+ }}
+ />
+
+
+ );
+}
diff --git a/web/src/types/classification.ts b/web/src/types/classification.ts
new file mode 100644
index 000000000..54320175a
--- /dev/null
+++ b/web/src/types/classification.ts
@@ -0,0 +1,8 @@
+const TRAIN_FILTERS = ["class", "score"] as const;
+export type TrainFilters = (typeof TRAIN_FILTERS)[number];
+
+export type TrainFilter = {
+ classes?: string[];
+ min_score?: number;
+ max_score?: number;
+};
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx
index c8c2266d5..12a474bbb 100644
--- a/web/src/views/classification/ModelTrainingView.tsx
+++ b/web/src/views/classification/ModelTrainingView.tsx
@@ -50,6 +50,10 @@ import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useNavigate } from "react-router-dom";
import { IoMdArrowRoundBack } from "react-icons/io";
+import { MdAutoFixHigh } from "react-icons/md";
+import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
+import useApiFilter from "@/hooks/use-api-filter";
+import { TrainFilter } from "@/types/classification";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@@ -96,6 +100,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
[id: string]: string[];
}>(`classification/${model.name}/dataset`);
+ const [trainFilter, setTrainFilter] = useApiFilter();
+
// image multiselect
const [selectedImages, setSelectedImages] = useState([]);
@@ -340,14 +346,25 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
) : (
-
+
+
+
+
)}
{pageToggle == "train" ? (
@@ -355,6 +372,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
model={model}
classes={Object.keys(dataset || {})}
trainImages={trainImages || []}
+ trainFilter={trainFilter}
selectedImages={selectedImages}
onRefresh={refreshTrain}
onClickImages={onClickImages}
@@ -642,6 +660,7 @@ type TrainGridProps = {
model: CustomClassificationModelConfig;
classes: string[];
trainImages: string[];
+ trainFilter?: TrainFilter;
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
@@ -651,6 +670,7 @@ function TrainGrid({
model,
classes,
trainImages,
+ trainFilter,
selectedImages,
onClickImages,
onRefresh,
@@ -672,8 +692,36 @@ function TrainGrid({
truePositive: rawScore >= model.threshold,
};
})
+ .filter((data) => {
+ if (!trainFilter) {
+ return true;
+ }
+
+ if (
+ trainFilter.classes &&
+ !trainFilter.classes.includes(data.label)
+ ) {
+ return false;
+ }
+
+ if (
+ trainFilter.min_score &&
+ trainFilter.min_score > data.score / 100.0
+ ) {
+ return false;
+ }
+
+ if (
+ trainFilter.max_score &&
+ trainFilter.max_score <= data.score / 100.0
+ ) {
+ return false;
+ }
+
+ return true;
+ })
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
- [model, trainImages],
+ [model, trainImages, trainFilter],
);
return (