mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Classification train updates (#19173)
* Improve model train button * Add filters for classification * Cleanup * Don't run classification on false positives * Cleanup filter * Fix icon color
This commit is contained in:
parent
99885b4bdc
commit
7ea288fe32
@ -225,6 +225,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
|
||||
if obj_data["false_positive"]:
|
||||
return
|
||||
|
||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||
return
|
||||
|
||||
|
@ -1,5 +1,11 @@
|
||||
{
|
||||
"filter": "Filter",
|
||||
"classes": {
|
||||
"label": "Classes",
|
||||
"all": { "title": "All Classes" },
|
||||
"count_one": "{{count}} Class",
|
||||
"count_other": "{{count}} Classes"
|
||||
},
|
||||
"labels": {
|
||||
"label": "Labels",
|
||||
"all": {
|
||||
|
@ -3,7 +3,8 @@
|
||||
"deleteClassificationAttempts": "Delete Classification Images",
|
||||
"renameCategory": "Rename Class",
|
||||
"deleteCategory": "Delete Class",
|
||||
"deleteImages": "Delete Images"
|
||||
"deleteImages": "Delete Images",
|
||||
"trainModel": "Train Model"
|
||||
},
|
||||
"toast": {
|
||||
"success": {
|
||||
|
262
web/src/components/overlay/dialog/TrainFilterDialog.tsx
Normal file
262
web/src/components/overlay/dialog/TrainFilterDialog.tsx
Normal file
@ -0,0 +1,262 @@
|
||||
import { FaFilter } from "react-icons/fa";
|
||||
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { PlatformAwareSheet } from "./PlatformAwareDialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { isDesktop, isMobile } from "react-device-detect";
|
||||
import FilterSwitch from "@/components/filter/FilterSwitch";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DualThumbSlider } from "@/components/ui/slider";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { TrainFilter } from "@/types/classification";
|
||||
|
||||
type TrainFilterDialogProps = {
|
||||
filter?: TrainFilter;
|
||||
filterValues: {
|
||||
classes: string[];
|
||||
};
|
||||
onUpdateFilter: (filter: TrainFilter) => void;
|
||||
};
|
||||
export default function TrainFilterDialog({
|
||||
filter,
|
||||
filterValues,
|
||||
onUpdateFilter,
|
||||
}: TrainFilterDialogProps) {
|
||||
// data
|
||||
const { t } = useTranslation(["components/filter"]);
|
||||
const [currentFilter, setCurrentFilter] = useState(filter ?? {});
|
||||
|
||||
useEffect(() => {
|
||||
if (filter) {
|
||||
setCurrentFilter(filter);
|
||||
}
|
||||
}, [filter]);
|
||||
|
||||
// state
|
||||
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const moreFiltersSelected = useMemo(
|
||||
() =>
|
||||
currentFilter &&
|
||||
(currentFilter.classes ||
|
||||
(currentFilter.min_score ?? 0) > 0.5 ||
|
||||
(currentFilter.max_score ?? 1) < 1),
|
||||
[currentFilter],
|
||||
);
|
||||
|
||||
const trigger = (
|
||||
<Button
|
||||
className="flex items-center gap-2"
|
||||
aria-label={t("more")}
|
||||
variant={moreFiltersSelected ? "select" : "default"}
|
||||
>
|
||||
<FaFilter
|
||||
className={cn(
|
||||
moreFiltersSelected ? "text-white" : "text-secondary-foreground",
|
||||
)}
|
||||
/>
|
||||
{t("more")}
|
||||
</Button>
|
||||
);
|
||||
const content = (
|
||||
<div className="space-y-3">
|
||||
<ClassFilterContent
|
||||
allClasses={filterValues.classes}
|
||||
classes={currentFilter.classes}
|
||||
updateClasses={(newClasses) =>
|
||||
setCurrentFilter({ ...currentFilter, classes: newClasses })
|
||||
}
|
||||
/>
|
||||
<ScoreFilterContent
|
||||
minScore={currentFilter.min_score}
|
||||
maxScore={currentFilter.max_score}
|
||||
setScoreRange={(min, max) =>
|
||||
setCurrentFilter({ ...currentFilter, min_score: min, max_score: max })
|
||||
}
|
||||
/>
|
||||
{isDesktop && <DropdownMenuSeparator />}
|
||||
<div className="flex items-center justify-evenly p-2">
|
||||
<Button
|
||||
variant="select"
|
||||
aria-label={t("button.apply", { ns: "common" })}
|
||||
onClick={() => {
|
||||
if (currentFilter != filter) {
|
||||
onUpdateFilter(currentFilter);
|
||||
}
|
||||
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
{t("button.apply", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
aria-label={t("reset.label")}
|
||||
onClick={() => {
|
||||
setCurrentFilter((prevFilter) => ({
|
||||
...prevFilter,
|
||||
time_range: undefined,
|
||||
zones: undefined,
|
||||
sub_labels: undefined,
|
||||
search_type: undefined,
|
||||
min_score: undefined,
|
||||
max_score: undefined,
|
||||
min_speed: undefined,
|
||||
max_speed: undefined,
|
||||
has_snapshot: undefined,
|
||||
has_clip: undefined,
|
||||
recognized_license_plate: undefined,
|
||||
}));
|
||||
}}
|
||||
>
|
||||
{t("button.reset", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<PlatformAwareSheet
|
||||
trigger={trigger}
|
||||
title={t("more")}
|
||||
content={content}
|
||||
contentClassName={cn(
|
||||
"w-auto lg:min-w-[275px] scrollbar-container h-full overflow-auto px-4",
|
||||
isMobile && "pb-20",
|
||||
)}
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
setCurrentFilter(filter ?? {});
|
||||
}
|
||||
|
||||
setOpen(open);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
type ClassFilterContentProps = {
|
||||
allClasses?: string[];
|
||||
classes?: string[];
|
||||
updateClasses: (classes: string[] | undefined) => void;
|
||||
};
|
||||
export function ClassFilterContent({
|
||||
allClasses,
|
||||
classes,
|
||||
updateClasses,
|
||||
}: ClassFilterContentProps) {
|
||||
const { t } = useTranslation(["components/filter"]);
|
||||
return (
|
||||
<>
|
||||
<div className="overflow-x-hidden">
|
||||
<DropdownMenuSeparator className="mb-3" />
|
||||
<div className="text-lg">{t("classes.label")}</div>
|
||||
{allClasses && (
|
||||
<>
|
||||
<div className="mb-5 mt-2.5 flex items-center justify-between">
|
||||
<Label
|
||||
className="mx-2 cursor-pointer text-primary"
|
||||
htmlFor="allClasses"
|
||||
>
|
||||
{t("classes.all.title")}
|
||||
</Label>
|
||||
<Switch
|
||||
className="ml-1"
|
||||
id="allClasses"
|
||||
checked={classes == undefined}
|
||||
onCheckedChange={(isChecked) => {
|
||||
if (isChecked) {
|
||||
updateClasses(undefined);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-2.5 flex flex-col gap-2.5">
|
||||
{allClasses.map((item) => (
|
||||
<FilterSwitch
|
||||
key={item}
|
||||
label={item.replaceAll("_", " ")}
|
||||
isChecked={classes?.includes(item) ?? false}
|
||||
onCheckedChange={(isChecked) => {
|
||||
if (isChecked) {
|
||||
const updatedClasses = classes ? [...classes] : [];
|
||||
|
||||
updatedClasses.push(item);
|
||||
updateClasses(updatedClasses);
|
||||
} else {
|
||||
const updatedClasses = classes ? [...classes] : [];
|
||||
|
||||
// can not deselect the last item
|
||||
if (updatedClasses.length > 1) {
|
||||
updatedClasses.splice(updatedClasses.indexOf(item), 1);
|
||||
updateClasses(updatedClasses);
|
||||
}
|
||||
}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
type ScoreFilterContentProps = {
|
||||
minScore: number | undefined;
|
||||
maxScore: number | undefined;
|
||||
setScoreRange: (min: number | undefined, max: number | undefined) => void;
|
||||
};
|
||||
export function ScoreFilterContent({
|
||||
minScore,
|
||||
maxScore,
|
||||
setScoreRange,
|
||||
}: ScoreFilterContentProps) {
|
||||
const { t } = useTranslation(["components/filter"]);
|
||||
return (
|
||||
<div className="overflow-x-hidden">
|
||||
<DropdownMenuSeparator className="mb-3" />
|
||||
<div className="mb-3 text-lg">{t("score")}</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<Input
|
||||
className="w-14 text-center"
|
||||
inputMode="numeric"
|
||||
value={Math.round((minScore ?? 0.5) * 100)}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
|
||||
if (value) {
|
||||
setScoreRange(parseInt(value) / 100.0, maxScore ?? 1.0);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<DualThumbSlider
|
||||
className="mx-2 w-full"
|
||||
min={0.5}
|
||||
max={1.0}
|
||||
step={0.01}
|
||||
value={[minScore ?? 0.5, maxScore ?? 1.0]}
|
||||
onValueChange={([min, max]) => setScoreRange(min, max)}
|
||||
/>
|
||||
<Input
|
||||
className="w-14 text-center"
|
||||
inputMode="numeric"
|
||||
value={Math.round((maxScore ?? 1.0) * 100)}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
|
||||
if (value) {
|
||||
setScoreRange(minScore ?? 0.5, parseInt(value) / 100.0);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
8
web/src/types/classification.ts
Normal file
8
web/src/types/classification.ts
Normal file
@ -0,0 +1,8 @@
|
||||
const TRAIN_FILTERS = ["class", "score"] as const;
|
||||
export type TrainFilters = (typeof TRAIN_FILTERS)[number];
|
||||
|
||||
export type TrainFilter = {
|
||||
classes?: string[];
|
||||
min_score?: number;
|
||||
max_score?: number;
|
||||
};
|
@ -50,6 +50,10 @@ import { ModelState } from "@/types/ws";
|
||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { IoMdArrowRoundBack } from "react-icons/io";
|
||||
import { MdAutoFixHigh } from "react-icons/md";
|
||||
import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
|
||||
import useApiFilter from "@/hooks/use-api-filter";
|
||||
import { TrainFilter } from "@/types/classification";
|
||||
|
||||
type ModelTrainingViewProps = {
|
||||
model: CustomClassificationModelConfig;
|
||||
@ -96,6 +100,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
[id: string]: string[];
|
||||
}>(`classification/${model.name}/dataset`);
|
||||
|
||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||
|
||||
// image multiselect
|
||||
|
||||
const [selectedImages, setSelectedImages] = useState<string[]>([]);
|
||||
@ -340,14 +346,25 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
className="flex justify-center gap-2"
|
||||
onClick={trainModel}
|
||||
disabled={modelState != "complete"}
|
||||
>
|
||||
Train Model
|
||||
{modelState == "training" && <ActivityIndicator size={20} />}
|
||||
</Button>
|
||||
<div className="flex flex-row gap-2">
|
||||
<TrainFilterDialog
|
||||
filter={trainFilter}
|
||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||
onUpdateFilter={setTrainFilter}
|
||||
/>
|
||||
<Button
|
||||
className="flex justify-center gap-2"
|
||||
onClick={trainModel}
|
||||
disabled={modelState != "complete"}
|
||||
>
|
||||
{modelState == "training" ? (
|
||||
<ActivityIndicator size={20} />
|
||||
) : (
|
||||
<MdAutoFixHigh className="text-secondary-foreground" />
|
||||
)}
|
||||
{t("button.trainModel")}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{pageToggle == "train" ? (
|
||||
@ -355,6 +372,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
model={model}
|
||||
classes={Object.keys(dataset || {})}
|
||||
trainImages={trainImages || []}
|
||||
trainFilter={trainFilter}
|
||||
selectedImages={selectedImages}
|
||||
onRefresh={refreshTrain}
|
||||
onClickImages={onClickImages}
|
||||
@ -642,6 +660,7 @@ type TrainGridProps = {
|
||||
model: CustomClassificationModelConfig;
|
||||
classes: string[];
|
||||
trainImages: string[];
|
||||
trainFilter?: TrainFilter;
|
||||
selectedImages: string[];
|
||||
onClickImages: (images: string[], ctrl: boolean) => void;
|
||||
onRefresh: () => void;
|
||||
@ -651,6 +670,7 @@ function TrainGrid({
|
||||
model,
|
||||
classes,
|
||||
trainImages,
|
||||
trainFilter,
|
||||
selectedImages,
|
||||
onClickImages,
|
||||
onRefresh,
|
||||
@ -672,8 +692,36 @@ function TrainGrid({
|
||||
truePositive: rawScore >= model.threshold,
|
||||
};
|
||||
})
|
||||
.filter((data) => {
|
||||
if (!trainFilter) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (
|
||||
trainFilter.classes &&
|
||||
!trainFilter.classes.includes(data.label)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
trainFilter.min_score &&
|
||||
trainFilter.min_score > data.score / 100.0
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
trainFilter.max_score &&
|
||||
trainFilter.max_score <= data.score / 100.0
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
})
|
||||
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
|
||||
[model, trainImages],
|
||||
[model, trainImages, trainFilter],
|
||||
);
|
||||
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user