mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
Classification Model UI (#18571)
* Setup basic training structure * Build out route * Handle model configs * Add image fetch APIs * Implement model training screen with dataset selection * Implement viewing of training images * Adjust directories * Implement viewing of images * Add support for deleting images * Implement full deletion * Implement classification model training * Improve naming * More renaming * Improve layout * Reduce logging * Cleanup
This commit is contained in:
parent
eb1fe9fe20
commit
85d721eb6b
@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import (
|
|||||||
from frigate.api.defs.tags import Tags
|
from frigate.api.defs.tags import Tags
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.config.camera import DetectConfig
|
from frigate.config.camera import DetectConfig
|
||||||
from frigate.const import FACE_DIR, MODEL_CACHE_DIR
|
from frigate.const import CLIPS_DIR, FACE_DIR
|
||||||
from frigate.embeddings import EmbeddingsContext
|
from frigate.embeddings import EmbeddingsContext
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
from frigate.util.classification import train_classification_model
|
from frigate.util.classification import train_classification_model
|
||||||
@ -431,6 +431,50 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
|
|||||||
# custom classification training
|
# custom classification training
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/classification/{name}/dataset")
|
||||||
|
def get_classification_dataset(name: str):
|
||||||
|
dataset_dict: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_dir):
|
||||||
|
return JSONResponse(status_code=200, content={})
|
||||||
|
|
||||||
|
for name in os.listdir(dataset_dir):
|
||||||
|
category_dir = os.path.join(dataset_dir, name)
|
||||||
|
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset_dict[name] = []
|
||||||
|
|
||||||
|
for file in filter(
|
||||||
|
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||||
|
os.listdir(category_dir),
|
||||||
|
):
|
||||||
|
dataset_dict[name].append(file)
|
||||||
|
|
||||||
|
return JSONResponse(status_code=200, content=dataset_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/classification/{name}/train")
|
||||||
|
def get_classification_images(name: str):
|
||||||
|
train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "train")
|
||||||
|
|
||||||
|
if not os.path.exists(train_dir):
|
||||||
|
return JSONResponse(status_code=200, content=[])
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content=list(
|
||||||
|
filter(
|
||||||
|
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||||
|
os.listdir(train_dir),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/classification/{name}/train")
|
@router.post("/classification/{name}/train")
|
||||||
async def train_configured_model(
|
async def train_configured_model(
|
||||||
request: Request, name: str, background_tasks: BackgroundTasks
|
request: Request, name: str, background_tasks: BackgroundTasks
|
||||||
@ -448,10 +492,131 @@ async def train_configured_model(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
)
|
)
|
||||||
|
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(train_classification_model, name)
|
||||||
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
|
|
||||||
)
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"success": True, "message": "Started classification model training."},
|
content={"success": True, "message": "Started classification model training."},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/classification/{name}/dataset/{category}/delete",
|
||||||
|
dependencies=[Depends(require_role(["admin"]))],
|
||||||
|
)
|
||||||
|
def delete_classification_dataset_images(
|
||||||
|
request: Request, name: str, category: str, body: dict = None
|
||||||
|
):
|
||||||
|
config: FrigateConfig = request.app.frigate_config
|
||||||
|
|
||||||
|
if name not in config.classification.custom:
|
||||||
|
return JSONResponse(
|
||||||
|
content=(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"{name} is not a known classification model.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
json: dict[str, Any] = body or {}
|
||||||
|
list_of_ids = json.get("ids", "")
|
||||||
|
folder = os.path.join(
|
||||||
|
CLIPS_DIR, sanitize_filename(name), "dataset", sanitize_filename(category)
|
||||||
|
)
|
||||||
|
|
||||||
|
for id in list_of_ids:
|
||||||
|
file_path = os.path.join(folder, id)
|
||||||
|
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/classification/{name}/dataset/categorize",
|
||||||
|
dependencies=[Depends(require_role(["admin"]))],
|
||||||
|
)
|
||||||
|
def categorize_classification_image(request: Request, name: str, body: dict = None):
|
||||||
|
config: FrigateConfig = request.app.frigate_config
|
||||||
|
|
||||||
|
if name not in config.classification.custom:
|
||||||
|
return JSONResponse(
|
||||||
|
content=(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"{name} is not a known classification model.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
json: dict[str, Any] = body or {}
|
||||||
|
category = sanitize_filename(json.get("category", ""))
|
||||||
|
training_file_name = sanitize_filename(json.get("training_file", ""))
|
||||||
|
training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name)
|
||||||
|
|
||||||
|
if training_file_name and not os.path.isfile(training_file):
|
||||||
|
return JSONResponse(
|
||||||
|
content=(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid filename or no file exists: {training_file_name}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
|
||||||
|
new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category)
|
||||||
|
|
||||||
|
if not os.path.exists(new_file_folder):
|
||||||
|
os.mkdir(new_file_folder)
|
||||||
|
|
||||||
|
# use opencv because webp images can not be used to train
|
||||||
|
img = cv2.imread(training_file)
|
||||||
|
cv2.imwrite(os.path.join(new_file_folder, new_name), img)
|
||||||
|
os.unlink(training_file)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/classification/{name}/train/delete",
|
||||||
|
dependencies=[Depends(require_role(["admin"]))],
|
||||||
|
)
|
||||||
|
def delete_classification_train_images(request: Request, name: str, body: dict = None):
|
||||||
|
config: FrigateConfig = request.app.frigate_config
|
||||||
|
|
||||||
|
if name not in config.classification.custom:
|
||||||
|
return JSONResponse(
|
||||||
|
content=(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"{name} is not a known classification model.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
json: dict[str, Any] = body or {}
|
||||||
|
list_of_ids = json.get("ids", "")
|
||||||
|
folder = os.path.join(CLIPS_DIR, sanitize_filename(name), "train")
|
||||||
|
|
||||||
|
for id in list_of_ids:
|
||||||
|
file_path = os.path.join(folder, id)
|
||||||
|
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
@ -42,7 +42,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name)
|
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||||
self.interpreter: Interpreter = None
|
self.interpreter: Interpreter = None
|
||||||
self.tensor_input_details: dict[str, Any] = None
|
self.tensor_input_details: dict[str, Any] = None
|
||||||
self.tensor_output_details: dict[str, Any] = None
|
self.tensor_output_details: dict[str, Any] = None
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Util for classification models."""
|
"""Util for classification models."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -9,6 +10,8 @@ from tensorflow.keras import layers, models, optimizers
|
|||||||
from tensorflow.keras.applications import MobileNetV2
|
from tensorflow.keras.applications import MobileNetV2
|
||||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 50
|
EPOCHS = 50
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
@ -35,9 +38,10 @@ def generate_representative_dataset_factory(dataset_dir: str):
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def train_classification_model(model_dir: str) -> bool:
|
def train_classification_model(model_name: str) -> bool:
|
||||||
"""Train a classification model."""
|
"""Train a classification model."""
|
||||||
dataset_dir = os.path.join(model_dir, "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||||
|
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||||
num_classes = len(
|
num_classes = len(
|
||||||
[
|
[
|
||||||
d
|
d
|
||||||
@ -46,6 +50,8 @@ def train_classification_model(model_dir: str) -> bool:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf.get_logger().setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Start with imagenet base model with 35% of channels in each layer
|
# Start with imagenet base model with 35% of channels in each layer
|
||||||
base_model = MobileNetV2(
|
base_model = MobileNetV2(
|
||||||
input_shape=(224, 224, 3),
|
input_shape=(224, 224, 3),
|
||||||
|
49
web/public/locales/en/views/classificationModel.json
Normal file
49
web/public/locales/en/views/classificationModel.json
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"button": {
|
||||||
|
"deleteClassificationAttempts": "Delete Classification Images",
|
||||||
|
"renameCategory": "Rename Class",
|
||||||
|
"deleteCategory": "Delete Class",
|
||||||
|
"deleteImages": "Delete Images"
|
||||||
|
},
|
||||||
|
"toast": {
|
||||||
|
"success": {
|
||||||
|
"deletedCategory": "Deleted Class",
|
||||||
|
"deletedImage": "Deleted Images",
|
||||||
|
"categorizedImage": "Successfully Classified Image"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
|
||||||
|
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||||
|
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"deleteCategory": {
|
||||||
|
"title": "Delete Class",
|
||||||
|
"desc": "Are you sure you want to delete the class {{name}}? This will permanently delete all associated images and require re-training the model."
|
||||||
|
},
|
||||||
|
"deleteDatasetImages": {
|
||||||
|
"title": "Delete Dataset Images",
|
||||||
|
"desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model."
|
||||||
|
},
|
||||||
|
"deleteTrainImages": {
|
||||||
|
"title": "Delete Train Images",
|
||||||
|
"desc": "Are you sure you want to delete {{count}} images? This action cannot be undone."
|
||||||
|
},
|
||||||
|
"renameCategory": {
|
||||||
|
"title": "Rename Class",
|
||||||
|
"desc": "Enter a new name for {{name}}. You will be required to retrain the model for the name change to take affect."
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens."
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"title": "Train",
|
||||||
|
"aria": "Select Train"
|
||||||
|
},
|
||||||
|
"categories": "Classes",
|
||||||
|
"createCategory": {
|
||||||
|
"new": "Create New Class"
|
||||||
|
},
|
||||||
|
"categorizeImageAs": "Classify Image As:",
|
||||||
|
"categorizeImage": "Classify Image"
|
||||||
|
}
|
@ -24,6 +24,7 @@ const System = lazy(() => import("@/pages/System"));
|
|||||||
const Settings = lazy(() => import("@/pages/Settings"));
|
const Settings = lazy(() => import("@/pages/Settings"));
|
||||||
const UIPlayground = lazy(() => import("@/pages/UIPlayground"));
|
const UIPlayground = lazy(() => import("@/pages/UIPlayground"));
|
||||||
const FaceLibrary = lazy(() => import("@/pages/FaceLibrary"));
|
const FaceLibrary = lazy(() => import("@/pages/FaceLibrary"));
|
||||||
|
const Classification = lazy(() => import("@/pages/ClassificationModel"));
|
||||||
const Logs = lazy(() => import("@/pages/Logs"));
|
const Logs = lazy(() => import("@/pages/Logs"));
|
||||||
const AccessDenied = lazy(() => import("@/pages/AccessDenied"));
|
const AccessDenied = lazy(() => import("@/pages/AccessDenied"));
|
||||||
|
|
||||||
@ -76,6 +77,7 @@ function DefaultAppView() {
|
|||||||
<Route path="/config" element={<ConfigEditor />} />
|
<Route path="/config" element={<ConfigEditor />} />
|
||||||
<Route path="/logs" element={<Logs />} />
|
<Route path="/logs" element={<Logs />} />
|
||||||
<Route path="/faces" element={<FaceLibrary />} />
|
<Route path="/faces" element={<FaceLibrary />} />
|
||||||
|
<Route path="/classification" element={<Classification />} />
|
||||||
<Route path="/playground" element={<UIPlayground />} />
|
<Route path="/playground" element={<UIPlayground />} />
|
||||||
</Route>
|
</Route>
|
||||||
<Route path="/unauthorized" element={<AccessDenied />} />
|
<Route path="/unauthorized" element={<AccessDenied />} />
|
||||||
|
155
web/src/components/overlay/ClassificationSelectionDialog.tsx
Normal file
155
web/src/components/overlay/ClassificationSelectionDialog.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import {
|
||||||
|
Drawer,
|
||||||
|
DrawerClose,
|
||||||
|
DrawerContent,
|
||||||
|
DrawerDescription,
|
||||||
|
DrawerHeader,
|
||||||
|
DrawerTitle,
|
||||||
|
DrawerTrigger,
|
||||||
|
} from "@/components/ui/drawer";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuLabel,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import { isDesktop, isMobile } from "react-device-detect";
|
||||||
|
import { LuPlus } from "react-icons/lu";
|
||||||
|
import { useTranslation } from "react-i18next";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import React, { ReactNode, useCallback, useMemo, useState } from "react";
|
||||||
|
import TextEntryDialog from "./dialog/TextEntryDialog";
|
||||||
|
import { Button } from "../ui/button";
|
||||||
|
import { MdCategory } from "react-icons/md";
|
||||||
|
import axios from "axios";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
type ClassificationSelectionDialogProps = {
|
||||||
|
className?: string;
|
||||||
|
classes: string[];
|
||||||
|
modelName: string;
|
||||||
|
image: string;
|
||||||
|
onRefresh: () => void;
|
||||||
|
children: ReactNode;
|
||||||
|
};
|
||||||
|
export default function ClassificationSelectionDialog({
|
||||||
|
className,
|
||||||
|
classes,
|
||||||
|
modelName,
|
||||||
|
image,
|
||||||
|
onRefresh,
|
||||||
|
children,
|
||||||
|
}: ClassificationSelectionDialogProps) {
|
||||||
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
|
|
||||||
|
const onCategorizeImage = useCallback(
|
||||||
|
(category: string) => {
|
||||||
|
axios
|
||||||
|
.post(`/classification/${modelName}/dataset/categorize`, {
|
||||||
|
category,
|
||||||
|
training_file: image,
|
||||||
|
})
|
||||||
|
.then((resp) => {
|
||||||
|
if (resp.status == 200) {
|
||||||
|
toast.success(t("toast.success.categorizedImage"), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
onRefresh();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
const errorMessage =
|
||||||
|
error.response?.data?.message ||
|
||||||
|
error.response?.data?.detail ||
|
||||||
|
"Unknown error";
|
||||||
|
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[modelName, image, onRefresh, t],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isChildButton = useMemo(
|
||||||
|
() => React.isValidElement(children) && children.type === Button,
|
||||||
|
[children],
|
||||||
|
);
|
||||||
|
|
||||||
|
// control
|
||||||
|
const [newFace, setNewFace] = useState(false);
|
||||||
|
|
||||||
|
// components
|
||||||
|
const Selector = isDesktop ? DropdownMenu : Drawer;
|
||||||
|
const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger;
|
||||||
|
const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent;
|
||||||
|
const SelectorItem = isDesktop
|
||||||
|
? DropdownMenuItem
|
||||||
|
: (props: React.HTMLAttributes<HTMLDivElement>) => (
|
||||||
|
<DrawerClose asChild>
|
||||||
|
<div {...props} className={cn(props.className, "my-2")} />
|
||||||
|
</DrawerClose>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={className ?? ""}>
|
||||||
|
{newFace && (
|
||||||
|
<TextEntryDialog
|
||||||
|
open={true}
|
||||||
|
setOpen={setNewFace}
|
||||||
|
title={t("createCategory.new")}
|
||||||
|
onSave={(newCat) => onCategorizeImage(newCat)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Tooltip>
|
||||||
|
<Selector>
|
||||||
|
<SelectorTrigger asChild>
|
||||||
|
<TooltipTrigger asChild={isChildButton}>{children}</TooltipTrigger>
|
||||||
|
</SelectorTrigger>
|
||||||
|
<SelectorContent
|
||||||
|
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
|
||||||
|
>
|
||||||
|
{isMobile && (
|
||||||
|
<DrawerHeader className="sr-only">
|
||||||
|
<DrawerTitle>Details</DrawerTitle>
|
||||||
|
<DrawerDescription>Details</DrawerDescription>
|
||||||
|
</DrawerHeader>
|
||||||
|
)}
|
||||||
|
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex max-h-[40dvh] flex-col overflow-y-auto",
|
||||||
|
isMobile && "gap-2 pb-4",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<SelectorItem
|
||||||
|
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||||
|
onClick={() => setNewFace(true)}
|
||||||
|
>
|
||||||
|
<LuPlus />
|
||||||
|
{t("createCategory.new")}
|
||||||
|
</SelectorItem>
|
||||||
|
{classes.sort().map((category) => (
|
||||||
|
<SelectorItem
|
||||||
|
key={category}
|
||||||
|
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||||
|
onClick={() => onCategorizeImage(category)}
|
||||||
|
>
|
||||||
|
<MdCategory />
|
||||||
|
{category}
|
||||||
|
</SelectorItem>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</SelectorContent>
|
||||||
|
</Selector>
|
||||||
|
<TooltipContent>{t("categorizeImage")}</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -6,7 +6,7 @@ import { isDesktop } from "react-device-detect";
|
|||||||
import { FaCompactDisc, FaVideo } from "react-icons/fa";
|
import { FaCompactDisc, FaVideo } from "react-icons/fa";
|
||||||
import { IoSearch } from "react-icons/io5";
|
import { IoSearch } from "react-icons/io5";
|
||||||
import { LuConstruction } from "react-icons/lu";
|
import { LuConstruction } from "react-icons/lu";
|
||||||
import { MdVideoLibrary } from "react-icons/md";
|
import { MdCategory, MdVideoLibrary } from "react-icons/md";
|
||||||
import { TbFaceId } from "react-icons/tb";
|
import { TbFaceId } from "react-icons/tb";
|
||||||
import useSWR from "swr";
|
import useSWR from "swr";
|
||||||
|
|
||||||
@ -16,6 +16,7 @@ export const ID_EXPLORE = 3;
|
|||||||
export const ID_EXPORT = 4;
|
export const ID_EXPORT = 4;
|
||||||
export const ID_PLAYGROUND = 5;
|
export const ID_PLAYGROUND = 5;
|
||||||
export const ID_FACE_LIBRARY = 6;
|
export const ID_FACE_LIBRARY = 6;
|
||||||
|
export const ID_CLASSIFICATION = 7;
|
||||||
|
|
||||||
export default function useNavigation(
|
export default function useNavigation(
|
||||||
variant: "primary" | "secondary" = "primary",
|
variant: "primary" | "secondary" = "primary",
|
||||||
@ -71,6 +72,14 @@ export default function useNavigation(
|
|||||||
url: "/faces",
|
url: "/faces",
|
||||||
enabled: isDesktop && config?.face_recognition.enabled,
|
enabled: isDesktop && config?.face_recognition.enabled,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: ID_CLASSIFICATION,
|
||||||
|
variant,
|
||||||
|
icon: MdCategory,
|
||||||
|
title: "menu.classification",
|
||||||
|
url: "/classification",
|
||||||
|
enabled: isDesktop,
|
||||||
|
},
|
||||||
] as NavData[],
|
] as NavData[],
|
||||||
[config?.face_recognition?.enabled, variant],
|
[config?.face_recognition?.enabled, variant],
|
||||||
);
|
);
|
||||||
|
18
web/src/pages/ClassificationModel.tsx
Normal file
18
web/src/pages/ClassificationModel.tsx
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import { useOverlayState } from "@/hooks/use-overlay-state";
|
||||||
|
import { CustomClassificationModelConfig } from "@/types/frigateConfig";
|
||||||
|
import ModelSelectionView from "@/views/classification/ModelSelectionView";
|
||||||
|
import ModelTrainingView from "@/views/classification/ModelTrainingView";
|
||||||
|
|
||||||
|
export default function ClassificationModelPage() {
|
||||||
|
// training
|
||||||
|
|
||||||
|
const [model, setModel] = useOverlayState<CustomClassificationModelConfig>(
|
||||||
|
"classificationModel",
|
||||||
|
);
|
||||||
|
|
||||||
|
if (model == undefined) {
|
||||||
|
return <ModelSelectionView onClick={setModel} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return <ModelTrainingView model={model} />;
|
||||||
|
}
|
@ -279,6 +279,23 @@ export type CameraStreamingSettings = {
|
|||||||
volume: number;
|
volume: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type CustomClassificationModelConfig = {
|
||||||
|
enabled: boolean;
|
||||||
|
name: string;
|
||||||
|
object_config: null | {
|
||||||
|
objects: string[];
|
||||||
|
};
|
||||||
|
state_config: null | {
|
||||||
|
cameras: {
|
||||||
|
[cameraName: string]: {
|
||||||
|
crop: [number, number, number, number];
|
||||||
|
threshold: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
motion: boolean;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export type GroupStreamingSettings = {
|
export type GroupStreamingSettings = {
|
||||||
[cameraName: string]: CameraStreamingSettings;
|
[cameraName: string]: CameraStreamingSettings;
|
||||||
};
|
};
|
||||||
@ -316,6 +333,9 @@ export interface FrigateConfig {
|
|||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
threshold: number;
|
threshold: number;
|
||||||
};
|
};
|
||||||
|
custom: {
|
||||||
|
[modelKey: string]: CustomClassificationModelConfig;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
database: {
|
database: {
|
||||||
|
63
web/src/views/classification/ModelSelectionView.tsx
Normal file
63
web/src/views/classification/ModelSelectionView.tsx
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import {
|
||||||
|
CustomClassificationModelConfig,
|
||||||
|
FrigateConfig,
|
||||||
|
} from "@/types/frigateConfig";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
import { isMobile } from "react-device-detect";
|
||||||
|
import useSWR from "swr";
|
||||||
|
|
||||||
|
type ModelSelectionViewProps = {
|
||||||
|
onClick: (model: CustomClassificationModelConfig) => void;
|
||||||
|
};
|
||||||
|
export default function ModelSelectionView({
|
||||||
|
onClick,
|
||||||
|
}: ModelSelectionViewProps) {
|
||||||
|
const { data: config } = useSWR<FrigateConfig>("config", {
|
||||||
|
revalidateOnFocus: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
const classificationConfigs = useMemo(() => {
|
||||||
|
if (!config) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.values(config.classification.custom);
|
||||||
|
}, [config]);
|
||||||
|
|
||||||
|
if (!config) {
|
||||||
|
return <ActivityIndicator />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (classificationConfigs.length == 0) {
|
||||||
|
return <div>You need to setup a custom model configuration.</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex size-full gap-2 p-2">
|
||||||
|
{classificationConfigs.map((config) => (
|
||||||
|
<div
|
||||||
|
key={config.name}
|
||||||
|
className={cn(
|
||||||
|
"flex h-52 cursor-pointer flex-col gap-2 rounded-lg bg-card p-2 outline outline-[3px]",
|
||||||
|
"outline-transparent duration-500",
|
||||||
|
isMobile && "w-full",
|
||||||
|
)}
|
||||||
|
onClick={() => onClick(config)}
|
||||||
|
onContextMenu={() => {
|
||||||
|
// e.stopPropagation();
|
||||||
|
// e.preventDefault();
|
||||||
|
// handleClickEvent(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="size-48"></div>
|
||||||
|
<div className="smart-capitalize">
|
||||||
|
{config.name} ({config.state_config != null ? "State" : "Object"}{" "}
|
||||||
|
Classification)
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
661
web/src/views/classification/ModelTrainingView.tsx
Normal file
661
web/src/views/classification/ModelTrainingView.tsx
Normal file
@ -0,0 +1,661 @@
|
|||||||
|
import { baseUrl } from "@/api/baseUrl";
|
||||||
|
import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog";
|
||||||
|
import { Button, buttonVariants } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/ui/dialog";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuSeparator,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import { Toaster } from "@/components/ui/sonner";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import useKeyboardListener from "@/hooks/use-keyboard-listener";
|
||||||
|
import useOptimisticState from "@/hooks/use-optimistic-state";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { CustomClassificationModelConfig } from "@/types/frigateConfig";
|
||||||
|
import { TooltipPortal } from "@radix-ui/react-tooltip";
|
||||||
|
import axios from "axios";
|
||||||
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
|
import { isDesktop, isMobile } from "react-device-detect";
|
||||||
|
import { Trans, useTranslation } from "react-i18next";
|
||||||
|
import { LuPencil, LuTrash2 } from "react-icons/lu";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import useSWR from "swr";
|
||||||
|
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
||||||
|
import { TbCategoryPlus } from "react-icons/tb";
|
||||||
|
|
||||||
|
type ModelTrainingViewProps = {
|
||||||
|
model: CustomClassificationModelConfig;
|
||||||
|
};
|
||||||
|
export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||||
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
|
const [page, setPage] = useState<string>("train");
|
||||||
|
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||||
|
|
||||||
|
// dataset
|
||||||
|
|
||||||
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
|
`classification/${model.name}/train`,
|
||||||
|
);
|
||||||
|
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
||||||
|
[id: string]: string[];
|
||||||
|
}>(`classification/${model.name}/dataset`);
|
||||||
|
|
||||||
|
// image multiselect
|
||||||
|
|
||||||
|
const [selectedImages, setSelectedImages] = useState<string[]>([]);
|
||||||
|
|
||||||
|
const onClickImages = useCallback(
|
||||||
|
(images: string[], ctrl: boolean) => {
|
||||||
|
if (selectedImages.length == 0 && !ctrl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newSelectedImages = [...selectedImages];
|
||||||
|
|
||||||
|
images.forEach((imageId) => {
|
||||||
|
const index = newSelectedImages.indexOf(imageId);
|
||||||
|
|
||||||
|
if (index != -1) {
|
||||||
|
if (selectedImages.length == 1) {
|
||||||
|
newSelectedImages = [];
|
||||||
|
} else {
|
||||||
|
const copy = [
|
||||||
|
...newSelectedImages.slice(0, index),
|
||||||
|
...newSelectedImages.slice(index + 1),
|
||||||
|
];
|
||||||
|
newSelectedImages = copy;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newSelectedImages.push(imageId);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
setSelectedImages(newSelectedImages);
|
||||||
|
},
|
||||||
|
[selectedImages, setSelectedImages],
|
||||||
|
);
|
||||||
|
|
||||||
|
// actions
|
||||||
|
|
||||||
|
const trainModel = useCallback(() => {
|
||||||
|
axios.post(`classification/${model.name}/train`);
|
||||||
|
}, [model]);
|
||||||
|
|
||||||
|
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
|
||||||
|
const onDelete = useCallback(
|
||||||
|
(ids: string[], isName: boolean = false) => {
|
||||||
|
const api =
|
||||||
|
pageToggle == "train"
|
||||||
|
? `/classification/${model.name}/train/delete`
|
||||||
|
: `/classification/${model.name}/dataset/${pageToggle}/delete`;
|
||||||
|
|
||||||
|
axios
|
||||||
|
.post(api, { ids })
|
||||||
|
.then((resp) => {
|
||||||
|
setSelectedImages([]);
|
||||||
|
|
||||||
|
if (resp.status == 200) {
|
||||||
|
if (isName) {
|
||||||
|
toast.success(
|
||||||
|
t("toast.success.deletedCategory", { count: ids.length }),
|
||||||
|
{
|
||||||
|
position: "top-center",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
toast.success(
|
||||||
|
t("toast.success.deletedImage", { count: ids.length }),
|
||||||
|
{
|
||||||
|
position: "top-center",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pageToggle == "train") {
|
||||||
|
refreshTrain();
|
||||||
|
} else {
|
||||||
|
refreshDataset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
const errorMessage =
|
||||||
|
error.response?.data?.message ||
|
||||||
|
error.response?.data?.detail ||
|
||||||
|
"Unknown error";
|
||||||
|
if (isName) {
|
||||||
|
toast.error(
|
||||||
|
t("toast.error.deleteCategoryFailed", { errorMessage }),
|
||||||
|
{
|
||||||
|
position: "top-center",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
toast.error(t("toast.error.deleteImageFailed", { errorMessage }), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[pageToggle, model, refreshTrain, refreshDataset, t],
|
||||||
|
);
|
||||||
|
|
||||||
|
// keyboard
|
||||||
|
|
||||||
|
useKeyboardListener(["a", "Escape"], (key, modifiers) => {
|
||||||
|
if (modifiers.repeat || !modifiers.down) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (key) {
|
||||||
|
case "a":
|
||||||
|
if (modifiers.ctrl) {
|
||||||
|
if (selectedImages.length) {
|
||||||
|
setSelectedImages([]);
|
||||||
|
} else {
|
||||||
|
setSelectedImages([
|
||||||
|
...(pageToggle === "train"
|
||||||
|
? trainImages || []
|
||||||
|
: dataset?.[pageToggle] || []),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "Escape":
|
||||||
|
setSelectedImages([]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setSelectedImages([]);
|
||||||
|
}, [pageToggle]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex size-full flex-col overflow-hidden">
|
||||||
|
<Toaster />
|
||||||
|
|
||||||
|
<AlertDialog
|
||||||
|
open={!!deleteDialogOpen}
|
||||||
|
onOpenChange={() => setDeleteDialogOpen(null)}
|
||||||
|
>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>
|
||||||
|
{t(
|
||||||
|
pageToggle == "train"
|
||||||
|
? "deleteTrainImages.title"
|
||||||
|
: "deleteDatasetImages.title",
|
||||||
|
)}
|
||||||
|
</AlertDialogTitle>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
<Trans
|
||||||
|
ns="views/classificationModel"
|
||||||
|
values={{ count: deleteDialogOpen?.length, dataset: pageToggle }}
|
||||||
|
>
|
||||||
|
{pageToggle == "train"
|
||||||
|
? "deleteTrainImages.desc"
|
||||||
|
: "deleteDatasetImages.desc"}
|
||||||
|
</Trans>
|
||||||
|
</AlertDialogDescription>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel>
|
||||||
|
{t("button.cancel", { ns: "common" })}
|
||||||
|
</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
className={buttonVariants({ variant: "destructive" })}
|
||||||
|
onClick={() => {
|
||||||
|
if (deleteDialogOpen) {
|
||||||
|
onDelete(deleteDialogOpen);
|
||||||
|
setDeleteDialogOpen(null);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t("button.delete", { ns: "common" })}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
|
||||||
|
<div className="flex flex-row justify-between gap-2 p-2 align-middle">
|
||||||
|
<LibrarySelector
|
||||||
|
pageToggle={pageToggle}
|
||||||
|
dataset={dataset || {}}
|
||||||
|
trainImages={trainImages || []}
|
||||||
|
setPageToggle={setPageToggle}
|
||||||
|
onDelete={onDelete}
|
||||||
|
onRename={() => {}}
|
||||||
|
/>
|
||||||
|
{selectedImages?.length > 0 ? (
|
||||||
|
<div className="flex items-center justify-center gap-2">
|
||||||
|
<div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground">
|
||||||
|
<div className="p-1">{`${selectedImages.length} selected`}</div>
|
||||||
|
<div className="p-1">{"|"}</div>
|
||||||
|
<div
|
||||||
|
className="cursor-pointer p-2 text-primary hover:rounded-lg hover:bg-secondary"
|
||||||
|
onClick={() => setSelectedImages([])}
|
||||||
|
>
|
||||||
|
{t("button.unselect", { ns: "common" })}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
className="flex gap-2"
|
||||||
|
onClick={() => setDeleteDialogOpen(selectedImages)}
|
||||||
|
>
|
||||||
|
<LuTrash2 className="size-7 rounded-md p-1 text-secondary-foreground" />
|
||||||
|
{isDesktop && t("button.deleteImages")}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<Button onClick={trainModel}>Train Model</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{pageToggle == "train" ? (
|
||||||
|
<TrainGrid
|
||||||
|
model={model}
|
||||||
|
classes={Object.keys(dataset || {})}
|
||||||
|
trainImages={trainImages || []}
|
||||||
|
selectedImages={selectedImages}
|
||||||
|
onRefresh={refreshTrain}
|
||||||
|
onClickImages={onClickImages}
|
||||||
|
onDelete={onDelete}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<DatasetGrid
|
||||||
|
modelName={model.name}
|
||||||
|
categoryName={pageToggle}
|
||||||
|
images={dataset?.[pageToggle] || []}
|
||||||
|
selectedImages={selectedImages}
|
||||||
|
onClickImages={onClickImages}
|
||||||
|
onDelete={onDelete}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type LibrarySelectorProps = {
|
||||||
|
pageToggle: string | undefined;
|
||||||
|
dataset: { [id: string]: string[] };
|
||||||
|
trainImages: string[];
|
||||||
|
setPageToggle: (toggle: string) => void;
|
||||||
|
onDelete: (ids: string[], isName: boolean) => void;
|
||||||
|
onRename: (old_name: string, new_name: string) => void;
|
||||||
|
};
|
||||||
|
function LibrarySelector({
|
||||||
|
pageToggle,
|
||||||
|
dataset,
|
||||||
|
trainImages,
|
||||||
|
setPageToggle,
|
||||||
|
onDelete,
|
||||||
|
onRename,
|
||||||
|
}: LibrarySelectorProps) {
|
||||||
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
|
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
|
||||||
|
const [renameFace, setRenameFace] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const handleDeleteFace = useCallback(
|
||||||
|
(name: string) => {
|
||||||
|
// Get all image IDs for this face
|
||||||
|
const imageIds = dataset?.[name] || [];
|
||||||
|
|
||||||
|
onDelete(imageIds, true);
|
||||||
|
setPageToggle("train");
|
||||||
|
},
|
||||||
|
[dataset, onDelete, setPageToggle],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSetOpen = useCallback(
|
||||||
|
(open: boolean) => {
|
||||||
|
setRenameFace(open ? renameFace : null);
|
||||||
|
},
|
||||||
|
[renameFace],
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Dialog
|
||||||
|
open={!!confirmDelete}
|
||||||
|
onOpenChange={(open) => !open && setConfirmDelete(null)}
|
||||||
|
>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>{t("deleteCategory.title")}</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
{t("deleteCategory.desc", { name: confirmDelete })}
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
<div className="flex justify-end gap-2">
|
||||||
|
<Button variant="outline" onClick={() => setConfirmDelete(null)}>
|
||||||
|
{t("button.cancel", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="destructive"
|
||||||
|
onClick={() => {
|
||||||
|
if (confirmDelete) {
|
||||||
|
handleDeleteFace(confirmDelete);
|
||||||
|
setConfirmDelete(null);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t("button.delete", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
|
<TextEntryDialog
|
||||||
|
open={!!renameFace}
|
||||||
|
setOpen={handleSetOpen}
|
||||||
|
title={t("renameCategory.title")}
|
||||||
|
description={t("renameCategory.desc", { name: renameFace })}
|
||||||
|
onSave={(newName) => {
|
||||||
|
onRename(renameFace!, newName);
|
||||||
|
setRenameFace(null);
|
||||||
|
}}
|
||||||
|
defaultValue={renameFace || ""}
|
||||||
|
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
|
||||||
|
regexErrorMessage={t("description.invalidName")}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<Button className="flex justify-between smart-capitalize">
|
||||||
|
{pageToggle == "train" ? t("train.title") : pageToggle}
|
||||||
|
<span className="ml-2 text-primary-variant">
|
||||||
|
(
|
||||||
|
{(pageToggle &&
|
||||||
|
(pageToggle == "train"
|
||||||
|
? trainImages.length
|
||||||
|
: dataset?.[pageToggle]?.length)) ||
|
||||||
|
0}
|
||||||
|
)
|
||||||
|
</span>
|
||||||
|
</Button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent
|
||||||
|
className="scrollbar-container max-h-[40dvh] min-w-[220px] overflow-y-auto"
|
||||||
|
align="start"
|
||||||
|
>
|
||||||
|
<DropdownMenuItem
|
||||||
|
className="flex cursor-pointer items-center justify-start gap-2"
|
||||||
|
aria-label={t("train.aria")}
|
||||||
|
onClick={() => setPageToggle("train")}
|
||||||
|
>
|
||||||
|
<div>{t("train.title")}</div>
|
||||||
|
<div className="text-secondary-foreground">
|
||||||
|
({trainImages.length})
|
||||||
|
</div>
|
||||||
|
</DropdownMenuItem>
|
||||||
|
{trainImages.length > 0 && Object.keys(dataset).length > 0 && (
|
||||||
|
<>
|
||||||
|
<DropdownMenuSeparator />
|
||||||
|
<div className="mb-1 ml-1.5 text-xs text-secondary-foreground">
|
||||||
|
{t("categories")}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{Object.keys(dataset).map((id) => (
|
||||||
|
<DropdownMenuItem
|
||||||
|
key={id}
|
||||||
|
className="group flex items-center justify-between"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className="flex-grow cursor-pointer"
|
||||||
|
onClick={() => setPageToggle(id)}
|
||||||
|
>
|
||||||
|
{id}
|
||||||
|
<span className="ml-2 text-muted-foreground">
|
||||||
|
({dataset?.[id].length})
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-0.5">
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="size-7 lg:opacity-0 lg:transition-opacity lg:group-hover:opacity-100"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setRenameFace(id);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<LuPencil className="size-4 text-primary" />
|
||||||
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("button.renameCategory")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="size-7 lg:opacity-0 lg:transition-opacity lg:group-hover:opacity-100"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setConfirmDelete(id);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<LuTrash2 className="size-4 text-destructive" />
|
||||||
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("button.deleteCategory")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</DropdownMenuItem>
|
||||||
|
))}
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatasetGridProps = {
|
||||||
|
modelName: string;
|
||||||
|
categoryName: string;
|
||||||
|
images: string[];
|
||||||
|
selectedImages: string[];
|
||||||
|
onClickImages: (images: string[], ctrl: boolean) => void;
|
||||||
|
onDelete: (ids: string[]) => void;
|
||||||
|
};
|
||||||
|
function DatasetGrid({
|
||||||
|
modelName,
|
||||||
|
categoryName,
|
||||||
|
images,
|
||||||
|
selectedImages,
|
||||||
|
onClickImages,
|
||||||
|
onDelete,
|
||||||
|
}: DatasetGridProps) {
|
||||||
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-wrap gap-2 overflow-y-auto p-2">
|
||||||
|
{images.map((image) => (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex w-60 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]",
|
||||||
|
selectedImages.includes(image)
|
||||||
|
? "shadow-selected outline-selected"
|
||||||
|
: "outline-transparent duration-500",
|
||||||
|
)}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
|
||||||
|
if (e.ctrlKey || e.metaKey) {
|
||||||
|
onClickImages([image], true);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"w-full overflow-hidden p-2 *:text-card-foreground",
|
||||||
|
isMobile && "flex justify-center",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
className="rounded-lg"
|
||||||
|
src={`${baseUrl}clips/${modelName}/dataset/${categoryName}/${image}`}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="rounded-b-lg bg-card p-3">
|
||||||
|
<div className="flex w-full flex-row items-center justify-between gap-2">
|
||||||
|
<div className="flex w-full flex-row items-start justify-end gap-5 md:gap-4">
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger>
|
||||||
|
<LuTrash2
|
||||||
|
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onDelete([image]);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("button.deleteClassificationAttempts")}
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type TrainGridProps = {
|
||||||
|
model: CustomClassificationModelConfig;
|
||||||
|
classes: string[];
|
||||||
|
trainImages: string[];
|
||||||
|
selectedImages: string[];
|
||||||
|
onClickImages: (images: string[], ctrl: boolean) => void;
|
||||||
|
onRefresh: () => void;
|
||||||
|
onDelete: (ids: string[]) => void;
|
||||||
|
};
|
||||||
|
function TrainGrid({
|
||||||
|
model,
|
||||||
|
classes,
|
||||||
|
trainImages,
|
||||||
|
selectedImages,
|
||||||
|
onClickImages,
|
||||||
|
onRefresh,
|
||||||
|
onDelete,
|
||||||
|
}: TrainGridProps) {
|
||||||
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
|
|
||||||
|
const trainData = useMemo(
|
||||||
|
() =>
|
||||||
|
trainImages
|
||||||
|
.map((raw) => {
|
||||||
|
const parts = raw.replaceAll(".webp", "").split("-");
|
||||||
|
return {
|
||||||
|
raw,
|
||||||
|
timestamp: parts[0],
|
||||||
|
label: parts[1],
|
||||||
|
score: Number.parseFloat(parts[2]) * 100,
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
|
||||||
|
[trainImages],
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-wrap gap-2 overflow-y-auto p-2">
|
||||||
|
{trainData?.map((data) => (
|
||||||
|
<div
|
||||||
|
key={data.timestamp}
|
||||||
|
className={cn(
|
||||||
|
"flex w-56 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]",
|
||||||
|
selectedImages.includes(data.raw)
|
||||||
|
? "shadow-selected outline-selected"
|
||||||
|
: "outline-transparent duration-500",
|
||||||
|
)}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onClickImages([data.raw], e.ctrlKey || e.metaKey);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"w-full overflow-hidden p-2 *:text-card-foreground",
|
||||||
|
isMobile && "flex justify-center",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
className="w-56 rounded-lg"
|
||||||
|
src={`${baseUrl}clips/${model.name}/train/${data.raw}`}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="rounded-b-lg bg-card p-3">
|
||||||
|
<div className="flex w-full flex-row items-center justify-between gap-2">
|
||||||
|
<div className="flex flex-col items-start text-xs text-primary-variant">
|
||||||
|
<div className="smart-capitalize">{data.label}</div>
|
||||||
|
<div>{data.score}%</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
|
||||||
|
<ClassificationSelectionDialog
|
||||||
|
classes={classes}
|
||||||
|
modelName={model.name}
|
||||||
|
image={data.raw}
|
||||||
|
onRefresh={onRefresh}
|
||||||
|
>
|
||||||
|
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
|
||||||
|
</ClassificationSelectionDialog>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger>
|
||||||
|
<LuTrash2
|
||||||
|
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onDelete([data.raw]);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("button.deleteClassificationAttempts")}
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user