From 85d721eb6b76de07658b1187a1c9c614d771bf0a Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 4 Jun 2025 17:09:55 -0600 Subject: [PATCH] Classification Model UI (#18571) * Setup basic training structure * Build out route * Handle model configs * Add image fetch APIs * Implement model training screen with dataset selection * Implement viewing of training images * Adjust directories * Implement viewing of images * Add support for deleting images * Implement full deletion * Implement classification model training * Improve naming * More renaming * Improve layout * Reduce logging * Cleanup --- frigate/api/classification.py | 173 ++++- .../real_time/custom_classification.py | 2 +- frigate/util/classification.py | 10 +- .../locales/en/views/classificationModel.json | 49 ++ web/src/App.tsx | 2 + .../overlay/ClassificationSelectionDialog.tsx | 155 ++++ web/src/hooks/use-navigation.ts | 11 +- web/src/pages/ClassificationModel.tsx | 18 + web/src/types/frigateConfig.ts | 20 + .../classification/ModelSelectionView.tsx | 63 ++ .../classification/ModelTrainingView.tsx | 661 ++++++++++++++++++ 11 files changed, 1156 insertions(+), 8 deletions(-) create mode 100644 web/public/locales/en/views/classificationModel.json create mode 100644 web/src/components/overlay/ClassificationSelectionDialog.tsx create mode 100644 web/src/pages/ClassificationModel.tsx create mode 100644 web/src/views/classification/ModelSelectionView.tsx create mode 100644 web/src/views/classification/ModelTrainingView.tsx diff --git a/frigate/api/classification.py b/frigate/api/classification.py index f2c6ac06b..f5acc437c 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import ( from frigate.api.defs.tags import Tags from frigate.config import FrigateConfig from frigate.config.camera import DetectConfig -from frigate.const import FACE_DIR, MODEL_CACHE_DIR +from frigate.const import CLIPS_DIR, FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event from frigate.util.classification import train_classification_model @@ -431,6 +431,50 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody): # custom classification training +@router.get("/classification/{name}/dataset") +def get_classification_dataset(name: str): + dataset_dict: dict[str, list[str]] = {} + + dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") + + if not os.path.exists(dataset_dir): + return JSONResponse(status_code=200, content={}) + + for name in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, name) + + if not os.path.isdir(category_dir): + continue + + dataset_dict[name] = [] + + for file in filter( + lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), + os.listdir(category_dir), + ): + dataset_dict[name].append(file) + + return JSONResponse(status_code=200, content=dataset_dict) + + +@router.get("/classification/{name}/train") +def get_classification_images(name: str): + train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "train") + + if not os.path.exists(train_dir): + return JSONResponse(status_code=200, content=[]) + + return JSONResponse( + status_code=200, + content=list( + filter( + lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), + os.listdir(train_dir), + ) + ), + ) + + @router.post("/classification/{name}/train") async def train_configured_model( request: Request, name: str, background_tasks: BackgroundTasks @@ -448,10 +492,131 @@ async def train_configured_model( status_code=404, ) - background_tasks.add_task( - train_classification_model, os.path.join(MODEL_CACHE_DIR, name) - ) + background_tasks.add_task(train_classification_model, name) return JSONResponse( content={"success": True, "message": "Started classification model training."}, status_code=200, ) + + +@router.post( + "/classification/{name}/dataset/{category}/delete", + dependencies=[Depends(require_role(["admin"]))], +) +def delete_classification_dataset_images( + request: Request, name: str, category: str, body: dict = None +): + config: FrigateConfig = request.app.frigate_config + + if name not in config.classification.custom: + return JSONResponse( + content=( + { + "success": False, + "message": f"{name} is not a known classification model.", + } + ), + status_code=404, + ) + + json: dict[str, Any] = body or {} + list_of_ids = json.get("ids", "") + folder = os.path.join( + CLIPS_DIR, sanitize_filename(name), "dataset", sanitize_filename(category) + ) + + for id in list_of_ids: + file_path = os.path.join(folder, id) + + if os.path.isfile(file_path): + os.unlink(file_path) + + return JSONResponse( + content=({"success": True, "message": "Successfully deleted faces."}), + status_code=200, + ) + + +@router.post( + "/classification/{name}/dataset/categorize", + dependencies=[Depends(require_role(["admin"]))], +) +def categorize_classification_image(request: Request, name: str, body: dict = None): + config: FrigateConfig = request.app.frigate_config + + if name not in config.classification.custom: + return JSONResponse( + content=( + { + "success": False, + "message": f"{name} is not a known classification model.", + } + ), + status_code=404, + ) + + json: dict[str, Any] = body or {} + category = sanitize_filename(json.get("category", "")) + training_file_name = sanitize_filename(json.get("training_file", "")) + training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name) + + if training_file_name and not os.path.isfile(training_file): + return JSONResponse( + content=( + { + "success": False, + "message": f"Invalid filename or no file exists: {training_file_name}", + } + ), + status_code=404, + ) + + new_name = f"{category}-{datetime.datetime.now().timestamp()}.png" + new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category) + + if not os.path.exists(new_file_folder): + os.mkdir(new_file_folder) + + # use opencv because webp images can not be used to train + img = cv2.imread(training_file) + cv2.imwrite(os.path.join(new_file_folder, new_name), img) + os.unlink(training_file) + + return JSONResponse( + content=({"success": True, "message": "Successfully deleted faces."}), + status_code=200, + ) + + +@router.post( + "/classification/{name}/train/delete", + dependencies=[Depends(require_role(["admin"]))], +) +def delete_classification_train_images(request: Request, name: str, body: dict = None): + config: FrigateConfig = request.app.frigate_config + + if name not in config.classification.custom: + return JSONResponse( + content=( + { + "success": False, + "message": f"{name} is not a known classification model.", + } + ), + status_code=404, + ) + + json: dict[str, Any] = body or {} + list_of_ids = json.get("ids", "") + folder = os.path.join(CLIPS_DIR, sanitize_filename(name), "train") + + for id in list_of_ids: + file_path = os.path.join(folder, id) + + if os.path.isfile(file_path): + os.unlink(file_path) + + return JSONResponse( + content=({"success": True, "message": "Successfully deleted faces."}), + status_code=200, + ) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index f94c2b28c..0e254ab0d 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -42,7 +42,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.model_config = model_config self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) - self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name) + self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.interpreter: Interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 4ee5e1d54..a8624870b 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -1,5 +1,6 @@ """Util for classification models.""" +import logging import os import cv2 @@ -9,6 +10,8 @@ from tensorflow.keras import layers, models, optimizers from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.preprocessing.image import ImageDataGenerator +from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR + BATCH_SIZE = 16 EPOCHS = 50 LEARNING_RATE = 0.001 @@ -35,9 +38,10 @@ def generate_representative_dataset_factory(dataset_dir: str): @staticmethod -def train_classification_model(model_dir: str) -> bool: +def train_classification_model(model_name: str) -> bool: """Train a classification model.""" - dataset_dir = os.path.join(model_dir, "dataset") + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + model_dir = os.path.join(MODEL_CACHE_DIR, model_name) num_classes = len( [ d @@ -46,6 +50,8 @@ def train_classification_model(model_dir: str) -> bool: ] ) + tf.get_logger().setLevel(logging.ERROR) + # Start with imagenet base model with 35% of channels in each layer base_model = MobileNetV2( input_shape=(224, 224, 3), diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json new file mode 100644 index 000000000..eb09ecaa0 --- /dev/null +++ b/web/public/locales/en/views/classificationModel.json @@ -0,0 +1,49 @@ +{ + "button": { + "deleteClassificationAttempts": "Delete Classification Images", + "renameCategory": "Rename Class", + "deleteCategory": "Delete Class", + "deleteImages": "Delete Images" + }, + "toast": { + "success": { + "deletedCategory": "Deleted Class", + "deletedImage": "Deleted Images", + "categorizedImage": "Successfully Classified Image" + }, + "error": { + "deleteImageFailed": "Failed to delete: {{errorMessage}}", + "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", + "categorizeFailed": "Failed to categorize image: {{errorMessage}}" + } + }, + "deleteCategory": { + "title": "Delete Class", + "desc": "Are you sure you want to delete the class {{name}}? This will permanently delete all associated images and require re-training the model." + }, + "deleteDatasetImages": { + "title": "Delete Dataset Images", + "desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model." + }, + "deleteTrainImages": { + "title": "Delete Train Images", + "desc": "Are you sure you want to delete {{count}} images? This action cannot be undone." + }, + "renameCategory": { + "title": "Rename Class", + "desc": "Enter a new name for {{name}}. You will be required to retrain the model for the name change to take affect." + }, + "description": { + "invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens." + }, + "train": { + "title": "Train", + "aria": "Select Train" + }, + "categories": "Classes", + "createCategory": { + "new": "Create New Class" + }, + "categorizeImageAs": "Classify Image As:", + "categorizeImage": "Classify Image" +} diff --git a/web/src/App.tsx b/web/src/App.tsx index d3edbc3a2..cd7906e97 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -24,6 +24,7 @@ const System = lazy(() => import("@/pages/System")); const Settings = lazy(() => import("@/pages/Settings")); const UIPlayground = lazy(() => import("@/pages/UIPlayground")); const FaceLibrary = lazy(() => import("@/pages/FaceLibrary")); +const Classification = lazy(() => import("@/pages/ClassificationModel")); const Logs = lazy(() => import("@/pages/Logs")); const AccessDenied = lazy(() => import("@/pages/AccessDenied")); @@ -76,6 +77,7 @@ function DefaultAppView() { } /> } /> } /> + } /> } /> } /> diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx new file mode 100644 index 000000000..7cb8ca156 --- /dev/null +++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx @@ -0,0 +1,155 @@ +import { + Drawer, + DrawerClose, + DrawerContent, + DrawerDescription, + DrawerHeader, + DrawerTitle, + DrawerTrigger, +} from "@/components/ui/drawer"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { isDesktop, isMobile } from "react-device-detect"; +import { LuPlus } from "react-icons/lu"; +import { useTranslation } from "react-i18next"; +import { cn } from "@/lib/utils"; +import React, { ReactNode, useCallback, useMemo, useState } from "react"; +import TextEntryDialog from "./dialog/TextEntryDialog"; +import { Button } from "../ui/button"; +import { MdCategory } from "react-icons/md"; +import axios from "axios"; +import { toast } from "sonner"; + +type ClassificationSelectionDialogProps = { + className?: string; + classes: string[]; + modelName: string; + image: string; + onRefresh: () => void; + children: ReactNode; +}; +export default function ClassificationSelectionDialog({ + className, + classes, + modelName, + image, + onRefresh, + children, +}: ClassificationSelectionDialogProps) { + const { t } = useTranslation(["views/classificationModel"]); + + const onCategorizeImage = useCallback( + (category: string) => { + axios + .post(`/classification/${modelName}/dataset/categorize`, { + category, + training_file: image, + }) + .then((resp) => { + if (resp.status == 200) { + toast.success(t("toast.success.categorizedImage"), { + position: "top-center", + }); + onRefresh(); + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error(t("toast.error.categorizeFailed", { errorMessage }), { + position: "top-center", + }); + }); + }, + [modelName, image, onRefresh, t], + ); + + const isChildButton = useMemo( + () => React.isValidElement(children) && children.type === Button, + [children], + ); + + // control + const [newFace, setNewFace] = useState(false); + + // components + const Selector = isDesktop ? DropdownMenu : Drawer; + const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger; + const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent; + const SelectorItem = isDesktop + ? DropdownMenuItem + : (props: React.HTMLAttributes) => ( + +
+ + ); + + return ( +
+ {newFace && ( + onCategorizeImage(newCat)} + /> + )} + + + + + {children} + + + {isMobile && ( + + Details + Details + + )} + {t("categorizeImageAs")} +
+ setNewFace(true)} + > + + {t("createCategory.new")} + + {classes.sort().map((category) => ( + onCategorizeImage(category)} + > + + {category} + + ))} +
+
+
+ {t("categorizeImage")} +
+
+ ); +} diff --git a/web/src/hooks/use-navigation.ts b/web/src/hooks/use-navigation.ts index 41ec7227f..d9bd6f6a4 100644 --- a/web/src/hooks/use-navigation.ts +++ b/web/src/hooks/use-navigation.ts @@ -6,7 +6,7 @@ import { isDesktop } from "react-device-detect"; import { FaCompactDisc, FaVideo } from "react-icons/fa"; import { IoSearch } from "react-icons/io5"; import { LuConstruction } from "react-icons/lu"; -import { MdVideoLibrary } from "react-icons/md"; +import { MdCategory, MdVideoLibrary } from "react-icons/md"; import { TbFaceId } from "react-icons/tb"; import useSWR from "swr"; @@ -16,6 +16,7 @@ export const ID_EXPLORE = 3; export const ID_EXPORT = 4; export const ID_PLAYGROUND = 5; export const ID_FACE_LIBRARY = 6; +export const ID_CLASSIFICATION = 7; export default function useNavigation( variant: "primary" | "secondary" = "primary", @@ -71,6 +72,14 @@ export default function useNavigation( url: "/faces", enabled: isDesktop && config?.face_recognition.enabled, }, + { + id: ID_CLASSIFICATION, + variant, + icon: MdCategory, + title: "menu.classification", + url: "/classification", + enabled: isDesktop, + }, ] as NavData[], [config?.face_recognition?.enabled, variant], ); diff --git a/web/src/pages/ClassificationModel.tsx b/web/src/pages/ClassificationModel.tsx new file mode 100644 index 000000000..c37d0b454 --- /dev/null +++ b/web/src/pages/ClassificationModel.tsx @@ -0,0 +1,18 @@ +import { useOverlayState } from "@/hooks/use-overlay-state"; +import { CustomClassificationModelConfig } from "@/types/frigateConfig"; +import ModelSelectionView from "@/views/classification/ModelSelectionView"; +import ModelTrainingView from "@/views/classification/ModelTrainingView"; + +export default function ClassificationModelPage() { + // training + + const [model, setModel] = useOverlayState( + "classificationModel", + ); + + if (model == undefined) { + return ; + } + + return ; +} diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index cf2bf1476..3ccc5b06d 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -279,6 +279,23 @@ export type CameraStreamingSettings = { volume: number; }; +export type CustomClassificationModelConfig = { + enabled: boolean; + name: string; + object_config: null | { + objects: string[]; + }; + state_config: null | { + cameras: { + [cameraName: string]: { + crop: [number, number, number, number]; + threshold: number; + }; + }; + motion: boolean; + }; +}; + export type GroupStreamingSettings = { [cameraName: string]: CameraStreamingSettings; }; @@ -316,6 +333,9 @@ export interface FrigateConfig { enabled: boolean; threshold: number; }; + custom: { + [modelKey: string]: CustomClassificationModelConfig; + }; }; database: { diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx new file mode 100644 index 000000000..63133842a --- /dev/null +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -0,0 +1,63 @@ +import ActivityIndicator from "@/components/indicators/activity-indicator"; +import { cn } from "@/lib/utils"; +import { + CustomClassificationModelConfig, + FrigateConfig, +} from "@/types/frigateConfig"; +import { useMemo } from "react"; +import { isMobile } from "react-device-detect"; +import useSWR from "swr"; + +type ModelSelectionViewProps = { + onClick: (model: CustomClassificationModelConfig) => void; +}; +export default function ModelSelectionView({ + onClick, +}: ModelSelectionViewProps) { + const { data: config } = useSWR("config", { + revalidateOnFocus: false, + }); + + const classificationConfigs = useMemo(() => { + if (!config) { + return []; + } + + return Object.values(config.classification.custom); + }, [config]); + + if (!config) { + return ; + } + + if (classificationConfigs.length == 0) { + return
You need to setup a custom model configuration.
; + } + + return ( +
+ {classificationConfigs.map((config) => ( +
onClick(config)} + onContextMenu={() => { + // e.stopPropagation(); + // e.preventDefault(); + // handleClickEvent(true); + }} + > +
+
+ {config.name} ({config.state_config != null ? "State" : "Object"}{" "} + Classification) +
+
+ ))} +
+ ); +} diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx new file mode 100644 index 000000000..53ef7fa66 --- /dev/null +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -0,0 +1,661 @@ +import { baseUrl } from "@/api/baseUrl"; +import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog"; +import { Button, buttonVariants } from "@/components/ui/button"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Toaster } from "@/components/ui/sonner"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import useKeyboardListener from "@/hooks/use-keyboard-listener"; +import useOptimisticState from "@/hooks/use-optimistic-state"; +import { cn } from "@/lib/utils"; +import { CustomClassificationModelConfig } from "@/types/frigateConfig"; +import { TooltipPortal } from "@radix-ui/react-tooltip"; +import axios from "axios"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { isDesktop, isMobile } from "react-device-detect"; +import { Trans, useTranslation } from "react-i18next"; +import { LuPencil, LuTrash2 } from "react-icons/lu"; +import { toast } from "sonner"; +import useSWR from "swr"; +import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; +import { TbCategoryPlus } from "react-icons/tb"; + +type ModelTrainingViewProps = { + model: CustomClassificationModelConfig; +}; +export default function ModelTrainingView({ model }: ModelTrainingViewProps) { + const { t } = useTranslation(["views/classificationModel"]); + const [page, setPage] = useState("train"); + const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); + + // dataset + + const { data: trainImages, mutate: refreshTrain } = useSWR( + `classification/${model.name}/train`, + ); + const { data: dataset, mutate: refreshDataset } = useSWR<{ + [id: string]: string[]; + }>(`classification/${model.name}/dataset`); + + // image multiselect + + const [selectedImages, setSelectedImages] = useState([]); + + const onClickImages = useCallback( + (images: string[], ctrl: boolean) => { + if (selectedImages.length == 0 && !ctrl) { + return; + } + + let newSelectedImages = [...selectedImages]; + + images.forEach((imageId) => { + const index = newSelectedImages.indexOf(imageId); + + if (index != -1) { + if (selectedImages.length == 1) { + newSelectedImages = []; + } else { + const copy = [ + ...newSelectedImages.slice(0, index), + ...newSelectedImages.slice(index + 1), + ]; + newSelectedImages = copy; + } + } else { + newSelectedImages.push(imageId); + } + }); + + setSelectedImages(newSelectedImages); + }, + [selectedImages, setSelectedImages], + ); + + // actions + + const trainModel = useCallback(() => { + axios.post(`classification/${model.name}/train`); + }, [model]); + + const [deleteDialogOpen, setDeleteDialogOpen] = useState( + null, + ); + + const onDelete = useCallback( + (ids: string[], isName: boolean = false) => { + const api = + pageToggle == "train" + ? `/classification/${model.name}/train/delete` + : `/classification/${model.name}/dataset/${pageToggle}/delete`; + + axios + .post(api, { ids }) + .then((resp) => { + setSelectedImages([]); + + if (resp.status == 200) { + if (isName) { + toast.success( + t("toast.success.deletedCategory", { count: ids.length }), + { + position: "top-center", + }, + ); + } else { + toast.success( + t("toast.success.deletedImage", { count: ids.length }), + { + position: "top-center", + }, + ); + } + + if (pageToggle == "train") { + refreshTrain(); + } else { + refreshDataset(); + } + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + if (isName) { + toast.error( + t("toast.error.deleteCategoryFailed", { errorMessage }), + { + position: "top-center", + }, + ); + } else { + toast.error(t("toast.error.deleteImageFailed", { errorMessage }), { + position: "top-center", + }); + } + }); + }, + [pageToggle, model, refreshTrain, refreshDataset, t], + ); + + // keyboard + + useKeyboardListener(["a", "Escape"], (key, modifiers) => { + if (modifiers.repeat || !modifiers.down) { + return; + } + + switch (key) { + case "a": + if (modifiers.ctrl) { + if (selectedImages.length) { + setSelectedImages([]); + } else { + setSelectedImages([ + ...(pageToggle === "train" + ? trainImages || [] + : dataset?.[pageToggle] || []), + ]); + } + } + break; + case "Escape": + setSelectedImages([]); + break; + } + }); + + useEffect(() => { + setSelectedImages([]); + }, [pageToggle]); + + return ( +
+ + + setDeleteDialogOpen(null)} + > + + + + {t( + pageToggle == "train" + ? "deleteTrainImages.title" + : "deleteDatasetImages.title", + )} + + + + + {pageToggle == "train" + ? "deleteTrainImages.desc" + : "deleteDatasetImages.desc"} + + + + + {t("button.cancel", { ns: "common" })} + + { + if (deleteDialogOpen) { + onDelete(deleteDialogOpen); + setDeleteDialogOpen(null); + } + }} + > + {t("button.delete", { ns: "common" })} + + + + + +
+ {}} + /> + {selectedImages?.length > 0 ? ( +
+
+
{`${selectedImages.length} selected`}
+
{"|"}
+
setSelectedImages([])} + > + {t("button.unselect", { ns: "common" })} +
+
+ +
+ ) : ( + + )} +
+ {pageToggle == "train" ? ( + + ) : ( + + )} +
+ ); +} + +type LibrarySelectorProps = { + pageToggle: string | undefined; + dataset: { [id: string]: string[] }; + trainImages: string[]; + setPageToggle: (toggle: string) => void; + onDelete: (ids: string[], isName: boolean) => void; + onRename: (old_name: string, new_name: string) => void; +}; +function LibrarySelector({ + pageToggle, + dataset, + trainImages, + setPageToggle, + onDelete, + onRename, +}: LibrarySelectorProps) { + const { t } = useTranslation(["views/classificationModel"]); + const [confirmDelete, setConfirmDelete] = useState(null); + const [renameFace, setRenameFace] = useState(null); + + const handleDeleteFace = useCallback( + (name: string) => { + // Get all image IDs for this face + const imageIds = dataset?.[name] || []; + + onDelete(imageIds, true); + setPageToggle("train"); + }, + [dataset, onDelete, setPageToggle], + ); + + const handleSetOpen = useCallback( + (open: boolean) => { + setRenameFace(open ? renameFace : null); + }, + [renameFace], + ); + + return ( + <> + !open && setConfirmDelete(null)} + > + + + {t("deleteCategory.title")} + + {t("deleteCategory.desc", { name: confirmDelete })} + + +
+ + +
+
+
+ + { + onRename(renameFace!, newName); + setRenameFace(null); + }} + defaultValue={renameFace || ""} + regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u} + regexErrorMessage={t("description.invalidName")} + /> + + + + + + + setPageToggle("train")} + > +
{t("train.title")}
+
+ ({trainImages.length}) +
+
+ {trainImages.length > 0 && Object.keys(dataset).length > 0 && ( + <> + +
+ {t("categories")} +
+ + )} + {Object.keys(dataset).map((id) => ( + +
setPageToggle(id)} + > + {id} + + ({dataset?.[id].length}) + +
+
+ + + + + + + {t("button.renameCategory")} + + + + + + + + + + {t("button.deleteCategory")} + + + +
+
+ ))} +
+
+ + ); +} + +type DatasetGridProps = { + modelName: string; + categoryName: string; + images: string[]; + selectedImages: string[]; + onClickImages: (images: string[], ctrl: boolean) => void; + onDelete: (ids: string[]) => void; +}; +function DatasetGrid({ + modelName, + categoryName, + images, + selectedImages, + onClickImages, + onDelete, +}: DatasetGridProps) { + const { t } = useTranslation(["views/classificationModel"]); + + return ( +
+ {images.map((image) => ( +
{ + e.stopPropagation(); + + if (e.ctrlKey || e.metaKey) { + onClickImages([image], true); + } + }} + > +
+ +
+
+
+
+ + + { + e.stopPropagation(); + onDelete([image]); + }} + /> + + + {t("button.deleteClassificationAttempts")} + + +
+
+
+
+ ))} +
+ ); +} + +type TrainGridProps = { + model: CustomClassificationModelConfig; + classes: string[]; + trainImages: string[]; + selectedImages: string[]; + onClickImages: (images: string[], ctrl: boolean) => void; + onRefresh: () => void; + onDelete: (ids: string[]) => void; +}; +function TrainGrid({ + model, + classes, + trainImages, + selectedImages, + onClickImages, + onRefresh, + onDelete, +}: TrainGridProps) { + const { t } = useTranslation(["views/classificationModel"]); + + const trainData = useMemo( + () => + trainImages + .map((raw) => { + const parts = raw.replaceAll(".webp", "").split("-"); + return { + raw, + timestamp: parts[0], + label: parts[1], + score: Number.parseFloat(parts[2]) * 100, + }; + }) + .sort((a, b) => b.timestamp.localeCompare(a.timestamp)), + [trainImages], + ); + + return ( +
+ {trainData?.map((data) => ( +
{ + e.stopPropagation(); + onClickImages([data.raw], e.ctrlKey || e.metaKey); + }} + > +
+ +
+
+
+
+
{data.label}
+
{data.score}%
+
+
+ + + + + + { + e.stopPropagation(); + onDelete([data.raw]); + }} + /> + + + {t("button.deleteClassificationAttempts")} + + +
+
+
+
+ ))} +
+ ); +}