Implement Wizard for Creating Classification Models (#20622)

* Implement extraction of images for classification state models

* Add object classification dataset preparation

* Add first step wizard

* Update i18n

* Add state classification image selection step

* Improve box handling

* Add object selector

* Improve object cropping implementation

* Fix state classification selection

* Finalize training and image selection step

* Cleanup

* Design optimizations

* Cleanup mobile styling

* Update no models screen

* Cleanups and fixes

* Fix bugs

* Improve model training and creation process

* Cleanup

* Dynamically add metrics for new model

* Add loading when hitting continue

* Improve image selection mechanism

* Remove unused translation keys

* Adjust wording

* Add retry button for image generation

* Make no models view more specific

* Adjust plus icon

* Adjust form label

* Start with correct type selected

* Cleanup sizing and more font colors

* Small tweaks

* Add tips and more info

* Cleanup dialog sizing

* Add cursor rule for frontend

* Cleanup

* remove underline

* Lazy loading
This commit is contained in:
Nicolas Mowen
2025-10-23 13:27:28 -06:00
committed by GitHub
parent 4df7793587
commit f5a57edcc9
18 changed files with 2450 additions and 79 deletions

View File

@@ -126,6 +126,7 @@ export const ClassificationCard = forwardRef<
imgClassName,
isMobile && "w-full",
)}
loading="lazy"
onLoad={() => setImageLoaded(true)}
src={`${baseUrl}${data.filepath}`}
/>

View File

@@ -7,58 +7,198 @@ import {
DialogHeader,
DialogTitle,
} from "../ui/dialog";
import { useState } from "react";
import { useReducer, useMemo } from "react";
import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine";
import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea";
import Step3ChooseExamples, {
Step3FormData,
} from "./wizard/Step3ChooseExamples";
import { cn } from "@/lib/utils";
import { isDesktop } from "react-device-detect";
const STEPS = [
"classificationWizard.steps.nameAndDefine",
"classificationWizard.steps.stateArea",
"classificationWizard.steps.chooseExamples",
"classificationWizard.steps.train",
const OBJECT_STEPS = [
"wizard.steps.nameAndDefine",
"wizard.steps.chooseExamples",
];
const STATE_STEPS = [
"wizard.steps.nameAndDefine",
"wizard.steps.stateArea",
"wizard.steps.chooseExamples",
];
type ClassificationModelWizardDialogProps = {
open: boolean;
onClose: () => void;
defaultModelType?: "state" | "object";
};
type WizardState = {
currentStep: number;
step1Data?: Step1FormData;
step2Data?: Step2FormData;
step3Data?: Step3FormData;
};
type WizardAction =
| { type: "NEXT_STEP"; payload?: Partial<WizardState> }
| { type: "PREVIOUS_STEP" }
| { type: "SET_STEP_1"; payload: Step1FormData }
| { type: "SET_STEP_2"; payload: Step2FormData }
| { type: "SET_STEP_3"; payload: Step3FormData }
| { type: "RESET" };
const initialState: WizardState = {
currentStep: 0,
};
function wizardReducer(state: WizardState, action: WizardAction): WizardState {
switch (action.type) {
case "SET_STEP_1":
return {
...state,
step1Data: action.payload,
currentStep: 1,
};
case "SET_STEP_2":
return {
...state,
step2Data: action.payload,
currentStep: 2,
};
case "SET_STEP_3":
return {
...state,
step3Data: action.payload,
currentStep: 3,
};
case "NEXT_STEP":
return {
...state,
...action.payload,
currentStep: state.currentStep + 1,
};
case "PREVIOUS_STEP":
return {
...state,
currentStep: Math.max(0, state.currentStep - 1),
};
case "RESET":
return initialState;
default:
return state;
}
}
export default function ClassificationModelWizardDialog({
open,
onClose,
defaultModelType,
}: ClassificationModelWizardDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
// step management
const [currentStep, _] = useState(0);
const [wizardState, dispatch] = useReducer(wizardReducer, initialState);
const steps = useMemo(() => {
if (!wizardState.step1Data) {
return OBJECT_STEPS;
}
return wizardState.step1Data.modelType === "state"
? STATE_STEPS
: OBJECT_STEPS;
}, [wizardState.step1Data]);
const handleStep1Next = (data: Step1FormData) => {
dispatch({ type: "SET_STEP_1", payload: data });
};
const handleStep2Next = (data: Step2FormData) => {
dispatch({ type: "SET_STEP_2", payload: data });
};
const handleBack = () => {
dispatch({ type: "PREVIOUS_STEP" });
};
const handleCancel = () => {
dispatch({ type: "RESET" });
onClose();
};
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) {
onClose;
handleCancel();
}
}}
>
<DialogContent
className="max-h-[90dvh] max-w-4xl overflow-y-auto"
className={cn(
"",
isDesktop &&
wizardState.currentStep == 0 &&
"max-h-[90%] overflow-y-auto xl:max-h-[80%]",
isDesktop &&
wizardState.currentStep > 0 &&
"max-h-[90%] max-w-[70%] overflow-y-auto xl:max-h-[80%]",
)}
onInteractOutside={(e) => {
e.preventDefault();
}}
>
<StepIndicator
steps={STEPS}
currentStep={currentStep}
steps={steps}
currentStep={wizardState.currentStep}
variant="dots"
className="mb-4 justify-start"
/>
<DialogHeader>
<DialogTitle>{t("wizard.title")}</DialogTitle>
{currentStep === 0 && (
<DialogDescription>{t("wizard.description")}</DialogDescription>
{wizardState.currentStep === 0 && (
<DialogDescription>
{t("wizard.step1.description")}
</DialogDescription>
)}
{wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "state" && (
<DialogDescription>
{t("wizard.step2.description")}
</DialogDescription>
)}
</DialogHeader>
<div className="pb-4">
<div className="size-full"></div>
{wizardState.currentStep === 0 && (
<Step1NameAndDefine
initialData={wizardState.step1Data}
defaultModelType={defaultModelType}
onNext={handleStep1Next}
onCancel={handleCancel}
/>
)}
{wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "state" && (
<Step2StateArea
initialData={wizardState.step2Data}
onNext={handleStep2Next}
onBack={handleBack}
/>
)}
{((wizardState.currentStep === 2 &&
wizardState.step1Data?.modelType === "state") ||
(wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "object")) &&
wizardState.step1Data && (
<Step3ChooseExamples
step1Data={wizardState.step1Data}
step2Data={wizardState.step2Data}
initialData={wizardState.step3Data}
onClose={onClose}
onBack={handleBack}
/>
)}
</div>
</DialogContent>
</Dialog>

View File

@@ -0,0 +1,498 @@
import { Button } from "@/components/ui/button";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { z } from "zod";
import { useTranslation } from "react-i18next";
import { useMemo } from "react";
import { LuX, LuPlus, LuInfo, LuExternalLink } from "react-icons/lu";
import useSWR from "swr";
import { FrigateConfig } from "@/types/frigateConfig";
import { getTranslatedLabel } from "@/utils/i18n";
import { useDocDomain } from "@/hooks/use-doc-domain";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
export type ModelType = "state" | "object";
export type ObjectClassificationType = "sub_label" | "attribute";
export type Step1FormData = {
modelName: string;
modelType: ModelType;
objectLabel?: string;
objectType?: ObjectClassificationType;
classes: string[];
};
type Step1NameAndDefineProps = {
initialData?: Partial<Step1FormData>;
defaultModelType?: "state" | "object";
onNext: (data: Step1FormData) => void;
onCancel: () => void;
};
export default function Step1NameAndDefine({
initialData,
defaultModelType,
onNext,
onCancel,
}: Step1NameAndDefineProps) {
const { t } = useTranslation(["views/classificationModel"]);
const { data: config } = useSWR<FrigateConfig>("config");
const { getLocaleDocUrl } = useDocDomain();
const objectLabels = useMemo(() => {
if (!config) return [];
const labels = new Set<string>();
Object.values(config.cameras).forEach((cameraConfig) => {
if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) {
return;
}
cameraConfig.objects.track.forEach((label) => {
if (!config.model.all_attributes.includes(label)) {
labels.add(label);
}
});
});
return [...labels].sort();
}, [config]);
const step1FormData = z
.object({
modelName: z
.string()
.min(1, t("wizard.step1.errors.nameRequired"))
.max(64, t("wizard.step1.errors.nameLength"))
.refine((value) => !/^\d+$/.test(value), {
message: t("wizard.step1.errors.nameOnlyNumbers"),
}),
modelType: z.enum(["state", "object"]),
objectLabel: z.string().optional(),
objectType: z.enum(["sub_label", "attribute"]).optional(),
classes: z
.array(z.string())
.min(1, t("wizard.step1.errors.classRequired"))
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
return nonEmpty.length >= 1;
},
{ message: t("wizard.step1.errors.classRequired") },
)
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
const unique = new Set(nonEmpty.map((c) => c.toLowerCase()));
return unique.size === nonEmpty.length;
},
{ message: t("wizard.step1.errors.classesUnique") },
),
})
.refine(
(data) => {
// State models require at least 2 classes
if (data.modelType === "state") {
const nonEmpty = data.classes.filter((c) => c.trim().length > 0);
return nonEmpty.length >= 2;
}
return true;
},
{
message: t("wizard.step1.errors.stateRequiresTwoClasses"),
path: ["classes"],
},
)
.refine(
(data) => {
if (data.modelType === "object") {
return data.objectLabel !== undefined && data.objectLabel !== "";
}
return true;
},
{
message: t("wizard.step1.errors.objectLabelRequired"),
path: ["objectLabel"],
},
)
.refine(
(data) => {
if (data.modelType === "object") {
return data.objectType !== undefined;
}
return true;
},
{
message: t("wizard.step1.errors.objectTypeRequired"),
path: ["objectType"],
},
);
const form = useForm<z.infer<typeof step1FormData>>({
resolver: zodResolver(step1FormData),
defaultValues: {
modelName: initialData?.modelName || "",
modelType: initialData?.modelType || defaultModelType || "state",
objectLabel: initialData?.objectLabel,
objectType: initialData?.objectType || "sub_label",
classes: initialData?.classes?.length ? initialData.classes : [""],
},
mode: "onChange",
});
const watchedClasses = form.watch("classes");
const watchedModelType = form.watch("modelType");
const watchedObjectType = form.watch("objectType");
const handleAddClass = () => {
const currentClasses = form.getValues("classes");
form.setValue("classes", [...currentClasses, ""], { shouldValidate: true });
};
const handleRemoveClass = (index: number) => {
const currentClasses = form.getValues("classes");
const newClasses = currentClasses.filter((_, i) => i !== index);
// Ensure at least one field remains (even if empty)
if (newClasses.length === 0) {
form.setValue("classes", [""], { shouldValidate: true });
} else {
form.setValue("classes", newClasses, { shouldValidate: true });
}
};
const onSubmit = (data: z.infer<typeof step1FormData>) => {
// Filter out empty classes
const filteredClasses = data.classes.filter((c) => c.trim().length > 0);
onNext({
...data,
classes: filteredClasses,
});
};
return (
<div className="space-y-6">
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="modelName"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.name")}
</FormLabel>
<FormControl>
<Input
className="h-8"
placeholder={t("wizard.step1.namePlaceholder")}
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="modelType"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.type")}
</FormLabel>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
defaultValue={field.value}
className="flex flex-col gap-4 pt-2"
>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedModelType === "state"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="state"
value="state"
/>
<Label className="cursor-pointer" htmlFor="state">
{t("wizard.step1.typeState")}
</Label>
</div>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedModelType === "object"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="object"
value="object"
/>
<Label className="cursor-pointer" htmlFor="object">
{t("wizard.step1.typeObject")}
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
{watchedModelType === "object" && (
<>
<FormField
control={form.control}
name="objectLabel"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.objectLabel")}
</FormLabel>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
>
<FormControl>
<SelectTrigger className="h-8">
<SelectValue
placeholder={t(
"wizard.step1.objectLabelPlaceholder",
)}
/>
</SelectTrigger>
</FormControl>
<SelectContent>
{objectLabels.map((label) => (
<SelectItem
key={label}
value={label}
className="cursor-pointer hover:bg-secondary-highlight"
>
{getTranslatedLabel(label)}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="objectType"
render={({ field }) => (
<FormItem>
<div className="flex items-center gap-1">
<FormLabel className="text-primary-variant">
{t("wizard.step1.classificationType")}
</FormLabel>
<Popover>
<PopoverTrigger asChild>
<Button
variant="ghost"
size="sm"
className="h-4 w-4 p-0"
>
<LuInfo className="size-3" />
</Button>
</PopoverTrigger>
<PopoverContent className="pointer-events-auto w-80 text-xs">
<div className="flex flex-col gap-2">
<div className="text-sm">
{t("wizard.step1.classificationTypeDesc")}
</div>
<div className="mt-3 flex items-center text-primary">
<a
href={getLocaleDocUrl(
"configuration/custom_classification/object_classification#classification-type",
)}
target="_blank"
rel="noopener noreferrer"
className="inline cursor-pointer"
>
{t("readTheDocumentation", { ns: "common" })}
<LuExternalLink className="ml-2 inline-flex size-3" />
</a>
</div>
</div>
</PopoverContent>
</Popover>
</div>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
defaultValue={field.value}
className="flex flex-col gap-4 pt-2"
>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "sub_label"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="sub_label"
value="sub_label"
/>
<Label className="cursor-pointer" htmlFor="sub_label">
{t("wizard.step1.classificationSubLabel")}
</Label>
</div>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "attribute"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="attribute"
value="attribute"
/>
<Label className="cursor-pointer" htmlFor="attribute">
{t("wizard.step1.classificationAttribute")}
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</>
)}
<div className="space-y-2">
<div className="flex items-center justify-between">
<div className="flex items-center gap-1">
<FormLabel className="text-primary-variant">
{t("wizard.step1.classes")}
</FormLabel>
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="sm" className="h-4 w-4 p-0">
<LuInfo className="size-3" />
</Button>
</PopoverTrigger>
<PopoverContent className="pointer-events-auto w-80 text-xs">
<div className="flex flex-col gap-2">
<div className="text-sm">
{watchedModelType === "state"
? t("wizard.step1.classesStateDesc")
: t("wizard.step1.classesObjectDesc")}
</div>
<div className="mt-3 flex items-center text-primary">
<a
href={getLocaleDocUrl(
watchedModelType === "state"
? "configuration/custom_classification/state_classification"
: "configuration/custom_classification/object_classification",
)}
target="_blank"
rel="noopener noreferrer"
className="inline cursor-pointer"
>
{t("readTheDocumentation", { ns: "common" })}
<LuExternalLink className="ml-2 inline-flex size-3" />
</a>
</div>
</div>
</PopoverContent>
</Popover>
</div>
<Button
type="button"
variant="secondary"
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
onClick={handleAddClass}
>
<LuPlus />
</Button>
</div>
<div className="space-y-2">
{watchedClasses.map((_, index) => (
<FormField
key={index}
control={form.control}
name={`classes.${index}`}
render={({ field }) => (
<FormItem>
<FormControl>
<div className="flex items-center gap-2">
<Input
className="h-8"
placeholder={t("wizard.step1.classPlaceholder")}
{...field}
/>
{watchedClasses.length > 1 && (
<Button
type="button"
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
onClick={() => handleRemoveClass(index)}
>
<LuX className="size-4" />
</Button>
)}
</div>
</FormControl>
</FormItem>
)}
/>
))}
</div>
{form.formState.errors.classes && (
<p className="text-sm font-medium text-destructive">
{form.formState.errors.classes.message}
</p>
)}
</div>
</form>
</Form>
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onCancel} className="sm:flex-1">
{t("button.cancel", { ns: "common" })}
</Button>
<Button
type="button"
onClick={form.handleSubmit(onSubmit)}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!form.formState.isValid}
>
{t("button.continue", { ns: "common" })}
</Button>
</div>
</div>
);
}

View File

@@ -0,0 +1,479 @@
import { Button } from "@/components/ui/button";
import { useTranslation } from "react-i18next";
import { useState, useMemo, useRef, useCallback, useEffect } from "react";
import useSWR from "swr";
import { FrigateConfig } from "@/types/frigateConfig";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { LuX, LuPlus } from "react-icons/lu";
import { Stage, Layer, Rect, Transformer } from "react-konva";
import Konva from "konva";
import { useResizeObserver } from "@/hooks/resize-observer";
import { useApiHost } from "@/api";
import { resolveCameraName } from "@/hooks/use-camera-friendly-name";
import Heading from "@/components/ui/heading";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
export type CameraAreaConfig = {
camera: string;
crop: [number, number, number, number];
};
export type Step2FormData = {
cameraAreas: CameraAreaConfig[];
};
type Step2StateAreaProps = {
initialData?: Partial<Step2FormData>;
onNext: (data: Step2FormData) => void;
onBack: () => void;
};
export default function Step2StateArea({
initialData,
onNext,
onBack,
}: Step2StateAreaProps) {
const { t } = useTranslation(["views/classificationModel"]);
const { data: config } = useSWR<FrigateConfig>("config");
const apiHost = useApiHost();
const [cameraAreas, setCameraAreas] = useState<CameraAreaConfig[]>(
initialData?.cameraAreas || [],
);
const [selectedCameraIndex, setSelectedCameraIndex] = useState<number>(0);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const [imageLoaded, setImageLoaded] = useState(false);
const containerRef = useRef<HTMLDivElement>(null);
const imageRef = useRef<HTMLImageElement>(null);
const stageRef = useRef<Konva.Stage>(null);
const rectRef = useRef<Konva.Rect>(null);
const transformerRef = useRef<Konva.Transformer>(null);
const [{ width: containerWidth }] = useResizeObserver(containerRef);
const availableCameras = useMemo(() => {
if (!config) return [];
const selectedCameraNames = cameraAreas.map((ca) => ca.camera);
return Object.entries(config.cameras)
.sort()
.filter(
([name, cam]) =>
cam.enabled &&
cam.enabled_in_config &&
!selectedCameraNames.includes(name),
)
.map(([name]) => ({
name,
displayName: resolveCameraName(config, name),
}));
}, [config, cameraAreas]);
const selectedCamera = useMemo(() => {
if (cameraAreas.length === 0) return null;
return cameraAreas[selectedCameraIndex];
}, [cameraAreas, selectedCameraIndex]);
const selectedCameraConfig = useMemo(() => {
if (!config || !selectedCamera) return null;
return config.cameras[selectedCamera.camera];
}, [config, selectedCamera]);
const imageSize = useMemo(() => {
if (!containerWidth || !selectedCameraConfig) {
return { width: 0, height: 0 };
}
const containerAspectRatio = 16 / 9;
const containerHeight = containerWidth / containerAspectRatio;
const cameraAspectRatio =
selectedCameraConfig.detect.width / selectedCameraConfig.detect.height;
// Fit camera within 16:9 container
let imageWidth, imageHeight;
if (cameraAspectRatio > containerAspectRatio) {
imageWidth = containerWidth;
imageHeight = imageWidth / cameraAspectRatio;
} else {
imageHeight = containerHeight;
imageWidth = imageHeight * cameraAspectRatio;
}
return { width: imageWidth, height: imageHeight };
}, [containerWidth, selectedCameraConfig]);
const handleAddCamera = useCallback(
(cameraName: string) => {
// Calculate a square crop in pixel space
const camera = config?.cameras[cameraName];
if (!camera) return;
const cameraAspect = camera.detect.width / camera.detect.height;
const cropSize = 0.3;
let x1, y1, x2, y2;
if (cameraAspect >= 1) {
const pixelSize = cropSize * camera.detect.height;
const normalizedWidth = pixelSize / camera.detect.width;
x1 = (1 - normalizedWidth) / 2;
y1 = (1 - cropSize) / 2;
x2 = x1 + normalizedWidth;
y2 = y1 + cropSize;
} else {
const pixelSize = cropSize * camera.detect.width;
const normalizedHeight = pixelSize / camera.detect.height;
x1 = (1 - cropSize) / 2;
y1 = (1 - normalizedHeight) / 2;
x2 = x1 + cropSize;
y2 = y1 + normalizedHeight;
}
const newArea: CameraAreaConfig = {
camera: cameraName,
crop: [x1, y1, x2, y2],
};
setCameraAreas([...cameraAreas, newArea]);
setSelectedCameraIndex(cameraAreas.length);
setIsPopoverOpen(false);
},
[cameraAreas, config],
);
const handleRemoveCamera = useCallback(
(index: number) => {
const newAreas = cameraAreas.filter((_, i) => i !== index);
setCameraAreas(newAreas);
if (selectedCameraIndex >= newAreas.length) {
setSelectedCameraIndex(Math.max(0, newAreas.length - 1));
}
},
[cameraAreas, selectedCameraIndex],
);
const handleCropChange = useCallback(
(crop: [number, number, number, number]) => {
const newAreas = [...cameraAreas];
newAreas[selectedCameraIndex] = {
...newAreas[selectedCameraIndex],
crop,
};
setCameraAreas(newAreas);
},
[cameraAreas, selectedCameraIndex],
);
useEffect(() => {
setImageLoaded(false);
}, [selectedCamera]);
useEffect(() => {
const rect = rectRef.current;
const transformer = transformerRef.current;
if (
rect &&
transformer &&
selectedCamera &&
imageSize.width > 0 &&
imageLoaded
) {
rect.scaleX(1);
rect.scaleY(1);
transformer.nodes([rect]);
transformer.getLayer()?.batchDraw();
}
}, [selectedCamera, imageSize, imageLoaded]);
const handleRectChange = useCallback(() => {
const rect = rectRef.current;
if (rect && imageSize.width > 0) {
const actualWidth = rect.width() * rect.scaleX();
const actualHeight = rect.height() * rect.scaleY();
// Average dimensions to maintain perfect square
const size = (actualWidth + actualHeight) / 2;
rect.width(size);
rect.height(size);
rect.scaleX(1);
rect.scaleY(1);
const x1 = rect.x() / imageSize.width;
const y1 = rect.y() / imageSize.height;
const x2 = (rect.x() + size) / imageSize.width;
const y2 = (rect.y() + size) / imageSize.height;
handleCropChange([x1, y1, x2, y2]);
}
}, [imageSize, handleCropChange]);
const handleContinue = useCallback(() => {
onNext({ cameraAreas });
}, [cameraAreas, onNext]);
const canContinue = cameraAreas.length > 0;
return (
<div className="flex flex-col gap-4">
<div
className={cn(
"flex gap-4 overflow-hidden",
isMobile ? "flex-col" : "flex-row",
)}
>
<div
className={cn(
"flex flex-shrink-0 flex-col gap-2 overflow-y-auto rounded-lg bg-secondary p-4",
isMobile ? "w-full" : "w-64",
)}
>
<div className="flex items-center justify-between">
<h3 className="text-sm font-medium">{t("wizard.step2.cameras")}</h3>
{availableCameras.length > 0 ? (
<Popover
open={isPopoverOpen}
onOpenChange={setIsPopoverOpen}
modal={true}
>
<PopoverTrigger asChild>
<Button
type="button"
variant="secondary"
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
aria-label="Add camera"
>
<LuPlus />
</Button>
</PopoverTrigger>
<PopoverContent
className="scrollbar-container w-64 border bg-background p-3 shadow-lg"
align="start"
sideOffset={5}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<div className="flex flex-col gap-2">
<Heading as="h4" className="text-sm text-primary-variant">
{t("wizard.step2.selectCamera")}
</Heading>
<div className="scrollbar-container flex max-h-[30vh] flex-col gap-1 overflow-y-auto">
{availableCameras.map((cam) => (
<Button
key={cam.name}
type="button"
variant="ghost"
size="sm"
className="h-auto justify-start p-2 capitalize text-primary"
onClick={() => {
handleAddCamera(cam.name);
}}
>
{cam.displayName}
</Button>
))}
</div>
</div>
</PopoverContent>
</Popover>
) : (
<Button
variant="secondary"
className="size-6 cursor-not-allowed rounded-md bg-muted p-1 text-muted-foreground"
disabled
>
<LuPlus />
</Button>
)}
</div>
<div className="flex flex-col gap-1">
{cameraAreas.map((area, index) => {
const isSelected = index === selectedCameraIndex;
const displayName = resolveCameraName(config, area.camera);
return (
<div
key={area.camera}
className={`flex items-center justify-between rounded-md p-2 ${
isSelected
? "bg-selected/20 ring-1 ring-selected"
: "hover:bg-secondary/50"
} cursor-pointer`}
onClick={() => setSelectedCameraIndex(index)}
>
<span className="text-sm capitalize">{displayName}</span>
<Button
type="button"
variant="ghost"
size="sm"
className="h-6 w-6 p-0"
onClick={(e) => {
e.stopPropagation();
handleRemoveCamera(index);
}}
>
<LuX className="size-4" />
</Button>
</div>
);
})}
</div>
{cameraAreas.length === 0 && (
<div className="flex flex-1 items-center justify-center text-center text-sm text-muted-foreground">
{t("wizard.step2.noCameras")}
</div>
)}
</div>
<div className="flex flex-1 items-center justify-center overflow-hidden rounded-lg p-4">
<div
ref={containerRef}
className="flex items-center justify-center"
style={{
width: "100%",
aspectRatio: "16 / 9",
maxHeight: "100%",
}}
>
{selectedCamera && selectedCameraConfig && imageSize.width > 0 ? (
<div
style={{
width: imageSize.width,
height: imageSize.height,
position: "relative",
}}
>
<img
ref={imageRef}
src={`${apiHost}api/${selectedCamera.camera}/latest.jpg?h=500`}
alt={resolveCameraName(config, selectedCamera.camera)}
className="h-full w-full object-contain"
onLoad={() => setImageLoaded(true)}
/>
<Stage
ref={stageRef}
width={imageSize.width}
height={imageSize.height}
className="absolute inset-0"
>
<Layer>
<Rect
ref={rectRef}
x={selectedCamera.crop[0] * imageSize.width}
y={selectedCamera.crop[1] * imageSize.height}
width={
(selectedCamera.crop[2] - selectedCamera.crop[0]) *
imageSize.width
}
height={
(selectedCamera.crop[3] - selectedCamera.crop[1]) *
imageSize.height
}
stroke="#3b82f6"
strokeWidth={2}
fill="rgba(59, 130, 246, 0.1)"
draggable
dragBoundFunc={(pos) => {
const rect = rectRef.current;
if (!rect) return pos;
const size = rect.width();
const x = Math.max(
0,
Math.min(pos.x, imageSize.width - size),
);
const y = Math.max(
0,
Math.min(pos.y, imageSize.height - size),
);
return { x, y };
}}
onDragEnd={handleRectChange}
onTransformEnd={handleRectChange}
/>
<Transformer
ref={transformerRef}
rotateEnabled={false}
enabledAnchors={[
"top-left",
"top-right",
"bottom-left",
"bottom-right",
]}
boundBoxFunc={(_oldBox, newBox) => {
const minSize = 50;
const maxSize = Math.min(
imageSize.width,
imageSize.height,
);
// Clamp dimensions to stage bounds first
const clampedWidth = Math.max(
minSize,
Math.min(newBox.width, maxSize),
);
const clampedHeight = Math.max(
minSize,
Math.min(newBox.height, maxSize),
);
// Enforce square using average
const size = (clampedWidth + clampedHeight) / 2;
// Clamp position to keep square within bounds
const x = Math.max(
0,
Math.min(newBox.x, imageSize.width - size),
);
const y = Math.max(
0,
Math.min(newBox.y, imageSize.height - size),
);
return {
...newBox,
x,
y,
width: size,
height: size,
};
}}
/>
</Layer>
</Stage>
</div>
) : (
<div className="flex items-center justify-center text-muted-foreground">
{t("wizard.step2.selectCameraPrompt")}
</div>
)}
</div>
</div>
</div>
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Button
type="button"
onClick={handleContinue}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!canContinue}
>
{t("button.continue", { ns: "common" })}
</Button>
</div>
</div>
);
}

View File

@@ -0,0 +1,444 @@
import { Button } from "@/components/ui/button";
import { useTranslation } from "react-i18next";
import { useState, useEffect, useCallback, useMemo } from "react";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import axios from "axios";
import { toast } from "sonner";
import { Step1FormData } from "./Step1NameAndDefine";
import { Step2FormData } from "./Step2StateArea";
import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
export type Step3FormData = {
examplesGenerated: boolean;
imageClassifications?: { [imageName: string]: string };
};
type Step3ChooseExamplesProps = {
step1Data: Step1FormData;
step2Data?: Step2FormData;
initialData?: Partial<Step3FormData>;
onClose: () => void;
onBack: () => void;
};
export default function Step3ChooseExamples({
step1Data,
step2Data,
initialData,
onClose,
onBack,
}: Step3ChooseExamplesProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [isGenerating, setIsGenerating] = useState(false);
const [hasGenerated, setHasGenerated] = useState(
initialData?.examplesGenerated || false,
);
const [imageClassifications, setImageClassifications] = useState<{
[imageName: string]: string;
}>(initialData?.imageClassifications || {});
const [isTraining, setIsTraining] = useState(false);
const [isProcessing, setIsProcessing] = useState(false);
const [currentClassIndex, setCurrentClassIndex] = useState(0);
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
);
const unknownImages = useMemo(() => {
if (!trainImages) return [];
return trainImages;
}, [trainImages]);
const toggleImageSelection = useCallback((imageName: string) => {
setSelectedImages((prev) => {
const newSet = new Set(prev);
if (newSet.has(imageName)) {
newSet.delete(imageName);
} else {
newSet.add(imageName);
}
return newSet;
});
}, []);
// Get all classes (excluding "none" - it will be auto-assigned)
const allClasses = useMemo(() => {
return [...step1Data.classes];
}, [step1Data.classes]);
const currentClass = allClasses[currentClassIndex];
const processClassificationsAndTrain = useCallback(
async (classifications: { [imageName: string]: string }) => {
// Step 1: Create config for the new model
const modelConfig: {
enabled: boolean;
name: string;
threshold: number;
state_config?: {
cameras: Record<string, { crop: number[] }>;
motion: boolean;
};
object_config?: { objects: string[]; classification_type: string };
} = {
enabled: true,
name: step1Data.modelName,
threshold: 0.8,
};
if (step1Data.modelType === "state") {
// State model config
const cameras: Record<string, { crop: number[] }> = {};
step2Data?.cameraAreas.forEach((area) => {
cameras[area.camera] = {
crop: area.crop,
};
});
modelConfig.state_config = {
cameras,
motion: true,
};
} else {
// Object model config
modelConfig.object_config = {
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
classification_type: step1Data.objectType || "sub_label",
} as { objects: string[]; classification_type: string };
}
// Update config via config API
await axios.put("/config/set", {
requires_restart: 0,
update_topic: `config/classification/custom/${step1Data.modelName}`,
config_data: {
classification: {
custom: {
[step1Data.modelName]: modelConfig,
},
},
},
});
// Step 2: Classify each image by moving it to the correct category folder
const categorizePromises = Object.entries(classifications).map(
([imageName, className]) => {
if (!className) return Promise.resolve();
return axios.post(
`/classification/${step1Data.modelName}/dataset/categorize`,
{
training_file: imageName,
category: className === "none" ? "none" : className,
},
);
},
);
await Promise.all(categorizePromises);
// Step 3: Kick off training
await axios.post(`/classification/${step1Data.modelName}/train`);
toast.success(t("wizard.step3.trainingStarted"));
setIsTraining(true);
},
[step1Data, step2Data, t],
);
const handleContinueClassification = useCallback(async () => {
// Mark selected images with current class
const newClassifications = { ...imageClassifications };
selectedImages.forEach((imageName) => {
newClassifications[imageName] = currentClass;
});
// Check if we're on the last class to select
const isLastClass = currentClassIndex === allClasses.length - 1;
if (isLastClass) {
// Assign remaining unclassified images
unknownImages.slice(0, 24).forEach((imageName) => {
if (!newClassifications[imageName]) {
// For state models with 2 classes, assign to the last class
// For object models, assign to "none"
if (step1Data.modelType === "state" && allClasses.length === 2) {
newClassifications[imageName] = allClasses[allClasses.length - 1];
} else {
newClassifications[imageName] = "none";
}
}
});
// All done, trigger training immediately
setImageClassifications(newClassifications);
setIsProcessing(true);
try {
await processClassificationsAndTrain(newClassifications);
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to classify images";
toast.error(
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
);
setIsProcessing(false);
}
} else {
// Move to next class
setImageClassifications(newClassifications);
setCurrentClassIndex((prev) => prev + 1);
setSelectedImages(new Set());
}
}, [
selectedImages,
currentClass,
currentClassIndex,
allClasses,
imageClassifications,
unknownImages,
step1Data,
processClassificationsAndTrain,
t,
]);
const generateExamples = useCallback(async () => {
setIsGenerating(true);
try {
if (step1Data.modelType === "state") {
// For state models, use cameras and crop areas
if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) {
toast.error(t("wizard.step3.errors.noCameras"));
setIsGenerating(false);
return;
}
const cameras: { [key: string]: [number, number, number, number] } = {};
step2Data.cameraAreas.forEach((area) => {
cameras[area.camera] = area.crop;
});
await axios.post("/classification/generate_examples/state", {
model_name: step1Data.modelName,
cameras,
});
} else {
// For object models, use label
if (!step1Data.objectLabel) {
toast.error(t("wizard.step3.errors.noObjectLabel"));
setIsGenerating(false);
return;
}
// For now, use all enabled cameras
// TODO: In the future, we might want to let users select specific cameras
await axios.post("/classification/generate_examples/object", {
model_name: step1Data.modelName,
label: step1Data.objectLabel,
});
}
setHasGenerated(true);
toast.success(t("wizard.step3.generateSuccess"));
await refreshTrainImages();
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to generate examples";
toast.error(
t("wizard.step3.errors.generateFailed", { error: errorMessage }),
);
} finally {
setIsGenerating(false);
}
}, [step1Data, step2Data, t, refreshTrainImages]);
useEffect(() => {
if (!hasGenerated && !isGenerating) {
generateExamples();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const handleContinue = useCallback(async () => {
setIsProcessing(true);
try {
await processClassificationsAndTrain(imageClassifications);
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to classify images";
toast.error(
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
);
setIsProcessing(false);
}
}, [imageClassifications, processClassificationsAndTrain, t]);
const unclassifiedImages = useMemo(() => {
if (!unknownImages) return [];
const images = unknownImages.slice(0, 24);
// Only filter if we have any classifications
if (Object.keys(imageClassifications).length === 0) {
return images;
}
return images.filter((img) => !imageClassifications[img]);
}, [unknownImages, imageClassifications]);
const allImagesClassified = useMemo(() => {
return unclassifiedImages.length === 0;
}, [unclassifiedImages]);
return (
<div className="flex flex-col gap-6">
{isTraining ? (
<div className="flex flex-col items-center gap-6 py-12">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.training.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.training.description")}
</p>
</div>
<Button onClick={onClose} variant="select" className="mt-4">
{t("button.close", { ns: "common" })}
</Button>
</div>
) : isGenerating ? (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.generating.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.generating.description")}
</p>
</div>
</div>
) : hasGenerated ? (
<div className="flex flex-col gap-4">
{!allImagesClassified && (
<div className="text-center">
<h3 className="text-lg font-medium">
{t("wizard.step3.selectImagesPrompt", {
className: currentClass,
})}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.selectImagesDescription")}
</p>
</div>
)}
<div
className={cn(
"rounded-lg bg-secondary/30 p-4",
isMobile && "max-h-[60vh] overflow-y-auto",
)}
>
{!unknownImages || unknownImages.length === 0 ? (
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
{t("wizard.step3.noImages")}
</p>
<Button onClick={generateExamples} variant="select">
{t("wizard.step3.retryGenerate")}
</Button>
</div>
) : allImagesClassified && isProcessing ? (
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" />
<p className="text-lg font-medium">
{t("wizard.step3.classifying")}
</p>
</div>
) : (
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
{unclassifiedImages.map((imageName, index) => {
const isSelected = selectedImages.has(imageName);
return (
<div
key={imageName}
className={cn(
"aspect-square cursor-pointer overflow-hidden rounded-lg border-2 bg-background transition-all",
isSelected && "border-selected ring-2 ring-selected",
)}
onClick={() => toggleImageSelection(imageName)}
>
<img
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
alt={`Example ${index + 1}`}
className="h-full w-full object-cover"
/>
</div>
);
})}
</div>
)}
</div>
</div>
) : (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<p className="text-sm text-destructive">
{t("wizard.step3.errors.generationFailed")}
</p>
<Button onClick={generateExamples} variant="select">
{t("wizard.step3.retryGenerate")}
</Button>
</div>
)}
{!isTraining && (
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Button
type="button"
onClick={
allImagesClassified
? handleContinue
: handleContinueClassification
}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!hasGenerated || isGenerating || isProcessing}
>
{isProcessing && <ActivityIndicator className="size-4" />}
{t("button.continue", { ns: "common" })}
</Button>
</div>
)}
</div>
);
}

View File

@@ -10,11 +10,14 @@ import {
CustomClassificationModelConfig,
FrigateConfig,
} from "@/types/frigateConfig";
import { useMemo, useState } from "react";
import { useEffect, useMemo, useState } from "react";
import { isMobile } from "react-device-detect";
import { useTranslation } from "react-i18next";
import { FaFolderPlus } from "react-icons/fa";
import { MdModelTraining } from "react-icons/md";
import useSWR from "swr";
import Heading from "@/components/ui/heading";
import { useOverlayState } from "@/hooks/use-overlay-state";
const allModelTypes = ["objects", "states"] as const;
type ModelType = (typeof allModelTypes)[number];
@@ -26,11 +29,24 @@ export default function ModelSelectionView({
onClick,
}: ModelSelectionViewProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [page, setPage] = useState<ModelType>("objects");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
const { data: config } = useSWR<FrigateConfig>("config", {
revalidateOnFocus: false,
});
const [page, setPage] = useOverlayState<ModelType>("objects", "objects");
const [pageToggle, setPageToggle] = useOptimisticState(
page || "objects",
setPage,
100,
);
const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
"config",
{
revalidateOnFocus: false,
},
);
// title
useEffect(() => {
document.title = t("documentTitle");
}, [t]);
// data
@@ -64,15 +80,15 @@ export default function ModelSelectionView({
return <ActivityIndicator />;
}
if (classificationConfigs.length == 0) {
return <div>You need to setup a custom model configuration.</div>;
}
return (
<div className="flex size-full flex-col p-2">
<ClassificationModelWizardDialog
open={newModel}
onClose={() => setNewModel(false)}
defaultModelType={pageToggle === "objects" ? "object" : "state"}
onClose={() => {
setNewModel(false);
refreshConfig();
}}
/>
<div className="flex h-12 w-full items-center justify-between">
@@ -84,7 +100,6 @@ export default function ModelSelectionView({
value={pageToggle}
onValueChange={(value: ModelType) => {
if (value) {
// Restrict viewer navigation
setPageToggle(value);
}
}}
@@ -117,13 +132,46 @@ export default function ModelSelectionView({
</div>
</div>
<div className="flex size-full gap-2 p-2">
{selectedClassificationConfigs.map((config) => (
<ModelCard
key={config.name}
config={config}
onClick={() => onClick(config)}
{selectedClassificationConfigs.length === 0 ? (
<NoModelsView
onCreateModel={() => setNewModel(true)}
modelType={pageToggle}
/>
))}
) : (
selectedClassificationConfigs.map((config) => (
<ModelCard
key={config.name}
config={config}
onClick={() => onClick(config)}
/>
))
)}
</div>
</div>
);
}
function NoModelsView({
onCreateModel,
modelType,
}: {
onCreateModel: () => void;
modelType: ModelType;
}) {
const { t } = useTranslation(["views/classificationModel"]);
const typeKey = modelType === "objects" ? "object" : "state";
return (
<div className="flex size-full items-center justify-center">
<div className="flex flex-col items-center gap-2">
<MdModelTraining className="size-8" />
<Heading as="h4">{t(`noModels.${typeKey}.title`)}</Heading>
<div className="mb-3 text-center text-secondary-foreground">
{t(`noModels.${typeKey}.description`)}
</div>
<Button size="sm" variant="select" onClick={onCreateModel}>
{t(`noModels.${typeKey}.buttonText`)}
</Button>
</div>
</div>
);
@@ -139,13 +187,17 @@ function ModelCard({ config, onClick }: ModelCardProps) {
}>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
const coverImage = useMemo(() => {
if (!dataset?.length) {
if (!dataset) {
return undefined;
}
const keys = Object.keys(dataset).filter((key) => key != "none");
const selectedKey = keys[0];
if (!dataset[selectedKey]) {
return undefined;
}
return {
name: selectedKey,
img: dataset[selectedKey][0],

View File

@@ -642,6 +642,7 @@ function DatasetGrid({
filepath: `clips/${modelName}/dataset/${categoryName}/${image}`,
name: "",
}}
showArea={false}
selected={selectedImages.includes(image)}
i18nLibrary="views/classificationModel"
onClick={(data, _) => onClickImages([data.filename], true)}