mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-07 02:18:07 +01:00
Implement Wizard for Creating Classification Models (#20622)
* Implement extraction of images for classification state models * Add object classification dataset preparation * Add first step wizard * Update i18n * Add state classification image selection step * Improve box handling * Add object selector * Improve object cropping implementation * Fix state classification selection * Finalize training and image selection step * Cleanup * Design optimizations * Cleanup mobile styling * Update no models screen * Cleanups and fixes * Fix bugs * Improve model training and creation process * Cleanup * Dynamically add metrics for new model * Add loading when hitting continue * Improve image selection mechanism * Remove unused translation keys * Adjust wording * Add retry button for image generation * Make no models view more specific * Adjust plus icon * Adjust form label * Start with correct type selected * Cleanup sizing and more font colors * Small tweaks * Add tips and more info * Cleanup dialog sizing * Add cursor rule for frontend * Cleanup * remove underline * Lazy loading
This commit is contained in:
@@ -126,6 +126,7 @@ export const ClassificationCard = forwardRef<
|
||||
imgClassName,
|
||||
isMobile && "w-full",
|
||||
)}
|
||||
loading="lazy"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
src={`${baseUrl}${data.filepath}`}
|
||||
/>
|
||||
|
||||
@@ -7,58 +7,198 @@ import {
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "../ui/dialog";
|
||||
import { useState } from "react";
|
||||
import { useReducer, useMemo } from "react";
|
||||
import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine";
|
||||
import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea";
|
||||
import Step3ChooseExamples, {
|
||||
Step3FormData,
|
||||
} from "./wizard/Step3ChooseExamples";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { isDesktop } from "react-device-detect";
|
||||
|
||||
const STEPS = [
|
||||
"classificationWizard.steps.nameAndDefine",
|
||||
"classificationWizard.steps.stateArea",
|
||||
"classificationWizard.steps.chooseExamples",
|
||||
"classificationWizard.steps.train",
|
||||
const OBJECT_STEPS = [
|
||||
"wizard.steps.nameAndDefine",
|
||||
"wizard.steps.chooseExamples",
|
||||
];
|
||||
|
||||
const STATE_STEPS = [
|
||||
"wizard.steps.nameAndDefine",
|
||||
"wizard.steps.stateArea",
|
||||
"wizard.steps.chooseExamples",
|
||||
];
|
||||
|
||||
type ClassificationModelWizardDialogProps = {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
defaultModelType?: "state" | "object";
|
||||
};
|
||||
|
||||
type WizardState = {
|
||||
currentStep: number;
|
||||
step1Data?: Step1FormData;
|
||||
step2Data?: Step2FormData;
|
||||
step3Data?: Step3FormData;
|
||||
};
|
||||
|
||||
type WizardAction =
|
||||
| { type: "NEXT_STEP"; payload?: Partial<WizardState> }
|
||||
| { type: "PREVIOUS_STEP" }
|
||||
| { type: "SET_STEP_1"; payload: Step1FormData }
|
||||
| { type: "SET_STEP_2"; payload: Step2FormData }
|
||||
| { type: "SET_STEP_3"; payload: Step3FormData }
|
||||
| { type: "RESET" };
|
||||
|
||||
const initialState: WizardState = {
|
||||
currentStep: 0,
|
||||
};
|
||||
|
||||
function wizardReducer(state: WizardState, action: WizardAction): WizardState {
|
||||
switch (action.type) {
|
||||
case "SET_STEP_1":
|
||||
return {
|
||||
...state,
|
||||
step1Data: action.payload,
|
||||
currentStep: 1,
|
||||
};
|
||||
case "SET_STEP_2":
|
||||
return {
|
||||
...state,
|
||||
step2Data: action.payload,
|
||||
currentStep: 2,
|
||||
};
|
||||
case "SET_STEP_3":
|
||||
return {
|
||||
...state,
|
||||
step3Data: action.payload,
|
||||
currentStep: 3,
|
||||
};
|
||||
case "NEXT_STEP":
|
||||
return {
|
||||
...state,
|
||||
...action.payload,
|
||||
currentStep: state.currentStep + 1,
|
||||
};
|
||||
case "PREVIOUS_STEP":
|
||||
return {
|
||||
...state,
|
||||
currentStep: Math.max(0, state.currentStep - 1),
|
||||
};
|
||||
case "RESET":
|
||||
return initialState;
|
||||
default:
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
export default function ClassificationModelWizardDialog({
|
||||
open,
|
||||
onClose,
|
||||
defaultModelType,
|
||||
}: ClassificationModelWizardDialogProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
|
||||
// step management
|
||||
const [currentStep, _] = useState(0);
|
||||
const [wizardState, dispatch] = useReducer(wizardReducer, initialState);
|
||||
|
||||
const steps = useMemo(() => {
|
||||
if (!wizardState.step1Data) {
|
||||
return OBJECT_STEPS;
|
||||
}
|
||||
return wizardState.step1Data.modelType === "state"
|
||||
? STATE_STEPS
|
||||
: OBJECT_STEPS;
|
||||
}, [wizardState.step1Data]);
|
||||
|
||||
const handleStep1Next = (data: Step1FormData) => {
|
||||
dispatch({ type: "SET_STEP_1", payload: data });
|
||||
};
|
||||
|
||||
const handleStep2Next = (data: Step2FormData) => {
|
||||
dispatch({ type: "SET_STEP_2", payload: data });
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
dispatch({ type: "PREVIOUS_STEP" });
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
dispatch({ type: "RESET" });
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
onClose;
|
||||
handleCancel();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
className="max-h-[90dvh] max-w-4xl overflow-y-auto"
|
||||
className={cn(
|
||||
"",
|
||||
isDesktop &&
|
||||
wizardState.currentStep == 0 &&
|
||||
"max-h-[90%] overflow-y-auto xl:max-h-[80%]",
|
||||
isDesktop &&
|
||||
wizardState.currentStep > 0 &&
|
||||
"max-h-[90%] max-w-[70%] overflow-y-auto xl:max-h-[80%]",
|
||||
)}
|
||||
onInteractOutside={(e) => {
|
||||
e.preventDefault();
|
||||
}}
|
||||
>
|
||||
<StepIndicator
|
||||
steps={STEPS}
|
||||
currentStep={currentStep}
|
||||
steps={steps}
|
||||
currentStep={wizardState.currentStep}
|
||||
variant="dots"
|
||||
className="mb-4 justify-start"
|
||||
/>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t("wizard.title")}</DialogTitle>
|
||||
{currentStep === 0 && (
|
||||
<DialogDescription>{t("wizard.description")}</DialogDescription>
|
||||
{wizardState.currentStep === 0 && (
|
||||
<DialogDescription>
|
||||
{t("wizard.step1.description")}
|
||||
</DialogDescription>
|
||||
)}
|
||||
{wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "state" && (
|
||||
<DialogDescription>
|
||||
{t("wizard.step2.description")}
|
||||
</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<div className="pb-4">
|
||||
<div className="size-full"></div>
|
||||
{wizardState.currentStep === 0 && (
|
||||
<Step1NameAndDefine
|
||||
initialData={wizardState.step1Data}
|
||||
defaultModelType={defaultModelType}
|
||||
onNext={handleStep1Next}
|
||||
onCancel={handleCancel}
|
||||
/>
|
||||
)}
|
||||
{wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "state" && (
|
||||
<Step2StateArea
|
||||
initialData={wizardState.step2Data}
|
||||
onNext={handleStep2Next}
|
||||
onBack={handleBack}
|
||||
/>
|
||||
)}
|
||||
{((wizardState.currentStep === 2 &&
|
||||
wizardState.step1Data?.modelType === "state") ||
|
||||
(wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "object")) &&
|
||||
wizardState.step1Data && (
|
||||
<Step3ChooseExamples
|
||||
step1Data={wizardState.step1Data}
|
||||
step2Data={wizardState.step2Data}
|
||||
initialData={wizardState.step3Data}
|
||||
onClose={onClose}
|
||||
onBack={handleBack}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
498
web/src/components/classification/wizard/Step1NameAndDefine.tsx
Normal file
498
web/src/components/classification/wizard/Step1NameAndDefine.tsx
Normal file
@@ -0,0 +1,498 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { z } from "zod";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useMemo } from "react";
|
||||
import { LuX, LuPlus, LuInfo, LuExternalLink } from "react-icons/lu";
|
||||
import useSWR from "swr";
|
||||
import { FrigateConfig } from "@/types/frigateConfig";
|
||||
import { getTranslatedLabel } from "@/utils/i18n";
|
||||
import { useDocDomain } from "@/hooks/use-doc-domain";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
|
||||
export type ModelType = "state" | "object";
|
||||
export type ObjectClassificationType = "sub_label" | "attribute";
|
||||
|
||||
export type Step1FormData = {
|
||||
modelName: string;
|
||||
modelType: ModelType;
|
||||
objectLabel?: string;
|
||||
objectType?: ObjectClassificationType;
|
||||
classes: string[];
|
||||
};
|
||||
|
||||
type Step1NameAndDefineProps = {
|
||||
initialData?: Partial<Step1FormData>;
|
||||
defaultModelType?: "state" | "object";
|
||||
onNext: (data: Step1FormData) => void;
|
||||
onCancel: () => void;
|
||||
};
|
||||
|
||||
export default function Step1NameAndDefine({
|
||||
initialData,
|
||||
defaultModelType,
|
||||
onNext,
|
||||
onCancel,
|
||||
}: Step1NameAndDefineProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const { data: config } = useSWR<FrigateConfig>("config");
|
||||
const { getLocaleDocUrl } = useDocDomain();
|
||||
|
||||
const objectLabels = useMemo(() => {
|
||||
if (!config) return [];
|
||||
|
||||
const labels = new Set<string>();
|
||||
|
||||
Object.values(config.cameras).forEach((cameraConfig) => {
|
||||
if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
cameraConfig.objects.track.forEach((label) => {
|
||||
if (!config.model.all_attributes.includes(label)) {
|
||||
labels.add(label);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return [...labels].sort();
|
||||
}, [config]);
|
||||
|
||||
const step1FormData = z
|
||||
.object({
|
||||
modelName: z
|
||||
.string()
|
||||
.min(1, t("wizard.step1.errors.nameRequired"))
|
||||
.max(64, t("wizard.step1.errors.nameLength"))
|
||||
.refine((value) => !/^\d+$/.test(value), {
|
||||
message: t("wizard.step1.errors.nameOnlyNumbers"),
|
||||
}),
|
||||
modelType: z.enum(["state", "object"]),
|
||||
objectLabel: z.string().optional(),
|
||||
objectType: z.enum(["sub_label", "attribute"]).optional(),
|
||||
classes: z
|
||||
.array(z.string())
|
||||
.min(1, t("wizard.step1.errors.classRequired"))
|
||||
.refine(
|
||||
(classes) => {
|
||||
const nonEmpty = classes.filter((c) => c.trim().length > 0);
|
||||
return nonEmpty.length >= 1;
|
||||
},
|
||||
{ message: t("wizard.step1.errors.classRequired") },
|
||||
)
|
||||
.refine(
|
||||
(classes) => {
|
||||
const nonEmpty = classes.filter((c) => c.trim().length > 0);
|
||||
const unique = new Set(nonEmpty.map((c) => c.toLowerCase()));
|
||||
return unique.size === nonEmpty.length;
|
||||
},
|
||||
{ message: t("wizard.step1.errors.classesUnique") },
|
||||
),
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
// State models require at least 2 classes
|
||||
if (data.modelType === "state") {
|
||||
const nonEmpty = data.classes.filter((c) => c.trim().length > 0);
|
||||
return nonEmpty.length >= 2;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.stateRequiresTwoClasses"),
|
||||
path: ["classes"],
|
||||
},
|
||||
)
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.modelType === "object") {
|
||||
return data.objectLabel !== undefined && data.objectLabel !== "";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.objectLabelRequired"),
|
||||
path: ["objectLabel"],
|
||||
},
|
||||
)
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.modelType === "object") {
|
||||
return data.objectType !== undefined;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.objectTypeRequired"),
|
||||
path: ["objectType"],
|
||||
},
|
||||
);
|
||||
|
||||
const form = useForm<z.infer<typeof step1FormData>>({
|
||||
resolver: zodResolver(step1FormData),
|
||||
defaultValues: {
|
||||
modelName: initialData?.modelName || "",
|
||||
modelType: initialData?.modelType || defaultModelType || "state",
|
||||
objectLabel: initialData?.objectLabel,
|
||||
objectType: initialData?.objectType || "sub_label",
|
||||
classes: initialData?.classes?.length ? initialData.classes : [""],
|
||||
},
|
||||
mode: "onChange",
|
||||
});
|
||||
|
||||
const watchedClasses = form.watch("classes");
|
||||
const watchedModelType = form.watch("modelType");
|
||||
const watchedObjectType = form.watch("objectType");
|
||||
|
||||
const handleAddClass = () => {
|
||||
const currentClasses = form.getValues("classes");
|
||||
form.setValue("classes", [...currentClasses, ""], { shouldValidate: true });
|
||||
};
|
||||
|
||||
const handleRemoveClass = (index: number) => {
|
||||
const currentClasses = form.getValues("classes");
|
||||
const newClasses = currentClasses.filter((_, i) => i !== index);
|
||||
|
||||
// Ensure at least one field remains (even if empty)
|
||||
if (newClasses.length === 0) {
|
||||
form.setValue("classes", [""], { shouldValidate: true });
|
||||
} else {
|
||||
form.setValue("classes", newClasses, { shouldValidate: true });
|
||||
}
|
||||
};
|
||||
|
||||
const onSubmit = (data: z.infer<typeof step1FormData>) => {
|
||||
// Filter out empty classes
|
||||
const filteredClasses = data.classes.filter((c) => c.trim().length > 0);
|
||||
onNext({
|
||||
...data,
|
||||
classes: filteredClasses,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="modelName"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.name")}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
className="h-8"
|
||||
placeholder={t("wizard.step1.namePlaceholder")}
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="modelType"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.type")}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<RadioGroup
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
className="flex flex-col gap-4 pt-2"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedModelType === "state"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="state"
|
||||
value="state"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="state">
|
||||
{t("wizard.step1.typeState")}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedModelType === "object"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="object"
|
||||
value="object"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="object">
|
||||
{t("wizard.step1.typeObject")}
|
||||
</Label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{watchedModelType === "object" && (
|
||||
<>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="objectLabel"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.objectLabel")}
|
||||
</FormLabel>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="h-8">
|
||||
<SelectValue
|
||||
placeholder={t(
|
||||
"wizard.step1.objectLabelPlaceholder",
|
||||
)}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{objectLabels.map((label) => (
|
||||
<SelectItem
|
||||
key={label}
|
||||
value={label}
|
||||
className="cursor-pointer hover:bg-secondary-highlight"
|
||||
>
|
||||
{getTranslatedLabel(label)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="objectType"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<div className="flex items-center gap-1">
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.classificationType")}
|
||||
</FormLabel>
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-4 w-4 p-0"
|
||||
>
|
||||
<LuInfo className="size-3" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="pointer-events-auto w-80 text-xs">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="text-sm">
|
||||
{t("wizard.step1.classificationTypeDesc")}
|
||||
</div>
|
||||
<div className="mt-3 flex items-center text-primary">
|
||||
<a
|
||||
href={getLocaleDocUrl(
|
||||
"configuration/custom_classification/object_classification#classification-type",
|
||||
)}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline cursor-pointer"
|
||||
>
|
||||
{t("readTheDocumentation", { ns: "common" })}
|
||||
<LuExternalLink className="ml-2 inline-flex size-3" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
<FormControl>
|
||||
<RadioGroup
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
className="flex flex-col gap-4 pt-2"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedObjectType === "sub_label"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="sub_label"
|
||||
value="sub_label"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="sub_label">
|
||||
{t("wizard.step1.classificationSubLabel")}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedObjectType === "attribute"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="attribute"
|
||||
value="attribute"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="attribute">
|
||||
{t("wizard.step1.classificationAttribute")}
|
||||
</Label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-1">
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.classes")}
|
||||
</FormLabel>
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="sm" className="h-4 w-4 p-0">
|
||||
<LuInfo className="size-3" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="pointer-events-auto w-80 text-xs">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="text-sm">
|
||||
{watchedModelType === "state"
|
||||
? t("wizard.step1.classesStateDesc")
|
||||
: t("wizard.step1.classesObjectDesc")}
|
||||
</div>
|
||||
<div className="mt-3 flex items-center text-primary">
|
||||
<a
|
||||
href={getLocaleDocUrl(
|
||||
watchedModelType === "state"
|
||||
? "configuration/custom_classification/state_classification"
|
||||
: "configuration/custom_classification/object_classification",
|
||||
)}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline cursor-pointer"
|
||||
>
|
||||
{t("readTheDocumentation", { ns: "common" })}
|
||||
<LuExternalLink className="ml-2 inline-flex size-3" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
|
||||
onClick={handleAddClass}
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{watchedClasses.map((_, index) => (
|
||||
<FormField
|
||||
key={index}
|
||||
control={form.control}
|
||||
name={`classes.${index}`}
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormControl>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
className="h-8"
|
||||
placeholder={t("wizard.step1.classPlaceholder")}
|
||||
{...field}
|
||||
/>
|
||||
{watchedClasses.length > 1 && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 w-8 p-0"
|
||||
onClick={() => handleRemoveClass(index)}
|
||||
>
|
||||
<LuX className="size-4" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{form.formState.errors.classes && (
|
||||
<p className="text-sm font-medium text-destructive">
|
||||
{form.formState.errors.classes.message}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</Form>
|
||||
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onCancel} className="sm:flex-1">
|
||||
{t("button.cancel", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!form.formState.isValid}
|
||||
>
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
479
web/src/components/classification/wizard/Step2StateArea.tsx
Normal file
479
web/src/components/classification/wizard/Step2StateArea.tsx
Normal file
@@ -0,0 +1,479 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useMemo, useRef, useCallback, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import { FrigateConfig } from "@/types/frigateConfig";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { LuX, LuPlus } from "react-icons/lu";
|
||||
import { Stage, Layer, Rect, Transformer } from "react-konva";
|
||||
import Konva from "konva";
|
||||
import { useResizeObserver } from "@/hooks/resize-observer";
|
||||
import { useApiHost } from "@/api";
|
||||
import { resolveCameraName } from "@/hooks/use-camera-friendly-name";
|
||||
import Heading from "@/components/ui/heading";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type CameraAreaConfig = {
|
||||
camera: string;
|
||||
crop: [number, number, number, number];
|
||||
};
|
||||
|
||||
export type Step2FormData = {
|
||||
cameraAreas: CameraAreaConfig[];
|
||||
};
|
||||
|
||||
type Step2StateAreaProps = {
|
||||
initialData?: Partial<Step2FormData>;
|
||||
onNext: (data: Step2FormData) => void;
|
||||
onBack: () => void;
|
||||
};
|
||||
|
||||
export default function Step2StateArea({
|
||||
initialData,
|
||||
onNext,
|
||||
onBack,
|
||||
}: Step2StateAreaProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const { data: config } = useSWR<FrigateConfig>("config");
|
||||
const apiHost = useApiHost();
|
||||
|
||||
const [cameraAreas, setCameraAreas] = useState<CameraAreaConfig[]>(
|
||||
initialData?.cameraAreas || [],
|
||||
);
|
||||
const [selectedCameraIndex, setSelectedCameraIndex] = useState<number>(0);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const imageRef = useRef<HTMLImageElement>(null);
|
||||
const stageRef = useRef<Konva.Stage>(null);
|
||||
const rectRef = useRef<Konva.Rect>(null);
|
||||
const transformerRef = useRef<Konva.Transformer>(null);
|
||||
|
||||
const [{ width: containerWidth }] = useResizeObserver(containerRef);
|
||||
|
||||
const availableCameras = useMemo(() => {
|
||||
if (!config) return [];
|
||||
|
||||
const selectedCameraNames = cameraAreas.map((ca) => ca.camera);
|
||||
return Object.entries(config.cameras)
|
||||
.sort()
|
||||
.filter(
|
||||
([name, cam]) =>
|
||||
cam.enabled &&
|
||||
cam.enabled_in_config &&
|
||||
!selectedCameraNames.includes(name),
|
||||
)
|
||||
.map(([name]) => ({
|
||||
name,
|
||||
displayName: resolveCameraName(config, name),
|
||||
}));
|
||||
}, [config, cameraAreas]);
|
||||
|
||||
const selectedCamera = useMemo(() => {
|
||||
if (cameraAreas.length === 0) return null;
|
||||
return cameraAreas[selectedCameraIndex];
|
||||
}, [cameraAreas, selectedCameraIndex]);
|
||||
|
||||
const selectedCameraConfig = useMemo(() => {
|
||||
if (!config || !selectedCamera) return null;
|
||||
return config.cameras[selectedCamera.camera];
|
||||
}, [config, selectedCamera]);
|
||||
|
||||
const imageSize = useMemo(() => {
|
||||
if (!containerWidth || !selectedCameraConfig) {
|
||||
return { width: 0, height: 0 };
|
||||
}
|
||||
|
||||
const containerAspectRatio = 16 / 9;
|
||||
const containerHeight = containerWidth / containerAspectRatio;
|
||||
|
||||
const cameraAspectRatio =
|
||||
selectedCameraConfig.detect.width / selectedCameraConfig.detect.height;
|
||||
|
||||
// Fit camera within 16:9 container
|
||||
let imageWidth, imageHeight;
|
||||
if (cameraAspectRatio > containerAspectRatio) {
|
||||
imageWidth = containerWidth;
|
||||
imageHeight = imageWidth / cameraAspectRatio;
|
||||
} else {
|
||||
imageHeight = containerHeight;
|
||||
imageWidth = imageHeight * cameraAspectRatio;
|
||||
}
|
||||
|
||||
return { width: imageWidth, height: imageHeight };
|
||||
}, [containerWidth, selectedCameraConfig]);
|
||||
|
||||
const handleAddCamera = useCallback(
|
||||
(cameraName: string) => {
|
||||
// Calculate a square crop in pixel space
|
||||
const camera = config?.cameras[cameraName];
|
||||
if (!camera) return;
|
||||
|
||||
const cameraAspect = camera.detect.width / camera.detect.height;
|
||||
const cropSize = 0.3;
|
||||
let x1, y1, x2, y2;
|
||||
|
||||
if (cameraAspect >= 1) {
|
||||
const pixelSize = cropSize * camera.detect.height;
|
||||
const normalizedWidth = pixelSize / camera.detect.width;
|
||||
x1 = (1 - normalizedWidth) / 2;
|
||||
y1 = (1 - cropSize) / 2;
|
||||
x2 = x1 + normalizedWidth;
|
||||
y2 = y1 + cropSize;
|
||||
} else {
|
||||
const pixelSize = cropSize * camera.detect.width;
|
||||
const normalizedHeight = pixelSize / camera.detect.height;
|
||||
x1 = (1 - cropSize) / 2;
|
||||
y1 = (1 - normalizedHeight) / 2;
|
||||
x2 = x1 + cropSize;
|
||||
y2 = y1 + normalizedHeight;
|
||||
}
|
||||
|
||||
const newArea: CameraAreaConfig = {
|
||||
camera: cameraName,
|
||||
crop: [x1, y1, x2, y2],
|
||||
};
|
||||
setCameraAreas([...cameraAreas, newArea]);
|
||||
setSelectedCameraIndex(cameraAreas.length);
|
||||
setIsPopoverOpen(false);
|
||||
},
|
||||
[cameraAreas, config],
|
||||
);
|
||||
|
||||
const handleRemoveCamera = useCallback(
|
||||
(index: number) => {
|
||||
const newAreas = cameraAreas.filter((_, i) => i !== index);
|
||||
setCameraAreas(newAreas);
|
||||
if (selectedCameraIndex >= newAreas.length) {
|
||||
setSelectedCameraIndex(Math.max(0, newAreas.length - 1));
|
||||
}
|
||||
},
|
||||
[cameraAreas, selectedCameraIndex],
|
||||
);
|
||||
|
||||
const handleCropChange = useCallback(
|
||||
(crop: [number, number, number, number]) => {
|
||||
const newAreas = [...cameraAreas];
|
||||
newAreas[selectedCameraIndex] = {
|
||||
...newAreas[selectedCameraIndex],
|
||||
crop,
|
||||
};
|
||||
setCameraAreas(newAreas);
|
||||
},
|
||||
[cameraAreas, selectedCameraIndex],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setImageLoaded(false);
|
||||
}, [selectedCamera]);
|
||||
|
||||
useEffect(() => {
|
||||
const rect = rectRef.current;
|
||||
const transformer = transformerRef.current;
|
||||
|
||||
if (
|
||||
rect &&
|
||||
transformer &&
|
||||
selectedCamera &&
|
||||
imageSize.width > 0 &&
|
||||
imageLoaded
|
||||
) {
|
||||
rect.scaleX(1);
|
||||
rect.scaleY(1);
|
||||
transformer.nodes([rect]);
|
||||
transformer.getLayer()?.batchDraw();
|
||||
}
|
||||
}, [selectedCamera, imageSize, imageLoaded]);
|
||||
|
||||
const handleRectChange = useCallback(() => {
|
||||
const rect = rectRef.current;
|
||||
|
||||
if (rect && imageSize.width > 0) {
|
||||
const actualWidth = rect.width() * rect.scaleX();
|
||||
const actualHeight = rect.height() * rect.scaleY();
|
||||
|
||||
// Average dimensions to maintain perfect square
|
||||
const size = (actualWidth + actualHeight) / 2;
|
||||
|
||||
rect.width(size);
|
||||
rect.height(size);
|
||||
rect.scaleX(1);
|
||||
rect.scaleY(1);
|
||||
|
||||
const x1 = rect.x() / imageSize.width;
|
||||
const y1 = rect.y() / imageSize.height;
|
||||
const x2 = (rect.x() + size) / imageSize.width;
|
||||
const y2 = (rect.y() + size) / imageSize.height;
|
||||
|
||||
handleCropChange([x1, y1, x2, y2]);
|
||||
}
|
||||
}, [imageSize, handleCropChange]);
|
||||
|
||||
const handleContinue = useCallback(() => {
|
||||
onNext({ cameraAreas });
|
||||
}, [cameraAreas, onNext]);
|
||||
|
||||
const canContinue = cameraAreas.length > 0;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-4 overflow-hidden",
|
||||
isMobile ? "flex-col" : "flex-row",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-shrink-0 flex-col gap-2 overflow-y-auto rounded-lg bg-secondary p-4",
|
||||
isMobile ? "w-full" : "w-64",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">{t("wizard.step2.cameras")}</h3>
|
||||
{availableCameras.length > 0 ? (
|
||||
<Popover
|
||||
open={isPopoverOpen}
|
||||
onOpenChange={setIsPopoverOpen}
|
||||
modal={true}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
|
||||
aria-label="Add camera"
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="scrollbar-container w-64 border bg-background p-3 shadow-lg"
|
||||
align="start"
|
||||
sideOffset={5}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<div className="flex flex-col gap-2">
|
||||
<Heading as="h4" className="text-sm text-primary-variant">
|
||||
{t("wizard.step2.selectCamera")}
|
||||
</Heading>
|
||||
<div className="scrollbar-container flex max-h-[30vh] flex-col gap-1 overflow-y-auto">
|
||||
{availableCameras.map((cam) => (
|
||||
<Button
|
||||
key={cam.name}
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-auto justify-start p-2 capitalize text-primary"
|
||||
onClick={() => {
|
||||
handleAddCamera(cam.name);
|
||||
}}
|
||||
>
|
||||
{cam.displayName}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="size-6 cursor-not-allowed rounded-md bg-muted p-1 text-muted-foreground"
|
||||
disabled
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
{cameraAreas.map((area, index) => {
|
||||
const isSelected = index === selectedCameraIndex;
|
||||
const displayName = resolveCameraName(config, area.camera);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={area.camera}
|
||||
className={`flex items-center justify-between rounded-md p-2 ${
|
||||
isSelected
|
||||
? "bg-selected/20 ring-1 ring-selected"
|
||||
: "hover:bg-secondary/50"
|
||||
} cursor-pointer`}
|
||||
onClick={() => setSelectedCameraIndex(index)}
|
||||
>
|
||||
<span className="text-sm capitalize">{displayName}</span>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 w-6 p-0"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleRemoveCamera(index);
|
||||
}}
|
||||
>
|
||||
<LuX className="size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{cameraAreas.length === 0 && (
|
||||
<div className="flex flex-1 items-center justify-center text-center text-sm text-muted-foreground">
|
||||
{t("wizard.step2.noCameras")}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-1 items-center justify-center overflow-hidden rounded-lg p-4">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="flex items-center justify-center"
|
||||
style={{
|
||||
width: "100%",
|
||||
aspectRatio: "16 / 9",
|
||||
maxHeight: "100%",
|
||||
}}
|
||||
>
|
||||
{selectedCamera && selectedCameraConfig && imageSize.width > 0 ? (
|
||||
<div
|
||||
style={{
|
||||
width: imageSize.width,
|
||||
height: imageSize.height,
|
||||
position: "relative",
|
||||
}}
|
||||
>
|
||||
<img
|
||||
ref={imageRef}
|
||||
src={`${apiHost}api/${selectedCamera.camera}/latest.jpg?h=500`}
|
||||
alt={resolveCameraName(config, selectedCamera.camera)}
|
||||
className="h-full w-full object-contain"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
/>
|
||||
<Stage
|
||||
ref={stageRef}
|
||||
width={imageSize.width}
|
||||
height={imageSize.height}
|
||||
className="absolute inset-0"
|
||||
>
|
||||
<Layer>
|
||||
<Rect
|
||||
ref={rectRef}
|
||||
x={selectedCamera.crop[0] * imageSize.width}
|
||||
y={selectedCamera.crop[1] * imageSize.height}
|
||||
width={
|
||||
(selectedCamera.crop[2] - selectedCamera.crop[0]) *
|
||||
imageSize.width
|
||||
}
|
||||
height={
|
||||
(selectedCamera.crop[3] - selectedCamera.crop[1]) *
|
||||
imageSize.height
|
||||
}
|
||||
stroke="#3b82f6"
|
||||
strokeWidth={2}
|
||||
fill="rgba(59, 130, 246, 0.1)"
|
||||
draggable
|
||||
dragBoundFunc={(pos) => {
|
||||
const rect = rectRef.current;
|
||||
if (!rect) return pos;
|
||||
|
||||
const size = rect.width();
|
||||
const x = Math.max(
|
||||
0,
|
||||
Math.min(pos.x, imageSize.width - size),
|
||||
);
|
||||
const y = Math.max(
|
||||
0,
|
||||
Math.min(pos.y, imageSize.height - size),
|
||||
);
|
||||
|
||||
return { x, y };
|
||||
}}
|
||||
onDragEnd={handleRectChange}
|
||||
onTransformEnd={handleRectChange}
|
||||
/>
|
||||
<Transformer
|
||||
ref={transformerRef}
|
||||
rotateEnabled={false}
|
||||
enabledAnchors={[
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
]}
|
||||
boundBoxFunc={(_oldBox, newBox) => {
|
||||
const minSize = 50;
|
||||
const maxSize = Math.min(
|
||||
imageSize.width,
|
||||
imageSize.height,
|
||||
);
|
||||
|
||||
// Clamp dimensions to stage bounds first
|
||||
const clampedWidth = Math.max(
|
||||
minSize,
|
||||
Math.min(newBox.width, maxSize),
|
||||
);
|
||||
const clampedHeight = Math.max(
|
||||
minSize,
|
||||
Math.min(newBox.height, maxSize),
|
||||
);
|
||||
|
||||
// Enforce square using average
|
||||
const size = (clampedWidth + clampedHeight) / 2;
|
||||
|
||||
// Clamp position to keep square within bounds
|
||||
const x = Math.max(
|
||||
0,
|
||||
Math.min(newBox.x, imageSize.width - size),
|
||||
);
|
||||
const y = Math.max(
|
||||
0,
|
||||
Math.min(newBox.y, imageSize.height - size),
|
||||
);
|
||||
|
||||
return {
|
||||
...newBox,
|
||||
x,
|
||||
y,
|
||||
width: size,
|
||||
height: size,
|
||||
};
|
||||
}}
|
||||
/>
|
||||
</Layer>
|
||||
</Stage>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center justify-center text-muted-foreground">
|
||||
{t("wizard.step2.selectCameraPrompt")}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onBack} className="sm:flex-1">
|
||||
{t("button.back", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={handleContinue}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!canContinue}
|
||||
>
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
444
web/src/components/classification/wizard/Step3ChooseExamples.tsx
Normal file
444
web/src/components/classification/wizard/Step3ChooseExamples.tsx
Normal file
@@ -0,0 +1,444 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||
import axios from "axios";
|
||||
import { toast } from "sonner";
|
||||
import { Step1FormData } from "./Step1NameAndDefine";
|
||||
import { Step2FormData } from "./Step2StateArea";
|
||||
import useSWR from "swr";
|
||||
import { baseUrl } from "@/api/baseUrl";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type Step3FormData = {
|
||||
examplesGenerated: boolean;
|
||||
imageClassifications?: { [imageName: string]: string };
|
||||
};
|
||||
|
||||
type Step3ChooseExamplesProps = {
|
||||
step1Data: Step1FormData;
|
||||
step2Data?: Step2FormData;
|
||||
initialData?: Partial<Step3FormData>;
|
||||
onClose: () => void;
|
||||
onBack: () => void;
|
||||
};
|
||||
|
||||
export default function Step3ChooseExamples({
|
||||
step1Data,
|
||||
step2Data,
|
||||
initialData,
|
||||
onClose,
|
||||
onBack,
|
||||
}: Step3ChooseExamplesProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const [isGenerating, setIsGenerating] = useState(false);
|
||||
const [hasGenerated, setHasGenerated] = useState(
|
||||
initialData?.examplesGenerated || false,
|
||||
);
|
||||
const [imageClassifications, setImageClassifications] = useState<{
|
||||
[imageName: string]: string;
|
||||
}>(initialData?.imageClassifications || {});
|
||||
const [isTraining, setIsTraining] = useState(false);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [currentClassIndex, setCurrentClassIndex] = useState(0);
|
||||
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
|
||||
|
||||
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
||||
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
||||
);
|
||||
|
||||
const unknownImages = useMemo(() => {
|
||||
if (!trainImages) return [];
|
||||
return trainImages;
|
||||
}, [trainImages]);
|
||||
|
||||
const toggleImageSelection = useCallback((imageName: string) => {
|
||||
setSelectedImages((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
if (newSet.has(imageName)) {
|
||||
newSet.delete(imageName);
|
||||
} else {
|
||||
newSet.add(imageName);
|
||||
}
|
||||
return newSet;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Get all classes (excluding "none" - it will be auto-assigned)
|
||||
const allClasses = useMemo(() => {
|
||||
return [...step1Data.classes];
|
||||
}, [step1Data.classes]);
|
||||
|
||||
const currentClass = allClasses[currentClassIndex];
|
||||
|
||||
const processClassificationsAndTrain = useCallback(
|
||||
async (classifications: { [imageName: string]: string }) => {
|
||||
// Step 1: Create config for the new model
|
||||
const modelConfig: {
|
||||
enabled: boolean;
|
||||
name: string;
|
||||
threshold: number;
|
||||
state_config?: {
|
||||
cameras: Record<string, { crop: number[] }>;
|
||||
motion: boolean;
|
||||
};
|
||||
object_config?: { objects: string[]; classification_type: string };
|
||||
} = {
|
||||
enabled: true,
|
||||
name: step1Data.modelName,
|
||||
threshold: 0.8,
|
||||
};
|
||||
|
||||
if (step1Data.modelType === "state") {
|
||||
// State model config
|
||||
const cameras: Record<string, { crop: number[] }> = {};
|
||||
step2Data?.cameraAreas.forEach((area) => {
|
||||
cameras[area.camera] = {
|
||||
crop: area.crop,
|
||||
};
|
||||
});
|
||||
|
||||
modelConfig.state_config = {
|
||||
cameras,
|
||||
motion: true,
|
||||
};
|
||||
} else {
|
||||
// Object model config
|
||||
modelConfig.object_config = {
|
||||
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
|
||||
classification_type: step1Data.objectType || "sub_label",
|
||||
} as { objects: string[]; classification_type: string };
|
||||
}
|
||||
|
||||
// Update config via config API
|
||||
await axios.put("/config/set", {
|
||||
requires_restart: 0,
|
||||
update_topic: `config/classification/custom/${step1Data.modelName}`,
|
||||
config_data: {
|
||||
classification: {
|
||||
custom: {
|
||||
[step1Data.modelName]: modelConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Step 2: Classify each image by moving it to the correct category folder
|
||||
const categorizePromises = Object.entries(classifications).map(
|
||||
([imageName, className]) => {
|
||||
if (!className) return Promise.resolve();
|
||||
return axios.post(
|
||||
`/classification/${step1Data.modelName}/dataset/categorize`,
|
||||
{
|
||||
training_file: imageName,
|
||||
category: className === "none" ? "none" : className,
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
await Promise.all(categorizePromises);
|
||||
|
||||
// Step 3: Kick off training
|
||||
await axios.post(`/classification/${step1Data.modelName}/train`);
|
||||
|
||||
toast.success(t("wizard.step3.trainingStarted"));
|
||||
setIsTraining(true);
|
||||
},
|
||||
[step1Data, step2Data, t],
|
||||
);
|
||||
|
||||
const handleContinueClassification = useCallback(async () => {
|
||||
// Mark selected images with current class
|
||||
const newClassifications = { ...imageClassifications };
|
||||
selectedImages.forEach((imageName) => {
|
||||
newClassifications[imageName] = currentClass;
|
||||
});
|
||||
|
||||
// Check if we're on the last class to select
|
||||
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||
|
||||
if (isLastClass) {
|
||||
// Assign remaining unclassified images
|
||||
unknownImages.slice(0, 24).forEach((imageName) => {
|
||||
if (!newClassifications[imageName]) {
|
||||
// For state models with 2 classes, assign to the last class
|
||||
// For object models, assign to "none"
|
||||
if (step1Data.modelType === "state" && allClasses.length === 2) {
|
||||
newClassifications[imageName] = allClasses[allClasses.length - 1];
|
||||
} else {
|
||||
newClassifications[imageName] = "none";
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// All done, trigger training immediately
|
||||
setImageClassifications(newClassifications);
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
await processClassificationsAndTrain(newClassifications);
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to classify images";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||
);
|
||||
setIsProcessing(false);
|
||||
}
|
||||
} else {
|
||||
// Move to next class
|
||||
setImageClassifications(newClassifications);
|
||||
setCurrentClassIndex((prev) => prev + 1);
|
||||
setSelectedImages(new Set());
|
||||
}
|
||||
}, [
|
||||
selectedImages,
|
||||
currentClass,
|
||||
currentClassIndex,
|
||||
allClasses,
|
||||
imageClassifications,
|
||||
unknownImages,
|
||||
step1Data,
|
||||
processClassificationsAndTrain,
|
||||
t,
|
||||
]);
|
||||
|
||||
const generateExamples = useCallback(async () => {
|
||||
setIsGenerating(true);
|
||||
|
||||
try {
|
||||
if (step1Data.modelType === "state") {
|
||||
// For state models, use cameras and crop areas
|
||||
if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) {
|
||||
toast.error(t("wizard.step3.errors.noCameras"));
|
||||
setIsGenerating(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cameras: { [key: string]: [number, number, number, number] } = {};
|
||||
step2Data.cameraAreas.forEach((area) => {
|
||||
cameras[area.camera] = area.crop;
|
||||
});
|
||||
|
||||
await axios.post("/classification/generate_examples/state", {
|
||||
model_name: step1Data.modelName,
|
||||
cameras,
|
||||
});
|
||||
} else {
|
||||
// For object models, use label
|
||||
if (!step1Data.objectLabel) {
|
||||
toast.error(t("wizard.step3.errors.noObjectLabel"));
|
||||
setIsGenerating(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// For now, use all enabled cameras
|
||||
// TODO: In the future, we might want to let users select specific cameras
|
||||
await axios.post("/classification/generate_examples/object", {
|
||||
model_name: step1Data.modelName,
|
||||
label: step1Data.objectLabel,
|
||||
});
|
||||
}
|
||||
|
||||
setHasGenerated(true);
|
||||
toast.success(t("wizard.step3.generateSuccess"));
|
||||
|
||||
await refreshTrainImages();
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to generate examples";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.generateFailed", { error: errorMessage }),
|
||||
);
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
}, [step1Data, step2Data, t, refreshTrainImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasGenerated && !isGenerating) {
|
||||
generateExamples();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const handleContinue = useCallback(async () => {
|
||||
setIsProcessing(true);
|
||||
try {
|
||||
await processClassificationsAndTrain(imageClassifications);
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to classify images";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||
);
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}, [imageClassifications, processClassificationsAndTrain, t]);
|
||||
|
||||
const unclassifiedImages = useMemo(() => {
|
||||
if (!unknownImages) return [];
|
||||
const images = unknownImages.slice(0, 24);
|
||||
|
||||
// Only filter if we have any classifications
|
||||
if (Object.keys(imageClassifications).length === 0) {
|
||||
return images;
|
||||
}
|
||||
|
||||
return images.filter((img) => !imageClassifications[img]);
|
||||
}, [unknownImages, imageClassifications]);
|
||||
|
||||
const allImagesClassified = useMemo(() => {
|
||||
return unclassifiedImages.length === 0;
|
||||
}, [unclassifiedImages]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
{isTraining ? (
|
||||
<div className="flex flex-col items-center gap-6 py-12">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<div className="text-center">
|
||||
<h3 className="mb-2 text-lg font-medium">
|
||||
{t("wizard.step3.training.title")}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.training.description")}
|
||||
</p>
|
||||
</div>
|
||||
<Button onClick={onClose} variant="select" className="mt-4">
|
||||
{t("button.close", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
) : isGenerating ? (
|
||||
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<div className="text-center">
|
||||
<h3 className="mb-2 text-lg font-medium">
|
||||
{t("wizard.step3.generating.title")}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.generating.description")}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : hasGenerated ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
{!allImagesClassified && (
|
||||
<div className="text-center">
|
||||
<h3 className="text-lg font-medium">
|
||||
{t("wizard.step3.selectImagesPrompt", {
|
||||
className: currentClass,
|
||||
})}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.selectImagesDescription")}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg bg-secondary/30 p-4",
|
||||
isMobile && "max-h-[60vh] overflow-y-auto",
|
||||
)}
|
||||
>
|
||||
{!unknownImages || unknownImages.length === 0 ? (
|
||||
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
|
||||
<p className="text-muted-foreground">
|
||||
{t("wizard.step3.noImages")}
|
||||
</p>
|
||||
<Button onClick={generateExamples} variant="select">
|
||||
{t("wizard.step3.retryGenerate")}
|
||||
</Button>
|
||||
</div>
|
||||
) : allImagesClassified && isProcessing ? (
|
||||
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<p className="text-lg font-medium">
|
||||
{t("wizard.step3.classifying")}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
|
||||
{unclassifiedImages.map((imageName, index) => {
|
||||
const isSelected = selectedImages.has(imageName);
|
||||
return (
|
||||
<div
|
||||
key={imageName}
|
||||
className={cn(
|
||||
"aspect-square cursor-pointer overflow-hidden rounded-lg border-2 bg-background transition-all",
|
||||
isSelected && "border-selected ring-2 ring-selected",
|
||||
)}
|
||||
onClick={() => toggleImageSelection(imageName)}
|
||||
>
|
||||
<img
|
||||
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
|
||||
alt={`Example ${index + 1}`}
|
||||
className="h-full w-full object-cover"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
||||
<p className="text-sm text-destructive">
|
||||
{t("wizard.step3.errors.generationFailed")}
|
||||
</p>
|
||||
<Button onClick={generateExamples} variant="select">
|
||||
{t("wizard.step3.retryGenerate")}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isTraining && (
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onBack} className="sm:flex-1">
|
||||
{t("button.back", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={
|
||||
allImagesClassified
|
||||
? handleContinue
|
||||
: handleContinueClassification
|
||||
}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!hasGenerated || isGenerating || isProcessing}
|
||||
>
|
||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -10,11 +10,14 @@ import {
|
||||
CustomClassificationModelConfig,
|
||||
FrigateConfig,
|
||||
} from "@/types/frigateConfig";
|
||||
import { useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { FaFolderPlus } from "react-icons/fa";
|
||||
import { MdModelTraining } from "react-icons/md";
|
||||
import useSWR from "swr";
|
||||
import Heading from "@/components/ui/heading";
|
||||
import { useOverlayState } from "@/hooks/use-overlay-state";
|
||||
|
||||
const allModelTypes = ["objects", "states"] as const;
|
||||
type ModelType = (typeof allModelTypes)[number];
|
||||
@@ -26,11 +29,24 @@ export default function ModelSelectionView({
|
||||
onClick,
|
||||
}: ModelSelectionViewProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const [page, setPage] = useState<ModelType>("objects");
|
||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||
const { data: config } = useSWR<FrigateConfig>("config", {
|
||||
revalidateOnFocus: false,
|
||||
});
|
||||
const [page, setPage] = useOverlayState<ModelType>("objects", "objects");
|
||||
const [pageToggle, setPageToggle] = useOptimisticState(
|
||||
page || "objects",
|
||||
setPage,
|
||||
100,
|
||||
);
|
||||
const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
|
||||
"config",
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
},
|
||||
);
|
||||
|
||||
// title
|
||||
|
||||
useEffect(() => {
|
||||
document.title = t("documentTitle");
|
||||
}, [t]);
|
||||
|
||||
// data
|
||||
|
||||
@@ -64,15 +80,15 @@ export default function ModelSelectionView({
|
||||
return <ActivityIndicator />;
|
||||
}
|
||||
|
||||
if (classificationConfigs.length == 0) {
|
||||
return <div>You need to setup a custom model configuration.</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex size-full flex-col p-2">
|
||||
<ClassificationModelWizardDialog
|
||||
open={newModel}
|
||||
onClose={() => setNewModel(false)}
|
||||
defaultModelType={pageToggle === "objects" ? "object" : "state"}
|
||||
onClose={() => {
|
||||
setNewModel(false);
|
||||
refreshConfig();
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="flex h-12 w-full items-center justify-between">
|
||||
@@ -84,7 +100,6 @@ export default function ModelSelectionView({
|
||||
value={pageToggle}
|
||||
onValueChange={(value: ModelType) => {
|
||||
if (value) {
|
||||
// Restrict viewer navigation
|
||||
setPageToggle(value);
|
||||
}
|
||||
}}
|
||||
@@ -117,13 +132,46 @@ export default function ModelSelectionView({
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex size-full gap-2 p-2">
|
||||
{selectedClassificationConfigs.map((config) => (
|
||||
<ModelCard
|
||||
key={config.name}
|
||||
config={config}
|
||||
onClick={() => onClick(config)}
|
||||
{selectedClassificationConfigs.length === 0 ? (
|
||||
<NoModelsView
|
||||
onCreateModel={() => setNewModel(true)}
|
||||
modelType={pageToggle}
|
||||
/>
|
||||
))}
|
||||
) : (
|
||||
selectedClassificationConfigs.map((config) => (
|
||||
<ModelCard
|
||||
key={config.name}
|
||||
config={config}
|
||||
onClick={() => onClick(config)}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function NoModelsView({
|
||||
onCreateModel,
|
||||
modelType,
|
||||
}: {
|
||||
onCreateModel: () => void;
|
||||
modelType: ModelType;
|
||||
}) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const typeKey = modelType === "objects" ? "object" : "state";
|
||||
|
||||
return (
|
||||
<div className="flex size-full items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<MdModelTraining className="size-8" />
|
||||
<Heading as="h4">{t(`noModels.${typeKey}.title`)}</Heading>
|
||||
<div className="mb-3 text-center text-secondary-foreground">
|
||||
{t(`noModels.${typeKey}.description`)}
|
||||
</div>
|
||||
<Button size="sm" variant="select" onClick={onCreateModel}>
|
||||
{t(`noModels.${typeKey}.buttonText`)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -139,13 +187,17 @@ function ModelCard({ config, onClick }: ModelCardProps) {
|
||||
}>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
|
||||
|
||||
const coverImage = useMemo(() => {
|
||||
if (!dataset?.length) {
|
||||
if (!dataset) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const keys = Object.keys(dataset).filter((key) => key != "none");
|
||||
const selectedKey = keys[0];
|
||||
|
||||
if (!dataset[selectedKey]) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
name: selectedKey,
|
||||
img: dataset[selectedKey][0],
|
||||
|
||||
@@ -642,6 +642,7 @@ function DatasetGrid({
|
||||
filepath: `clips/${modelName}/dataset/${categoryName}/${image}`,
|
||||
name: "",
|
||||
}}
|
||||
showArea={false}
|
||||
selected={selectedImages.includes(image)}
|
||||
i18nLibrary="views/classificationModel"
|
||||
onClick={(data, _) => onClickImages([data.filename], true)}
|
||||
|
||||
Reference in New Issue
Block a user