Reclassification (#22603)

* add ability to reclassify images

* add ability to reclassify faces

* work around radix pointer events issue again
This commit is contained in:
Josh Hawkins
2026-03-24 08:18:06 -05:00
committed by GitHub
parent 91ef3b2ceb
commit 854ef320de
7 changed files with 324 additions and 13 deletions

View File

@@ -34,7 +34,11 @@ type ClassificationSelectionDialogProps = {
classes: string[];
modelName: string;
image: string;
onRefresh: () => void;
onRefresh?: () => void;
onCategorize?: (category: string) => void;
excludeCategory?: string;
dialogLabel?: string;
tooltipLabel?: string;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
@@ -43,12 +47,21 @@ export default function ClassificationSelectionDialog({
modelName,
image,
onRefresh,
onCategorize,
excludeCategory,
dialogLabel,
tooltipLabel,
children,
}: ClassificationSelectionDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const onCategorizeImage = useCallback(
(category: string) => {
if (onCategorize) {
onCategorize(category);
return;
}
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
@@ -59,7 +72,7 @@ export default function ClassificationSelectionDialog({
toast.success(t("toast.success.categorizedImage"), {
position: "top-center",
});
onRefresh();
onRefresh?.();
}
})
.catch((error) => {
@@ -72,7 +85,13 @@ export default function ClassificationSelectionDialog({
});
});
},
[modelName, image, onRefresh, t],
[modelName, image, onRefresh, onCategorize, t],
);
const filteredClasses = useMemo(
() =>
excludeCategory ? classes.filter((c) => c !== excludeCategory) : classes,
[classes, excludeCategory],
);
const isChildButton = useMemo(
@@ -111,6 +130,7 @@ export default function ClassificationSelectionDialog({
</SelectorTrigger>
<SelectorContent
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
onCloseAutoFocus={(e) => e.preventDefault()}
>
{isMobile && (
<DrawerHeader className="sr-only">
@@ -118,14 +138,16 @@ export default function ClassificationSelectionDialog({
<DrawerDescription>Details</DrawerDescription>
</DrawerHeader>
)}
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
<DropdownMenuLabel>
{dialogLabel ?? t("categorizeImageAs")}
</DropdownMenuLabel>
<div
className={cn(
"flex max-h-[40dvh] flex-col overflow-y-auto",
isMobile && "gap-2 pb-4",
)}
>
{classes
{filteredClasses
.sort((a, b) => {
if (a === "none") return 1;
if (b === "none") return -1;
@@ -152,7 +174,7 @@ export default function ClassificationSelectionDialog({
</div>
</SelectorContent>
</Selector>
<TooltipContent>{t("categorizeImage")}</TooltipContent>
<TooltipContent>{tooltipLabel ?? t("categorizeImage")}</TooltipContent>
</Tooltip>
</div>
);

View File

@@ -30,17 +30,29 @@ import { Button } from "../ui/button";
type FaceSelectionDialogProps = {
className?: string;
faceNames: string[];
excludeName?: string;
dialogLabel?: string;
tooltipLabel?: string;
onTrainAttempt: (name: string) => void;
children: ReactNode;
};
export default function FaceSelectionDialog({
className,
faceNames,
excludeName,
dialogLabel,
tooltipLabel,
onTrainAttempt,
children,
}: FaceSelectionDialogProps) {
const { t } = useTranslation(["views/faceLibrary"]);
const filteredNames = useMemo(
() =>
excludeName ? faceNames.filter((n) => n !== excludeName) : faceNames,
[faceNames, excludeName],
);
const isChildButton = useMemo(
() => React.isValidElement(children) && children.type === Button,
[children],
@@ -79,6 +91,7 @@ export default function FaceSelectionDialog({
</SelectorTrigger>
<SelectorContent
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
onCloseAutoFocus={(e) => e.preventDefault()}
>
{isMobile && (
<DrawerHeader className="sr-only">
@@ -86,14 +99,16 @@ export default function FaceSelectionDialog({
<DrawerDescription>Details</DrawerDescription>
</DrawerHeader>
)}
<DropdownMenuLabel>{t("trainFaceAs")}</DropdownMenuLabel>
<DropdownMenuLabel>
{dialogLabel ?? t("trainFaceAs")}
</DropdownMenuLabel>
<div
className={cn(
"flex max-h-[40dvh] flex-col overflow-y-auto overflow-x-hidden",
isMobile && "gap-2 pb-4",
)}
>
{faceNames.sort().map((faceName) => (
{filteredNames.sort().map((faceName) => (
<SelectorItem
key={faceName}
className="flex cursor-pointer gap-2 smart-capitalize"
@@ -112,7 +127,7 @@ export default function FaceSelectionDialog({
</div>
</SelectorContent>
</Selector>
<TooltipContent>{t("trainFace")}</TooltipContent>
<TooltipContent>{tooltipLabel ?? t("trainFace")}</TooltipContent>
</Tooltip>
</div>
);

View File

@@ -266,6 +266,34 @@ export default function FaceLibrary() {
[setPageToggle, refreshFaces, t],
);
const onReclassify = useCallback(
(image: string, newName: string) => {
axios
.post(`/faces/${pageToggle}/reclassify`, {
id: image,
new_name: newName,
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.reclassifiedFace"), {
position: "top-center",
});
refreshFaces();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.reclassifyFailed", { errorMessage }), {
position: "top-center",
});
});
},
[pageToggle, refreshFaces, t],
);
// keyboard
const contentRef = useRef<HTMLDivElement | null>(null);
@@ -452,10 +480,12 @@ export default function FaceLibrary() {
<FaceGrid
contentRef={contentRef}
faceImages={faceImages}
faceNames={faces}
pageToggle={pageToggle}
selectedFaces={selectedFaces}
onClickFaces={onClickFaces}
onDelete={onDelete}
onReclassify={onReclassify}
/>
))
)}
@@ -601,11 +631,11 @@ function LibrarySelector({
className="group flex items-center justify-between p-0"
>
<div
className="flex-grow cursor-pointer"
className="flex-grow cursor-pointer px-2 py-1.5"
onClick={() => setPageToggle(face)}
>
{face}
<span className="ml-2 px-2 py-1.5 text-muted-foreground">
<span className="ml-2 text-muted-foreground">
({faceData?.[face].length})
</span>
</div>
@@ -983,18 +1013,22 @@ function FaceAttemptGroup({
type FaceGridProps = {
contentRef: MutableRefObject<HTMLDivElement | null>;
faceImages: string[];
faceNames: string[];
pageToggle: string;
selectedFaces: string[];
onClickFaces: (images: string[], ctrl: boolean) => void;
onDelete: (name: string, ids: string[]) => void;
onReclassify: (image: string, newName: string) => void;
};
function FaceGrid({
contentRef,
faceImages,
faceNames,
pageToggle,
selectedFaces,
onClickFaces,
onDelete,
onReclassify,
}: FaceGridProps) {
const { t } = useTranslation(["views/faceLibrary"]);
@@ -1032,6 +1066,17 @@ function FaceGrid({
i18nLibrary="views/faceLibrary"
onClick={(data, meta) => onClickFaces([data.filename], meta)}
>
<FaceSelectionDialog
faceNames={faceNames}
excludeName={pageToggle}
dialogLabel={t("reclassifyFaceAs")}
tooltipLabel={t("reclassifyFace")}
onTrainAttempt={(newName) => onReclassify(image, newName)}
>
<BlurredIconButton>
<AddFaceIcon className="size-5" />
</BlurredIconButton>
</FaceSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2

View File

@@ -304,6 +304,37 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
[pageToggle, model, refreshTrain, refreshDataset, t],
);
const onReclassify = useCallback(
(image: string, newCategory: string) => {
axios
.post(
`/classification/${model.name}/dataset/${pageToggle}/reclassify`,
{
id: image,
new_category: newCategory,
},
)
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.reclassifiedImage"), {
position: "top-center",
});
refreshDataset();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.reclassifyFailed", { errorMessage }), {
position: "top-center",
});
});
},
[pageToggle, model, refreshDataset, t],
);
// keyboard
const contentRef = useRef<HTMLDivElement | null>(null);
@@ -535,10 +566,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
contentRef={contentRef}
modelName={model.name}
categoryName={pageToggle}
classes={Object.keys(dataset || {})}
images={dataset?.[pageToggle] || []}
selectedImages={selectedImages}
onClickImages={onClickImages}
onDelete={onDelete}
onReclassify={onReclassify}
/>
)}
</div>
@@ -776,19 +809,23 @@ type DatasetGridProps = {
contentRef: MutableRefObject<HTMLDivElement | null>;
modelName: string;
categoryName: string;
classes: string[];
images: string[];
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onDelete: (ids: string[]) => void;
onReclassify: (image: string, newCategory: string) => void;
};
function DatasetGrid({
contentRef,
modelName,
categoryName,
classes,
images,
selectedImages,
onClickImages,
onDelete,
onReclassify,
}: DatasetGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
@@ -816,10 +853,23 @@ function DatasetGrid({
i18nLibrary="views/classificationModel"
onClick={(data, _) => onClickImages([data.filename], true)}
>
<ClassificationSelectionDialog
classes={classes}
modelName={modelName}
image={image}
excludeCategory={categoryName}
dialogLabel={t("reclassifyImageAs")}
tooltipLabel={t("reclassifyImage")}
onCategorize={(newCat) => onReclassify(image, newCat)}
>
<BlurredIconButton>
<TbCategoryPlus className="size-5" />
</BlurredIconButton>
</ClassificationSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-danger"
className="size-5 cursor-pointer text-gray-200 hover:text-danger"
onClick={(e) => {
e.stopPropagation();
onDelete([image]);