Classification train updates (#19173)

* Improve model train button

* Add filters for classification

* Cleanup

* Don't run classification on false positives

* Cleanup filter

* Fix icon color
This commit is contained in:
Nicolas Mowen 2025-07-16 20:46:59 -06:00 committed by GitHub
parent 99885b4bdc
commit 7ea288fe32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 338 additions and 10 deletions

View File

@ -225,6 +225,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.model_config.name
].value = self.classifications_per_second.eps()
if obj_data["false_positive"]:
return
if obj_data["label"] not in self.model_config.object_config.objects:
return

View File

@ -1,5 +1,11 @@
{
"filter": "Filter",
"classes": {
"label": "Classes",
"all": { "title": "All Classes" },
"count_one": "{{count}} Class",
"count_other": "{{count}} Classes"
},
"labels": {
"label": "Labels",
"all": {

View File

@ -3,7 +3,8 @@
"deleteClassificationAttempts": "Delete Classification Images",
"renameCategory": "Rename Class",
"deleteCategory": "Delete Class",
"deleteImages": "Delete Images"
"deleteImages": "Delete Images",
"trainModel": "Train Model"
},
"toast": {
"success": {

View File

@ -0,0 +1,262 @@
import { FaFilter } from "react-icons/fa";
import { useEffect, useMemo, useState } from "react";
import { PlatformAwareSheet } from "./PlatformAwareDialog";
import { Button } from "@/components/ui/button";
import { isDesktop, isMobile } from "react-device-detect";
import FilterSwitch from "@/components/filter/FilterSwitch";
import { Switch } from "@/components/ui/switch";
import { Label } from "@/components/ui/label";
import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
import { cn } from "@/lib/utils";
import { DualThumbSlider } from "@/components/ui/slider";
import { Input } from "@/components/ui/input";
import { useTranslation } from "react-i18next";
import { TrainFilter } from "@/types/classification";
type TrainFilterDialogProps = {
filter?: TrainFilter;
filterValues: {
classes: string[];
};
onUpdateFilter: (filter: TrainFilter) => void;
};
export default function TrainFilterDialog({
filter,
filterValues,
onUpdateFilter,
}: TrainFilterDialogProps) {
// data
const { t } = useTranslation(["components/filter"]);
const [currentFilter, setCurrentFilter] = useState(filter ?? {});
useEffect(() => {
if (filter) {
setCurrentFilter(filter);
}
}, [filter]);
// state
const [open, setOpen] = useState(false);
const moreFiltersSelected = useMemo(
() =>
currentFilter &&
(currentFilter.classes ||
(currentFilter.min_score ?? 0) > 0.5 ||
(currentFilter.max_score ?? 1) < 1),
[currentFilter],
);
const trigger = (
<Button
className="flex items-center gap-2"
aria-label={t("more")}
variant={moreFiltersSelected ? "select" : "default"}
>
<FaFilter
className={cn(
moreFiltersSelected ? "text-white" : "text-secondary-foreground",
)}
/>
{t("more")}
</Button>
);
const content = (
<div className="space-y-3">
<ClassFilterContent
allClasses={filterValues.classes}
classes={currentFilter.classes}
updateClasses={(newClasses) =>
setCurrentFilter({ ...currentFilter, classes: newClasses })
}
/>
<ScoreFilterContent
minScore={currentFilter.min_score}
maxScore={currentFilter.max_score}
setScoreRange={(min, max) =>
setCurrentFilter({ ...currentFilter, min_score: min, max_score: max })
}
/>
{isDesktop && <DropdownMenuSeparator />}
<div className="flex items-center justify-evenly p-2">
<Button
variant="select"
aria-label={t("button.apply", { ns: "common" })}
onClick={() => {
if (currentFilter != filter) {
onUpdateFilter(currentFilter);
}
setOpen(false);
}}
>
{t("button.apply", { ns: "common" })}
</Button>
<Button
aria-label={t("reset.label")}
onClick={() => {
setCurrentFilter((prevFilter) => ({
...prevFilter,
time_range: undefined,
zones: undefined,
sub_labels: undefined,
search_type: undefined,
min_score: undefined,
max_score: undefined,
min_speed: undefined,
max_speed: undefined,
has_snapshot: undefined,
has_clip: undefined,
recognized_license_plate: undefined,
}));
}}
>
{t("button.reset", { ns: "common" })}
</Button>
</div>
</div>
);
return (
<PlatformAwareSheet
trigger={trigger}
title={t("more")}
content={content}
contentClassName={cn(
"w-auto lg:min-w-[275px] scrollbar-container h-full overflow-auto px-4",
isMobile && "pb-20",
)}
open={open}
onOpenChange={(open) => {
if (!open) {
setCurrentFilter(filter ?? {});
}
setOpen(open);
}}
/>
);
}
type ClassFilterContentProps = {
allClasses?: string[];
classes?: string[];
updateClasses: (classes: string[] | undefined) => void;
};
export function ClassFilterContent({
allClasses,
classes,
updateClasses,
}: ClassFilterContentProps) {
const { t } = useTranslation(["components/filter"]);
return (
<>
<div className="overflow-x-hidden">
<DropdownMenuSeparator className="mb-3" />
<div className="text-lg">{t("classes.label")}</div>
{allClasses && (
<>
<div className="mb-5 mt-2.5 flex items-center justify-between">
<Label
className="mx-2 cursor-pointer text-primary"
htmlFor="allClasses"
>
{t("classes.all.title")}
</Label>
<Switch
className="ml-1"
id="allClasses"
checked={classes == undefined}
onCheckedChange={(isChecked) => {
if (isChecked) {
updateClasses(undefined);
}
}}
/>
</div>
<div className="mt-2.5 flex flex-col gap-2.5">
{allClasses.map((item) => (
<FilterSwitch
key={item}
label={item.replaceAll("_", " ")}
isChecked={classes?.includes(item) ?? false}
onCheckedChange={(isChecked) => {
if (isChecked) {
const updatedClasses = classes ? [...classes] : [];
updatedClasses.push(item);
updateClasses(updatedClasses);
} else {
const updatedClasses = classes ? [...classes] : [];
// can not deselect the last item
if (updatedClasses.length > 1) {
updatedClasses.splice(updatedClasses.indexOf(item), 1);
updateClasses(updatedClasses);
}
}
}}
/>
))}
</div>
</>
)}
</div>
</>
);
}
type ScoreFilterContentProps = {
minScore: number | undefined;
maxScore: number | undefined;
setScoreRange: (min: number | undefined, max: number | undefined) => void;
};
export function ScoreFilterContent({
minScore,
maxScore,
setScoreRange,
}: ScoreFilterContentProps) {
const { t } = useTranslation(["components/filter"]);
return (
<div className="overflow-x-hidden">
<DropdownMenuSeparator className="mb-3" />
<div className="mb-3 text-lg">{t("score")}</div>
<div className="flex items-center gap-1">
<Input
className="w-14 text-center"
inputMode="numeric"
value={Math.round((minScore ?? 0.5) * 100)}
onChange={(e) => {
const value = e.target.value;
if (value) {
setScoreRange(parseInt(value) / 100.0, maxScore ?? 1.0);
}
}}
/>
<DualThumbSlider
className="mx-2 w-full"
min={0.5}
max={1.0}
step={0.01}
value={[minScore ?? 0.5, maxScore ?? 1.0]}
onValueChange={([min, max]) => setScoreRange(min, max)}
/>
<Input
className="w-14 text-center"
inputMode="numeric"
value={Math.round((maxScore ?? 1.0) * 100)}
onChange={(e) => {
const value = e.target.value;
if (value) {
setScoreRange(minScore ?? 0.5, parseInt(value) / 100.0);
}
}}
/>
</div>
</div>
);
}

View File

@ -0,0 +1,8 @@
const TRAIN_FILTERS = ["class", "score"] as const;
export type TrainFilters = (typeof TRAIN_FILTERS)[number];
export type TrainFilter = {
classes?: string[];
min_score?: number;
max_score?: number;
};

View File

@ -50,6 +50,10 @@ import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useNavigate } from "react-router-dom";
import { IoMdArrowRoundBack } from "react-icons/io";
import { MdAutoFixHigh } from "react-icons/md";
import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
import useApiFilter from "@/hooks/use-api-filter";
import { TrainFilter } from "@/types/classification";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -96,6 +100,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
[id: string]: string[];
}>(`classification/${model.name}/dataset`);
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
// image multiselect
const [selectedImages, setSelectedImages] = useState<string[]>([]);
@ -340,14 +346,25 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button>
</div>
) : (
<Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
Train Model
{modelState == "training" && <ActivityIndicator size={20} />}
</Button>
<div className="flex flex-row gap-2">
<TrainFilterDialog
filter={trainFilter}
filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter}
/>
<Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
{modelState == "training" ? (
<ActivityIndicator size={20} />
) : (
<MdAutoFixHigh className="text-secondary-foreground" />
)}
{t("button.trainModel")}
</Button>
</div>
)}
</div>
{pageToggle == "train" ? (
@ -355,6 +372,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
model={model}
classes={Object.keys(dataset || {})}
trainImages={trainImages || []}
trainFilter={trainFilter}
selectedImages={selectedImages}
onRefresh={refreshTrain}
onClickImages={onClickImages}
@ -642,6 +660,7 @@ type TrainGridProps = {
model: CustomClassificationModelConfig;
classes: string[];
trainImages: string[];
trainFilter?: TrainFilter;
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
@ -651,6 +670,7 @@ function TrainGrid({
model,
classes,
trainImages,
trainFilter,
selectedImages,
onClickImages,
onRefresh,
@ -672,8 +692,36 @@ function TrainGrid({
truePositive: rawScore >= model.threshold,
};
})
.filter((data) => {
if (!trainFilter) {
return true;
}
if (
trainFilter.classes &&
!trainFilter.classes.includes(data.label)
) {
return false;
}
if (
trainFilter.min_score &&
trainFilter.min_score > data.score / 100.0
) {
return false;
}
if (
trainFilter.max_score &&
trainFilter.max_score <= data.score / 100.0
) {
return false;
}
return true;
})
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
[model, trainImages],
[model, trainImages, trainFilter],
);
return (