mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Classification train updates (#19173)
* Improve model train button * Add filters for classification * Cleanup * Don't run classification on false positives * Cleanup filter * Fix icon color
This commit is contained in:
parent
99885b4bdc
commit
7ea288fe32
@ -225,6 +225,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.model_config.name
|
self.model_config.name
|
||||||
].value = self.classifications_per_second.eps()
|
].value = self.classifications_per_second.eps()
|
||||||
|
|
||||||
|
if obj_data["false_positive"]:
|
||||||
|
return
|
||||||
|
|
||||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,5 +1,11 @@
|
|||||||
{
|
{
|
||||||
"filter": "Filter",
|
"filter": "Filter",
|
||||||
|
"classes": {
|
||||||
|
"label": "Classes",
|
||||||
|
"all": { "title": "All Classes" },
|
||||||
|
"count_one": "{{count}} Class",
|
||||||
|
"count_other": "{{count}} Classes"
|
||||||
|
},
|
||||||
"labels": {
|
"labels": {
|
||||||
"label": "Labels",
|
"label": "Labels",
|
||||||
"all": {
|
"all": {
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
"deleteClassificationAttempts": "Delete Classification Images",
|
"deleteClassificationAttempts": "Delete Classification Images",
|
||||||
"renameCategory": "Rename Class",
|
"renameCategory": "Rename Class",
|
||||||
"deleteCategory": "Delete Class",
|
"deleteCategory": "Delete Class",
|
||||||
"deleteImages": "Delete Images"
|
"deleteImages": "Delete Images",
|
||||||
|
"trainModel": "Train Model"
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"success": {
|
"success": {
|
||||||
|
262
web/src/components/overlay/dialog/TrainFilterDialog.tsx
Normal file
262
web/src/components/overlay/dialog/TrainFilterDialog.tsx
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
import { FaFilter } from "react-icons/fa";
|
||||||
|
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { PlatformAwareSheet } from "./PlatformAwareDialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { isDesktop, isMobile } from "react-device-detect";
|
||||||
|
import FilterSwitch from "@/components/filter/FilterSwitch";
|
||||||
|
import { Switch } from "@/components/ui/switch";
|
||||||
|
import { Label } from "@/components/ui/label";
|
||||||
|
import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { DualThumbSlider } from "@/components/ui/slider";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { useTranslation } from "react-i18next";
|
||||||
|
import { TrainFilter } from "@/types/classification";
|
||||||
|
|
||||||
|
type TrainFilterDialogProps = {
|
||||||
|
filter?: TrainFilter;
|
||||||
|
filterValues: {
|
||||||
|
classes: string[];
|
||||||
|
};
|
||||||
|
onUpdateFilter: (filter: TrainFilter) => void;
|
||||||
|
};
|
||||||
|
export default function TrainFilterDialog({
|
||||||
|
filter,
|
||||||
|
filterValues,
|
||||||
|
onUpdateFilter,
|
||||||
|
}: TrainFilterDialogProps) {
|
||||||
|
// data
|
||||||
|
const { t } = useTranslation(["components/filter"]);
|
||||||
|
const [currentFilter, setCurrentFilter] = useState(filter ?? {});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (filter) {
|
||||||
|
setCurrentFilter(filter);
|
||||||
|
}
|
||||||
|
}, [filter]);
|
||||||
|
|
||||||
|
// state
|
||||||
|
|
||||||
|
const [open, setOpen] = useState(false);
|
||||||
|
|
||||||
|
const moreFiltersSelected = useMemo(
|
||||||
|
() =>
|
||||||
|
currentFilter &&
|
||||||
|
(currentFilter.classes ||
|
||||||
|
(currentFilter.min_score ?? 0) > 0.5 ||
|
||||||
|
(currentFilter.max_score ?? 1) < 1),
|
||||||
|
[currentFilter],
|
||||||
|
);
|
||||||
|
|
||||||
|
const trigger = (
|
||||||
|
<Button
|
||||||
|
className="flex items-center gap-2"
|
||||||
|
aria-label={t("more")}
|
||||||
|
variant={moreFiltersSelected ? "select" : "default"}
|
||||||
|
>
|
||||||
|
<FaFilter
|
||||||
|
className={cn(
|
||||||
|
moreFiltersSelected ? "text-white" : "text-secondary-foreground",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
{t("more")}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
const content = (
|
||||||
|
<div className="space-y-3">
|
||||||
|
<ClassFilterContent
|
||||||
|
allClasses={filterValues.classes}
|
||||||
|
classes={currentFilter.classes}
|
||||||
|
updateClasses={(newClasses) =>
|
||||||
|
setCurrentFilter({ ...currentFilter, classes: newClasses })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<ScoreFilterContent
|
||||||
|
minScore={currentFilter.min_score}
|
||||||
|
maxScore={currentFilter.max_score}
|
||||||
|
setScoreRange={(min, max) =>
|
||||||
|
setCurrentFilter({ ...currentFilter, min_score: min, max_score: max })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
{isDesktop && <DropdownMenuSeparator />}
|
||||||
|
<div className="flex items-center justify-evenly p-2">
|
||||||
|
<Button
|
||||||
|
variant="select"
|
||||||
|
aria-label={t("button.apply", { ns: "common" })}
|
||||||
|
onClick={() => {
|
||||||
|
if (currentFilter != filter) {
|
||||||
|
onUpdateFilter(currentFilter);
|
||||||
|
}
|
||||||
|
|
||||||
|
setOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t("button.apply", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
aria-label={t("reset.label")}
|
||||||
|
onClick={() => {
|
||||||
|
setCurrentFilter((prevFilter) => ({
|
||||||
|
...prevFilter,
|
||||||
|
time_range: undefined,
|
||||||
|
zones: undefined,
|
||||||
|
sub_labels: undefined,
|
||||||
|
search_type: undefined,
|
||||||
|
min_score: undefined,
|
||||||
|
max_score: undefined,
|
||||||
|
min_speed: undefined,
|
||||||
|
max_speed: undefined,
|
||||||
|
has_snapshot: undefined,
|
||||||
|
has_clip: undefined,
|
||||||
|
recognized_license_plate: undefined,
|
||||||
|
}));
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t("button.reset", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PlatformAwareSheet
|
||||||
|
trigger={trigger}
|
||||||
|
title={t("more")}
|
||||||
|
content={content}
|
||||||
|
contentClassName={cn(
|
||||||
|
"w-auto lg:min-w-[275px] scrollbar-container h-full overflow-auto px-4",
|
||||||
|
isMobile && "pb-20",
|
||||||
|
)}
|
||||||
|
open={open}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
if (!open) {
|
||||||
|
setCurrentFilter(filter ?? {});
|
||||||
|
}
|
||||||
|
|
||||||
|
setOpen(open);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassFilterContentProps = {
|
||||||
|
allClasses?: string[];
|
||||||
|
classes?: string[];
|
||||||
|
updateClasses: (classes: string[] | undefined) => void;
|
||||||
|
};
|
||||||
|
export function ClassFilterContent({
|
||||||
|
allClasses,
|
||||||
|
classes,
|
||||||
|
updateClasses,
|
||||||
|
}: ClassFilterContentProps) {
|
||||||
|
const { t } = useTranslation(["components/filter"]);
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="overflow-x-hidden">
|
||||||
|
<DropdownMenuSeparator className="mb-3" />
|
||||||
|
<div className="text-lg">{t("classes.label")}</div>
|
||||||
|
{allClasses && (
|
||||||
|
<>
|
||||||
|
<div className="mb-5 mt-2.5 flex items-center justify-between">
|
||||||
|
<Label
|
||||||
|
className="mx-2 cursor-pointer text-primary"
|
||||||
|
htmlFor="allClasses"
|
||||||
|
>
|
||||||
|
{t("classes.all.title")}
|
||||||
|
</Label>
|
||||||
|
<Switch
|
||||||
|
className="ml-1"
|
||||||
|
id="allClasses"
|
||||||
|
checked={classes == undefined}
|
||||||
|
onCheckedChange={(isChecked) => {
|
||||||
|
if (isChecked) {
|
||||||
|
updateClasses(undefined);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="mt-2.5 flex flex-col gap-2.5">
|
||||||
|
{allClasses.map((item) => (
|
||||||
|
<FilterSwitch
|
||||||
|
key={item}
|
||||||
|
label={item.replaceAll("_", " ")}
|
||||||
|
isChecked={classes?.includes(item) ?? false}
|
||||||
|
onCheckedChange={(isChecked) => {
|
||||||
|
if (isChecked) {
|
||||||
|
const updatedClasses = classes ? [...classes] : [];
|
||||||
|
|
||||||
|
updatedClasses.push(item);
|
||||||
|
updateClasses(updatedClasses);
|
||||||
|
} else {
|
||||||
|
const updatedClasses = classes ? [...classes] : [];
|
||||||
|
|
||||||
|
// can not deselect the last item
|
||||||
|
if (updatedClasses.length > 1) {
|
||||||
|
updatedClasses.splice(updatedClasses.indexOf(item), 1);
|
||||||
|
updateClasses(updatedClasses);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type ScoreFilterContentProps = {
|
||||||
|
minScore: number | undefined;
|
||||||
|
maxScore: number | undefined;
|
||||||
|
setScoreRange: (min: number | undefined, max: number | undefined) => void;
|
||||||
|
};
|
||||||
|
export function ScoreFilterContent({
|
||||||
|
minScore,
|
||||||
|
maxScore,
|
||||||
|
setScoreRange,
|
||||||
|
}: ScoreFilterContentProps) {
|
||||||
|
const { t } = useTranslation(["components/filter"]);
|
||||||
|
return (
|
||||||
|
<div className="overflow-x-hidden">
|
||||||
|
<DropdownMenuSeparator className="mb-3" />
|
||||||
|
<div className="mb-3 text-lg">{t("score")}</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<Input
|
||||||
|
className="w-14 text-center"
|
||||||
|
inputMode="numeric"
|
||||||
|
value={Math.round((minScore ?? 0.5) * 100)}
|
||||||
|
onChange={(e) => {
|
||||||
|
const value = e.target.value;
|
||||||
|
|
||||||
|
if (value) {
|
||||||
|
setScoreRange(parseInt(value) / 100.0, maxScore ?? 1.0);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<DualThumbSlider
|
||||||
|
className="mx-2 w-full"
|
||||||
|
min={0.5}
|
||||||
|
max={1.0}
|
||||||
|
step={0.01}
|
||||||
|
value={[minScore ?? 0.5, maxScore ?? 1.0]}
|
||||||
|
onValueChange={([min, max]) => setScoreRange(min, max)}
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
className="w-14 text-center"
|
||||||
|
inputMode="numeric"
|
||||||
|
value={Math.round((maxScore ?? 1.0) * 100)}
|
||||||
|
onChange={(e) => {
|
||||||
|
const value = e.target.value;
|
||||||
|
|
||||||
|
if (value) {
|
||||||
|
setScoreRange(minScore ?? 0.5, parseInt(value) / 100.0);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
8
web/src/types/classification.ts
Normal file
8
web/src/types/classification.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
const TRAIN_FILTERS = ["class", "score"] as const;
|
||||||
|
export type TrainFilters = (typeof TRAIN_FILTERS)[number];
|
||||||
|
|
||||||
|
export type TrainFilter = {
|
||||||
|
classes?: string[];
|
||||||
|
min_score?: number;
|
||||||
|
max_score?: number;
|
||||||
|
};
|
@ -50,6 +50,10 @@ import { ModelState } from "@/types/ws";
|
|||||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
import { useNavigate } from "react-router-dom";
|
import { useNavigate } from "react-router-dom";
|
||||||
import { IoMdArrowRoundBack } from "react-icons/io";
|
import { IoMdArrowRoundBack } from "react-icons/io";
|
||||||
|
import { MdAutoFixHigh } from "react-icons/md";
|
||||||
|
import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
|
||||||
|
import useApiFilter from "@/hooks/use-api-filter";
|
||||||
|
import { TrainFilter } from "@/types/classification";
|
||||||
|
|
||||||
type ModelTrainingViewProps = {
|
type ModelTrainingViewProps = {
|
||||||
model: CustomClassificationModelConfig;
|
model: CustomClassificationModelConfig;
|
||||||
@ -96,6 +100,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
[id: string]: string[];
|
[id: string]: string[];
|
||||||
}>(`classification/${model.name}/dataset`);
|
}>(`classification/${model.name}/dataset`);
|
||||||
|
|
||||||
|
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||||
|
|
||||||
// image multiselect
|
// image multiselect
|
||||||
|
|
||||||
const [selectedImages, setSelectedImages] = useState<string[]>([]);
|
const [selectedImages, setSelectedImages] = useState<string[]>([]);
|
||||||
@ -340,14 +346,25 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Button
|
<div className="flex flex-row gap-2">
|
||||||
className="flex justify-center gap-2"
|
<TrainFilterDialog
|
||||||
onClick={trainModel}
|
filter={trainFilter}
|
||||||
disabled={modelState != "complete"}
|
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||||
>
|
onUpdateFilter={setTrainFilter}
|
||||||
Train Model
|
/>
|
||||||
{modelState == "training" && <ActivityIndicator size={20} />}
|
<Button
|
||||||
</Button>
|
className="flex justify-center gap-2"
|
||||||
|
onClick={trainModel}
|
||||||
|
disabled={modelState != "complete"}
|
||||||
|
>
|
||||||
|
{modelState == "training" ? (
|
||||||
|
<ActivityIndicator size={20} />
|
||||||
|
) : (
|
||||||
|
<MdAutoFixHigh className="text-secondary-foreground" />
|
||||||
|
)}
|
||||||
|
{t("button.trainModel")}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{pageToggle == "train" ? (
|
{pageToggle == "train" ? (
|
||||||
@ -355,6 +372,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
model={model}
|
model={model}
|
||||||
classes={Object.keys(dataset || {})}
|
classes={Object.keys(dataset || {})}
|
||||||
trainImages={trainImages || []}
|
trainImages={trainImages || []}
|
||||||
|
trainFilter={trainFilter}
|
||||||
selectedImages={selectedImages}
|
selectedImages={selectedImages}
|
||||||
onRefresh={refreshTrain}
|
onRefresh={refreshTrain}
|
||||||
onClickImages={onClickImages}
|
onClickImages={onClickImages}
|
||||||
@ -642,6 +660,7 @@ type TrainGridProps = {
|
|||||||
model: CustomClassificationModelConfig;
|
model: CustomClassificationModelConfig;
|
||||||
classes: string[];
|
classes: string[];
|
||||||
trainImages: string[];
|
trainImages: string[];
|
||||||
|
trainFilter?: TrainFilter;
|
||||||
selectedImages: string[];
|
selectedImages: string[];
|
||||||
onClickImages: (images: string[], ctrl: boolean) => void;
|
onClickImages: (images: string[], ctrl: boolean) => void;
|
||||||
onRefresh: () => void;
|
onRefresh: () => void;
|
||||||
@ -651,6 +670,7 @@ function TrainGrid({
|
|||||||
model,
|
model,
|
||||||
classes,
|
classes,
|
||||||
trainImages,
|
trainImages,
|
||||||
|
trainFilter,
|
||||||
selectedImages,
|
selectedImages,
|
||||||
onClickImages,
|
onClickImages,
|
||||||
onRefresh,
|
onRefresh,
|
||||||
@ -672,8 +692,36 @@ function TrainGrid({
|
|||||||
truePositive: rawScore >= model.threshold,
|
truePositive: rawScore >= model.threshold,
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
|
.filter((data) => {
|
||||||
|
if (!trainFilter) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
trainFilter.classes &&
|
||||||
|
!trainFilter.classes.includes(data.label)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
trainFilter.min_score &&
|
||||||
|
trainFilter.min_score > data.score / 100.0
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
trainFilter.max_score &&
|
||||||
|
trainFilter.max_score <= data.score / 100.0
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
})
|
||||||
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
|
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
|
||||||
[model, trainImages],
|
[model, trainImages, trainFilter],
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
Loading…
Reference in New Issue
Block a user