Classification Model UI (#18571)

* Setup basic training structure

* Build out route

* Handle model configs

* Add image fetch APIs

* Implement model training screen with dataset selection

* Implement viewing of training images

* Adjust directories

* Implement viewing of images

* Add support for deleting images

* Implement full deletion

* Implement classification model training

* Improve naming

* More renaming

* Improve layout

* Reduce logging

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-06-04 17:09:55 -06:00 committed by GitHub
parent eb1fe9fe20
commit 85d721eb6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1156 additions and 8 deletions

View File

@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import (
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.const import FACE_DIR, MODEL_CACHE_DIR
from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
@ -431,6 +431,50 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
# custom classification training
@router.get("/classification/{name}/dataset")
def get_classification_dataset(name: str):
dataset_dict: dict[str, list[str]] = {}
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir):
return JSONResponse(status_code=200, content={})
for name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, name)
if not os.path.isdir(category_dir):
continue
dataset_dict[name] = []
for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir),
):
dataset_dict[name].append(file)
return JSONResponse(status_code=200, content=dataset_dict)
@router.get("/classification/{name}/train")
def get_classification_images(name: str):
train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "train")
if not os.path.exists(train_dir):
return JSONResponse(status_code=200, content=[])
return JSONResponse(
status_code=200,
content=list(
filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(train_dir),
)
),
)
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
@ -448,10 +492,131 @@ async def train_configured_model(
status_code=404,
)
background_tasks.add_task(
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
)
background_tasks.add_task(train_classification_model, name)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
)
@router.post(
"/classification/{name}/dataset/{category}/delete",
dependencies=[Depends(require_role(["admin"]))],
)
def delete_classification_dataset_images(
request: Request, name: str, category: str, body: dict = None
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
json: dict[str, Any] = body or {}
list_of_ids = json.get("ids", "")
folder = os.path.join(
CLIPS_DIR, sanitize_filename(name), "dataset", sanitize_filename(category)
)
for id in list_of_ids:
file_path = os.path.join(folder, id)
if os.path.isfile(file_path):
os.unlink(file_path)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)
@router.post(
"/classification/{name}/dataset/categorize",
dependencies=[Depends(require_role(["admin"]))],
)
def categorize_classification_image(request: Request, name: str, body: dict = None):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
json: dict[str, Any] = body or {}
category = sanitize_filename(json.get("category", ""))
training_file_name = sanitize_filename(json.get("training_file", ""))
training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name)
if training_file_name and not os.path.isfile(training_file):
return JSONResponse(
content=(
{
"success": False,
"message": f"Invalid filename or no file exists: {training_file_name}",
}
),
status_code=404,
)
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category)
if not os.path.exists(new_file_folder):
os.mkdir(new_file_folder)
# use opencv because webp images can not be used to train
img = cv2.imread(training_file)
cv2.imwrite(os.path.join(new_file_folder, new_name), img)
os.unlink(training_file)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)
@router.post(
"/classification/{name}/train/delete",
dependencies=[Depends(require_role(["admin"]))],
)
def delete_classification_train_images(request: Request, name: str, body: dict = None):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
json: dict[str, Any] = body or {}
list_of_ids = json.get("ids", "")
folder = os.path.join(CLIPS_DIR, sanitize_filename(name), "train")
for id in list_of_ids:
file_path = os.path.join(folder, id)
if os.path.isfile(file_path):
os.unlink(file_path)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)

View File

@ -42,7 +42,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.model_config = model_config
self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None

View File

@ -1,5 +1,6 @@
"""Util for classification models."""
import logging
import os
import cv2
@ -9,6 +10,8 @@ from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
@ -35,9 +38,10 @@ def generate_representative_dataset_factory(dataset_dir: str):
@staticmethod
def train_classification_model(model_dir: str) -> bool:
def train_classification_model(model_name: str) -> bool:
"""Train a classification model."""
dataset_dir = os.path.join(model_dir, "dataset")
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
num_classes = len(
[
d
@ -46,6 +50,8 @@ def train_classification_model(model_dir: str) -> bool:
]
)
tf.get_logger().setLevel(logging.ERROR)
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),

View File

@ -0,0 +1,49 @@
{
"button": {
"deleteClassificationAttempts": "Delete Classification Images",
"renameCategory": "Rename Class",
"deleteCategory": "Delete Class",
"deleteImages": "Delete Images"
},
"toast": {
"success": {
"deletedCategory": "Deleted Class",
"deletedImage": "Deleted Images",
"categorizedImage": "Successfully Classified Image"
},
"error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
}
},
"deleteCategory": {
"title": "Delete Class",
"desc": "Are you sure you want to delete the class {{name}}? This will permanently delete all associated images and require re-training the model."
},
"deleteDatasetImages": {
"title": "Delete Dataset Images",
"desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model."
},
"deleteTrainImages": {
"title": "Delete Train Images",
"desc": "Are you sure you want to delete {{count}} images? This action cannot be undone."
},
"renameCategory": {
"title": "Rename Class",
"desc": "Enter a new name for {{name}}. You will be required to retrain the model for the name change to take affect."
},
"description": {
"invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens."
},
"train": {
"title": "Train",
"aria": "Select Train"
},
"categories": "Classes",
"createCategory": {
"new": "Create New Class"
},
"categorizeImageAs": "Classify Image As:",
"categorizeImage": "Classify Image"
}

View File

@ -24,6 +24,7 @@ const System = lazy(() => import("@/pages/System"));
const Settings = lazy(() => import("@/pages/Settings"));
const UIPlayground = lazy(() => import("@/pages/UIPlayground"));
const FaceLibrary = lazy(() => import("@/pages/FaceLibrary"));
const Classification = lazy(() => import("@/pages/ClassificationModel"));
const Logs = lazy(() => import("@/pages/Logs"));
const AccessDenied = lazy(() => import("@/pages/AccessDenied"));
@ -76,6 +77,7 @@ function DefaultAppView() {
<Route path="/config" element={<ConfigEditor />} />
<Route path="/logs" element={<Logs />} />
<Route path="/faces" element={<FaceLibrary />} />
<Route path="/classification" element={<Classification />} />
<Route path="/playground" element={<UIPlayground />} />
</Route>
<Route path="/unauthorized" element={<AccessDenied />} />

View File

@ -0,0 +1,155 @@
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "@/components/ui/drawer";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { isDesktop, isMobile } from "react-device-detect";
import { LuPlus } from "react-icons/lu";
import { useTranslation } from "react-i18next";
import { cn } from "@/lib/utils";
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import TextEntryDialog from "./dialog/TextEntryDialog";
import { Button } from "../ui/button";
import { MdCategory } from "react-icons/md";
import axios from "axios";
import { toast } from "sonner";
type ClassificationSelectionDialogProps = {
className?: string;
classes: string[];
modelName: string;
image: string;
onRefresh: () => void;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
className,
classes,
modelName,
image,
onRefresh,
children,
}: ClassificationSelectionDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const onCategorizeImage = useCallback(
(category: string) => {
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: image,
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.categorizedImage"), {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
});
},
[modelName, image, onRefresh, t],
);
const isChildButton = useMemo(
() => React.isValidElement(children) && children.type === Button,
[children],
);
// control
const [newFace, setNewFace] = useState(false);
// components
const Selector = isDesktop ? DropdownMenu : Drawer;
const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger;
const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent;
const SelectorItem = isDesktop
? DropdownMenuItem
: (props: React.HTMLAttributes<HTMLDivElement>) => (
<DrawerClose asChild>
<div {...props} className={cn(props.className, "my-2")} />
</DrawerClose>
);
return (
<div className={className ?? ""}>
{newFace && (
<TextEntryDialog
open={true}
setOpen={setNewFace}
title={t("createCategory.new")}
onSave={(newCat) => onCategorizeImage(newCat)}
/>
)}
<Tooltip>
<Selector>
<SelectorTrigger asChild>
<TooltipTrigger asChild={isChildButton}>{children}</TooltipTrigger>
</SelectorTrigger>
<SelectorContent
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
>
{isMobile && (
<DrawerHeader className="sr-only">
<DrawerTitle>Details</DrawerTitle>
<DrawerDescription>Details</DrawerDescription>
</DrawerHeader>
)}
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
<div
className={cn(
"flex max-h-[40dvh] flex-col overflow-y-auto",
isMobile && "gap-2 pb-4",
)}
>
<SelectorItem
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => setNewFace(true)}
>
<LuPlus />
{t("createCategory.new")}
</SelectorItem>
{classes.sort().map((category) => (
<SelectorItem
key={category}
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => onCategorizeImage(category)}
>
<MdCategory />
{category}
</SelectorItem>
))}
</div>
</SelectorContent>
</Selector>
<TooltipContent>{t("categorizeImage")}</TooltipContent>
</Tooltip>
</div>
);
}

View File

@ -6,7 +6,7 @@ import { isDesktop } from "react-device-detect";
import { FaCompactDisc, FaVideo } from "react-icons/fa";
import { IoSearch } from "react-icons/io5";
import { LuConstruction } from "react-icons/lu";
import { MdVideoLibrary } from "react-icons/md";
import { MdCategory, MdVideoLibrary } from "react-icons/md";
import { TbFaceId } from "react-icons/tb";
import useSWR from "swr";
@ -16,6 +16,7 @@ export const ID_EXPLORE = 3;
export const ID_EXPORT = 4;
export const ID_PLAYGROUND = 5;
export const ID_FACE_LIBRARY = 6;
export const ID_CLASSIFICATION = 7;
export default function useNavigation(
variant: "primary" | "secondary" = "primary",
@ -71,6 +72,14 @@ export default function useNavigation(
url: "/faces",
enabled: isDesktop && config?.face_recognition.enabled,
},
{
id: ID_CLASSIFICATION,
variant,
icon: MdCategory,
title: "menu.classification",
url: "/classification",
enabled: isDesktop,
},
] as NavData[],
[config?.face_recognition?.enabled, variant],
);

View File

@ -0,0 +1,18 @@
import { useOverlayState } from "@/hooks/use-overlay-state";
import { CustomClassificationModelConfig } from "@/types/frigateConfig";
import ModelSelectionView from "@/views/classification/ModelSelectionView";
import ModelTrainingView from "@/views/classification/ModelTrainingView";
export default function ClassificationModelPage() {
// training
const [model, setModel] = useOverlayState<CustomClassificationModelConfig>(
"classificationModel",
);
if (model == undefined) {
return <ModelSelectionView onClick={setModel} />;
}
return <ModelTrainingView model={model} />;
}

View File

@ -279,6 +279,23 @@ export type CameraStreamingSettings = {
volume: number;
};
export type CustomClassificationModelConfig = {
enabled: boolean;
name: string;
object_config: null | {
objects: string[];
};
state_config: null | {
cameras: {
[cameraName: string]: {
crop: [number, number, number, number];
threshold: number;
};
};
motion: boolean;
};
};
export type GroupStreamingSettings = {
[cameraName: string]: CameraStreamingSettings;
};
@ -316,6 +333,9 @@ export interface FrigateConfig {
enabled: boolean;
threshold: number;
};
custom: {
[modelKey: string]: CustomClassificationModelConfig;
};
};
database: {

View File

@ -0,0 +1,63 @@
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { cn } from "@/lib/utils";
import {
CustomClassificationModelConfig,
FrigateConfig,
} from "@/types/frigateConfig";
import { useMemo } from "react";
import { isMobile } from "react-device-detect";
import useSWR from "swr";
type ModelSelectionViewProps = {
onClick: (model: CustomClassificationModelConfig) => void;
};
export default function ModelSelectionView({
onClick,
}: ModelSelectionViewProps) {
const { data: config } = useSWR<FrigateConfig>("config", {
revalidateOnFocus: false,
});
const classificationConfigs = useMemo(() => {
if (!config) {
return [];
}
return Object.values(config.classification.custom);
}, [config]);
if (!config) {
return <ActivityIndicator />;
}
if (classificationConfigs.length == 0) {
return <div>You need to setup a custom model configuration.</div>;
}
return (
<div className="flex size-full gap-2 p-2">
{classificationConfigs.map((config) => (
<div
key={config.name}
className={cn(
"flex h-52 cursor-pointer flex-col gap-2 rounded-lg bg-card p-2 outline outline-[3px]",
"outline-transparent duration-500",
isMobile && "w-full",
)}
onClick={() => onClick(config)}
onContextMenu={() => {
// e.stopPropagation();
// e.preventDefault();
// handleClickEvent(true);
}}
>
<div className="size-48"></div>
<div className="smart-capitalize">
{config.name} ({config.state_config != null ? "State" : "Object"}{" "}
Classification)
</div>
</div>
))}
</div>
);
}

View File

@ -0,0 +1,661 @@
import { baseUrl } from "@/api/baseUrl";
import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog";
import { Button, buttonVariants } from "@/components/ui/button";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Toaster } from "@/components/ui/sonner";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import useKeyboardListener from "@/hooks/use-keyboard-listener";
import useOptimisticState from "@/hooks/use-optimistic-state";
import { cn } from "@/lib/utils";
import { CustomClassificationModelConfig } from "@/types/frigateConfig";
import { TooltipPortal } from "@radix-ui/react-tooltip";
import axios from "axios";
import { useCallback, useEffect, useMemo, useState } from "react";
import { isDesktop, isMobile } from "react-device-detect";
import { Trans, useTranslation } from "react-i18next";
import { LuPencil, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
};
export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
// dataset
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
`classification/${model.name}/train`,
);
const { data: dataset, mutate: refreshDataset } = useSWR<{
[id: string]: string[];
}>(`classification/${model.name}/dataset`);
// image multiselect
const [selectedImages, setSelectedImages] = useState<string[]>([]);
const onClickImages = useCallback(
(images: string[], ctrl: boolean) => {
if (selectedImages.length == 0 && !ctrl) {
return;
}
let newSelectedImages = [...selectedImages];
images.forEach((imageId) => {
const index = newSelectedImages.indexOf(imageId);
if (index != -1) {
if (selectedImages.length == 1) {
newSelectedImages = [];
} else {
const copy = [
...newSelectedImages.slice(0, index),
...newSelectedImages.slice(index + 1),
];
newSelectedImages = copy;
}
} else {
newSelectedImages.push(imageId);
}
});
setSelectedImages(newSelectedImages);
},
[selectedImages, setSelectedImages],
);
// actions
const trainModel = useCallback(() => {
axios.post(`classification/${model.name}/train`);
}, [model]);
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
null,
);
const onDelete = useCallback(
(ids: string[], isName: boolean = false) => {
const api =
pageToggle == "train"
? `/classification/${model.name}/train/delete`
: `/classification/${model.name}/dataset/${pageToggle}/delete`;
axios
.post(api, { ids })
.then((resp) => {
setSelectedImages([]);
if (resp.status == 200) {
if (isName) {
toast.success(
t("toast.success.deletedCategory", { count: ids.length }),
{
position: "top-center",
},
);
} else {
toast.success(
t("toast.success.deletedImage", { count: ids.length }),
{
position: "top-center",
},
);
}
if (pageToggle == "train") {
refreshTrain();
} else {
refreshDataset();
}
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
if (isName) {
toast.error(
t("toast.error.deleteCategoryFailed", { errorMessage }),
{
position: "top-center",
},
);
} else {
toast.error(t("toast.error.deleteImageFailed", { errorMessage }), {
position: "top-center",
});
}
});
},
[pageToggle, model, refreshTrain, refreshDataset, t],
);
// keyboard
useKeyboardListener(["a", "Escape"], (key, modifiers) => {
if (modifiers.repeat || !modifiers.down) {
return;
}
switch (key) {
case "a":
if (modifiers.ctrl) {
if (selectedImages.length) {
setSelectedImages([]);
} else {
setSelectedImages([
...(pageToggle === "train"
? trainImages || []
: dataset?.[pageToggle] || []),
]);
}
}
break;
case "Escape":
setSelectedImages([]);
break;
}
});
useEffect(() => {
setSelectedImages([]);
}, [pageToggle]);
return (
<div className="flex size-full flex-col overflow-hidden">
<Toaster />
<AlertDialog
open={!!deleteDialogOpen}
onOpenChange={() => setDeleteDialogOpen(null)}
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>
{t(
pageToggle == "train"
? "deleteTrainImages.title"
: "deleteDatasetImages.title",
)}
</AlertDialogTitle>
</AlertDialogHeader>
<AlertDialogDescription>
<Trans
ns="views/classificationModel"
values={{ count: deleteDialogOpen?.length, dataset: pageToggle }}
>
{pageToggle == "train"
? "deleteTrainImages.desc"
: "deleteDatasetImages.desc"}
</Trans>
</AlertDialogDescription>
<AlertDialogFooter>
<AlertDialogCancel>
{t("button.cancel", { ns: "common" })}
</AlertDialogCancel>
<AlertDialogAction
className={buttonVariants({ variant: "destructive" })}
onClick={() => {
if (deleteDialogOpen) {
onDelete(deleteDialogOpen);
setDeleteDialogOpen(null);
}
}}
>
{t("button.delete", { ns: "common" })}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
<div className="flex flex-row justify-between gap-2 p-2 align-middle">
<LibrarySelector
pageToggle={pageToggle}
dataset={dataset || {}}
trainImages={trainImages || []}
setPageToggle={setPageToggle}
onDelete={onDelete}
onRename={() => {}}
/>
{selectedImages?.length > 0 ? (
<div className="flex items-center justify-center gap-2">
<div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground">
<div className="p-1">{`${selectedImages.length} selected`}</div>
<div className="p-1">{"|"}</div>
<div
className="cursor-pointer p-2 text-primary hover:rounded-lg hover:bg-secondary"
onClick={() => setSelectedImages([])}
>
{t("button.unselect", { ns: "common" })}
</div>
</div>
<Button
className="flex gap-2"
onClick={() => setDeleteDialogOpen(selectedImages)}
>
<LuTrash2 className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop && t("button.deleteImages")}
</Button>
</div>
) : (
<Button onClick={trainModel}>Train Model</Button>
)}
</div>
{pageToggle == "train" ? (
<TrainGrid
model={model}
classes={Object.keys(dataset || {})}
trainImages={trainImages || []}
selectedImages={selectedImages}
onRefresh={refreshTrain}
onClickImages={onClickImages}
onDelete={onDelete}
/>
) : (
<DatasetGrid
modelName={model.name}
categoryName={pageToggle}
images={dataset?.[pageToggle] || []}
selectedImages={selectedImages}
onClickImages={onClickImages}
onDelete={onDelete}
/>
)}
</div>
);
}
type LibrarySelectorProps = {
pageToggle: string | undefined;
dataset: { [id: string]: string[] };
trainImages: string[];
setPageToggle: (toggle: string) => void;
onDelete: (ids: string[], isName: boolean) => void;
onRename: (old_name: string, new_name: string) => void;
};
function LibrarySelector({
pageToggle,
dataset,
trainImages,
setPageToggle,
onDelete,
onRename,
}: LibrarySelectorProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
const [renameFace, setRenameFace] = useState<string | null>(null);
const handleDeleteFace = useCallback(
(name: string) => {
// Get all image IDs for this face
const imageIds = dataset?.[name] || [];
onDelete(imageIds, true);
setPageToggle("train");
},
[dataset, onDelete, setPageToggle],
);
const handleSetOpen = useCallback(
(open: boolean) => {
setRenameFace(open ? renameFace : null);
},
[renameFace],
);
return (
<>
<Dialog
open={!!confirmDelete}
onOpenChange={(open) => !open && setConfirmDelete(null)}
>
<DialogContent>
<DialogHeader>
<DialogTitle>{t("deleteCategory.title")}</DialogTitle>
<DialogDescription>
{t("deleteCategory.desc", { name: confirmDelete })}
</DialogDescription>
</DialogHeader>
<div className="flex justify-end gap-2">
<Button variant="outline" onClick={() => setConfirmDelete(null)}>
{t("button.cancel", { ns: "common" })}
</Button>
<Button
variant="destructive"
onClick={() => {
if (confirmDelete) {
handleDeleteFace(confirmDelete);
setConfirmDelete(null);
}
}}
>
{t("button.delete", { ns: "common" })}
</Button>
</div>
</DialogContent>
</Dialog>
<TextEntryDialog
open={!!renameFace}
setOpen={handleSetOpen}
title={t("renameCategory.title")}
description={t("renameCategory.desc", { name: renameFace })}
onSave={(newName) => {
onRename(renameFace!, newName);
setRenameFace(null);
}}
defaultValue={renameFace || ""}
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
regexErrorMessage={t("description.invalidName")}
/>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button className="flex justify-between smart-capitalize">
{pageToggle == "train" ? t("train.title") : pageToggle}
<span className="ml-2 text-primary-variant">
(
{(pageToggle &&
(pageToggle == "train"
? trainImages.length
: dataset?.[pageToggle]?.length)) ||
0}
)
</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent
className="scrollbar-container max-h-[40dvh] min-w-[220px] overflow-y-auto"
align="start"
>
<DropdownMenuItem
className="flex cursor-pointer items-center justify-start gap-2"
aria-label={t("train.aria")}
onClick={() => setPageToggle("train")}
>
<div>{t("train.title")}</div>
<div className="text-secondary-foreground">
({trainImages.length})
</div>
</DropdownMenuItem>
{trainImages.length > 0 && Object.keys(dataset).length > 0 && (
<>
<DropdownMenuSeparator />
<div className="mb-1 ml-1.5 text-xs text-secondary-foreground">
{t("categories")}
</div>
</>
)}
{Object.keys(dataset).map((id) => (
<DropdownMenuItem
key={id}
className="group flex items-center justify-between"
>
<div
className="flex-grow cursor-pointer"
onClick={() => setPageToggle(id)}
>
{id}
<span className="ml-2 text-muted-foreground">
({dataset?.[id].length})
</span>
</div>
<div className="flex gap-0.5">
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
className="size-7 lg:opacity-0 lg:transition-opacity lg:group-hover:opacity-100"
onClick={(e) => {
e.stopPropagation();
setRenameFace(id);
}}
>
<LuPencil className="size-4 text-primary" />
</Button>
</TooltipTrigger>
<TooltipPortal>
<TooltipContent>
{t("button.renameCategory")}
</TooltipContent>
</TooltipPortal>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
className="size-7 lg:opacity-0 lg:transition-opacity lg:group-hover:opacity-100"
onClick={(e) => {
e.stopPropagation();
setConfirmDelete(id);
}}
>
<LuTrash2 className="size-4 text-destructive" />
</Button>
</TooltipTrigger>
<TooltipPortal>
<TooltipContent>
{t("button.deleteCategory")}
</TooltipContent>
</TooltipPortal>
</Tooltip>
</div>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
</>
);
}
type DatasetGridProps = {
modelName: string;
categoryName: string;
images: string[];
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onDelete: (ids: string[]) => void;
};
function DatasetGrid({
modelName,
categoryName,
images,
selectedImages,
onClickImages,
onDelete,
}: DatasetGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
return (
<div className="flex flex-wrap gap-2 overflow-y-auto p-2">
{images.map((image) => (
<div
className={cn(
"flex w-60 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]",
selectedImages.includes(image)
? "shadow-selected outline-selected"
: "outline-transparent duration-500",
)}
onClick={(e) => {
e.stopPropagation();
if (e.ctrlKey || e.metaKey) {
onClickImages([image], true);
}
}}
>
<div
className={cn(
"w-full overflow-hidden p-2 *:text-card-foreground",
isMobile && "flex justify-center",
)}
>
<img
className="rounded-lg"
src={`${baseUrl}clips/${modelName}/dataset/${categoryName}/${image}`}
/>
</div>
<div className="rounded-b-lg bg-card p-3">
<div className="flex w-full flex-row items-center justify-between gap-2">
<div className="flex w-full flex-row items-start justify-end gap-5 md:gap-4">
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={(e) => {
e.stopPropagation();
onDelete([image]);
}}
/>
</TooltipTrigger>
<TooltipContent>
{t("button.deleteClassificationAttempts")}
</TooltipContent>
</Tooltip>
</div>
</div>
</div>
</div>
))}
</div>
);
}
type TrainGridProps = {
model: CustomClassificationModelConfig;
classes: string[];
trainImages: string[];
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
};
function TrainGrid({
model,
classes,
trainImages,
selectedImages,
onClickImages,
onRefresh,
onDelete,
}: TrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
const trainData = useMemo(
() =>
trainImages
.map((raw) => {
const parts = raw.replaceAll(".webp", "").split("-");
return {
raw,
timestamp: parts[0],
label: parts[1],
score: Number.parseFloat(parts[2]) * 100,
};
})
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
[trainImages],
);
return (
<div className="flex flex-wrap gap-2 overflow-y-auto p-2">
{trainData?.map((data) => (
<div
key={data.timestamp}
className={cn(
"flex w-56 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]",
selectedImages.includes(data.raw)
? "shadow-selected outline-selected"
: "outline-transparent duration-500",
)}
onClick={(e) => {
e.stopPropagation();
onClickImages([data.raw], e.ctrlKey || e.metaKey);
}}
>
<div
className={cn(
"w-full overflow-hidden p-2 *:text-card-foreground",
isMobile && "flex justify-center",
)}
>
<img
className="w-56 rounded-lg"
src={`${baseUrl}clips/${model.name}/train/${data.raw}`}
/>
</div>
<div className="rounded-b-lg bg-card p-3">
<div className="flex w-full flex-row items-center justify-between gap-2">
<div className="flex flex-col items-start text-xs text-primary-variant">
<div className="smart-capitalize">{data.label}</div>
<div>{data.score}%</div>
</div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<ClassificationSelectionDialog
classes={classes}
modelName={model.name}
image={data.raw}
onRefresh={onRefresh}
>
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</ClassificationSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={(e) => {
e.stopPropagation();
onDelete([data.raw]);
}}
/>
</TooltipTrigger>
<TooltipContent>
{t("button.deleteClassificationAttempts")}
</TooltipContent>
</Tooltip>
</div>
</div>
</div>
</div>
))}
</div>
);
}