Implement face recognition training in UI (#15786)

* Rename debug to train

* Add api to train image as person

* Cleanup model running

* Formatting

* Fix

* Set face recognition page title
This commit is contained in:
Nicolas Mowen 2025-01-02 16:44:25 -06:00 committed by Blake Blackshear
parent 172e7d494f
commit 281407247b
5 changed files with 231 additions and 72 deletions

View File

@ -2,6 +2,9 @@
import logging import logging
import os import os
import random
import shutil
import string
from fastapi import APIRouter, Request, UploadFile from fastapi import APIRouter, Request, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -22,7 +25,13 @@ def get_faces():
for name in os.listdir(FACE_DIR): for name in os.listdir(FACE_DIR):
face_dict[name] = [] face_dict[name] = []
for file in os.listdir(os.path.join(FACE_DIR, name)):
face_dir = os.path.join(FACE_DIR, name)
if not os.path.isdir(face_dir):
continue
for file in os.listdir(face_dir):
face_dict[name].append(file) face_dict[name].append(file)
return JSONResponse(status_code=200, content=face_dict) return JSONResponse(status_code=200, content=face_dict)
@ -38,6 +47,39 @@ async def register_face(request: Request, name: str, file: UploadFile):
) )
@router.post("/faces/train/{name}/classify")
def train_face(name: str, body: dict = None):
json: dict[str, any] = body or {}
training_file = os.path.join(
FACE_DIR, f"train/{sanitize_filename(json.get('training_file', ''))}"
)
if not training_file or not os.path.isfile(training_file):
return JSONResponse(
content=(
{
"success": False,
"message": f"Invalid filename or no file exists: {training_file}",
}
),
status_code=404,
)
rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
new_name = f"{name}-{rand_id}.webp"
new_file = os.path.join(FACE_DIR, f"{name}/{new_name}")
shutil.move(training_file, new_file)
return JSONResponse(
content=(
{
"success": True,
"message": f"Successfully saved {training_file} as {new_name}.",
}
),
status_code=200,
)
@router.post("/faces/{name}/delete") @router.post("/faces/{name}/delete")
def deregister_faces(request: Request, name: str, body: dict = None): def deregister_faces(request: Request, name: str, body: dict = None):
json: dict[str, any] = body or {} json: dict[str, any] = body or {}

View File

@ -517,7 +517,7 @@ class EmbeddingMaintainer(threading.Thread):
if self.config.face_recognition.save_attempts: if self.config.face_recognition.save_attempts:
# write face to library # write face to library
folder = os.path.join(FACE_DIR, "debug") folder = os.path.join(FACE_DIR, "train")
file = os.path.join(folder, f"{id}-{sub_label}-{score}-{face_score}.webp") file = os.path.join(folder, f"{id}-{sub_label}-{score}-{face_score}.webp")
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, face_frame) cv2.imwrite(file, face_frame)

View File

@ -163,7 +163,12 @@ class FaceClassificationModel:
self.config = config self.config = config
self.db = db self.db = db
self.landmark_detector = cv2.face.createFacemarkLBF() self.landmark_detector = cv2.face.createFacemarkLBF()
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
if os.path.isfile("/config/model_cache/facedet/landmarkdet.yaml"):
self.landmark_detector.loadModel(
"/config/model_cache/facedet/landmarkdet.yaml"
)
self.recognizer: cv2.face.LBPHFaceRecognizer = ( self.recognizer: cv2.face.LBPHFaceRecognizer = (
cv2.face.LBPHFaceRecognizer_create( cv2.face.LBPHFaceRecognizer_create(
radius=2, threshold=(1 - config.min_score) * 1000 radius=2, threshold=(1 - config.min_score) * 1000
@ -178,13 +183,21 @@ class FaceClassificationModel:
dir = "/media/frigate/clips/faces" dir = "/media/frigate/clips/faces"
for idx, name in enumerate(os.listdir(dir)): for idx, name in enumerate(os.listdir(dir)):
if name == "debug": if name == "train":
continue
face_folder = os.path.join(dir, name)
if not os.path.isdir(face_folder):
continue continue
self.label_map[idx] = name self.label_map[idx] = name
face_folder = os.path.join(dir, name)
for image in os.listdir(face_folder): for image in os.listdir(face_folder):
img = cv2.imread(os.path.join(face_folder, image)) img = cv2.imread(os.path.join(face_folder, image))
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = self.__align_face(img, img.shape[1], img.shape[0]) img = self.__align_face(img, img.shape[1], img.shape[0])
faces.append(img) faces.append(img)

View File

@ -0,0 +1,25 @@
import { forwardRef } from "react";
import { LuPlus, LuScanFace } from "react-icons/lu";
import { cn } from "@/lib/utils";
type AddFaceIconProps = {
className?: string;
onClick?: () => void;
};
const AddFaceIcon = forwardRef<HTMLDivElement, AddFaceIconProps>(
({ className, onClick }, ref) => {
return (
<div
ref={ref}
className={cn("relative flex items-center", className)}
onClick={onClick}
>
<LuScanFace className="size-full" />
<LuPlus className="absolute size-4 translate-x-3 translate-y-3" />
</div>
);
},
);
export default AddFaceIcon;

View File

@ -1,19 +1,36 @@
import { baseUrl } from "@/api/baseUrl"; import { baseUrl } from "@/api/baseUrl";
import Chip from "@/components/indicators/Chip"; import AddFaceIcon from "@/components/icons/AddFaceIcon";
import UploadImageDialog from "@/components/overlay/dialog/UploadImageDialog"; import UploadImageDialog from "@/components/overlay/dialog/UploadImageDialog";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area"; import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
import { Toaster } from "@/components/ui/sonner"; import { Toaster } from "@/components/ui/sonner";
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import useOptimisticState from "@/hooks/use-optimistic-state"; import useOptimisticState from "@/hooks/use-optimistic-state";
import axios from "axios"; import axios from "axios";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { isDesktop } from "react-device-detect"; import { LuImagePlus, LuTrash2 } from "react-icons/lu";
import { LuImagePlus, LuTrash } from "react-icons/lu";
import { toast } from "sonner"; import { toast } from "sonner";
import useSWR from "swr"; import useSWR from "swr";
export default function FaceLibrary() { export default function FaceLibrary() {
// title
useEffect(() => {
document.title = "Face Library - Frigate";
}, []);
const [page, setPage] = useState<string>(); const [page, setPage] = useState<string>();
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
const tabsRef = useRef<HTMLDivElement | null>(null); const tabsRef = useRef<HTMLDivElement | null>(null);
@ -24,7 +41,7 @@ export default function FaceLibrary() {
const faces = useMemo<string[]>( const faces = useMemo<string[]>(
() => () =>
faceData ? Object.keys(faceData).filter((face) => face != "debug") : [], faceData ? Object.keys(faceData).filter((face) => face != "train") : [],
[faceData], [faceData],
); );
const faceImages = useMemo<string[]>( const faceImages = useMemo<string[]>(
@ -32,24 +49,24 @@ export default function FaceLibrary() {
[pageToggle, faceData], [pageToggle, faceData],
); );
const faceAttempts = useMemo<string[]>( const trainImages = useMemo<string[]>(
() => faceData?.["debug"] || [], () => faceData?.["train"] || [],
[faceData], [faceData],
); );
useEffect(() => { useEffect(() => {
if (!pageToggle) { if (!pageToggle) {
if (faceAttempts.length > 0) { if (trainImages.length > 0) {
setPageToggle("attempts"); setPageToggle("train");
} else if (faces) { } else if (faces) {
setPageToggle(faces[0]); setPageToggle(faces[0]);
} }
} else if (pageToggle == "attempts" && faceAttempts.length == 0) { } else if (pageToggle == "train" && trainImages.length == 0) {
setPageToggle(faces[0]); setPageToggle(faces[0]);
} }
// we need to listen on the value of the faces list // we need to listen on the value of the faces list
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [faceAttempts, faces]); }, [trainImages, faces]);
// upload // upload
@ -117,15 +134,15 @@ export default function FaceLibrary() {
} }
}} }}
> >
{faceAttempts.length > 0 && ( {trainImages.length > 0 && (
<> <>
<ToggleGroupItem <ToggleGroupItem
value="attempts" value="train"
className={`flex scroll-mx-10 items-center justify-between gap-2 ${pageToggle == "attempts" ? "" : "*:text-muted-foreground"}`} className={`flex scroll-mx-10 items-center justify-between gap-2 ${pageToggle == "train" ? "" : "*:text-muted-foreground"}`}
data-nav-item="attempts" data-nav-item="train"
aria-label="Select attempts" aria-label="Select train"
> >
<div>Attempts</div> <div>Train</div>
</ToggleGroupItem> </ToggleGroupItem>
<div>|</div> <div>|</div>
</> </>
@ -148,8 +165,12 @@ export default function FaceLibrary() {
</ScrollArea> </ScrollArea>
</div> </div>
{pageToggle && {pageToggle &&
(pageToggle == "attempts" ? ( (pageToggle == "train" ? (
<AttemptsGrid attemptImages={faceAttempts} onRefresh={refreshFaces} /> <TrainingGrid
attemptImages={trainImages}
faceNames={faces}
onRefresh={refreshFaces}
/>
) : ( ) : (
<FaceGrid <FaceGrid
faceImages={faceImages} faceImages={faceImages}
@ -162,15 +183,25 @@ export default function FaceLibrary() {
); );
} }
type AttemptsGridProps = { type TrainingGridProps = {
attemptImages: string[]; attemptImages: string[];
faceNames: string[];
onRefresh: () => void; onRefresh: () => void;
}; };
function AttemptsGrid({ attemptImages, onRefresh }: AttemptsGridProps) { function TrainingGrid({
attemptImages,
faceNames,
onRefresh,
}: TrainingGridProps) {
return ( return (
<div className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll"> <div className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll">
{attemptImages.map((image: string) => ( {attemptImages.map((image: string) => (
<FaceAttempt key={image} image={image} onRefresh={onRefresh} /> <FaceAttempt
key={image}
image={image}
faceNames={faceNames}
onRefresh={onRefresh}
/>
))} ))}
</div> </div>
); );
@ -178,11 +209,10 @@ function AttemptsGrid({ attemptImages, onRefresh }: AttemptsGridProps) {
type FaceAttemptProps = { type FaceAttemptProps = {
image: string; image: string;
faceNames: string[];
onRefresh: () => void; onRefresh: () => void;
}; };
function FaceAttempt({ image, onRefresh }: FaceAttemptProps) { function FaceAttempt({ image, faceNames, onRefresh }: FaceAttemptProps) {
const [hovered, setHovered] = useState(false);
const data = useMemo(() => { const data = useMemo(() => {
const parts = image.split("-"); const parts = image.split("-");
@ -193,9 +223,36 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) {
}; };
}, [image]); }, [image]);
const onTrainAttempt = useCallback(
(trainName: string) => {
axios
.post(`/faces/train/${trainName}/classify`, { training_file: image })
.then((resp) => {
if (resp.status == 200) {
toast.success(`Successfully trained face.`, {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
if (error.response?.data?.message) {
toast.error(`Failed to train: ${error.response.data.message}`, {
position: "top-center",
});
} else {
toast.error(`Failed to train: ${error.message}`, {
position: "top-center",
});
}
});
},
[image, onRefresh],
);
const onDelete = useCallback(() => { const onDelete = useCallback(() => {
axios axios
.post(`/faces/debug/delete`, { ids: [image] }) .post(`/faces/train/delete`, { ids: [image] })
.then((resp) => { .then((resp) => {
if (resp.status == 200) { if (resp.status == 200) {
toast.success(`Successfully deleted face.`, { toast.success(`Successfully deleted face.`, {
@ -218,28 +275,50 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) {
}, [image, onRefresh]); }, [image, onRefresh]);
return ( return (
<div <div className="relative flex flex-col rounded-lg">
className="relative h-min" <div className="w-full overflow-hidden rounded-t-lg border border-t-0 *:text-card-foreground">
onMouseEnter={isDesktop ? () => setHovered(true) : undefined} <img className="h-40" src={`${baseUrl}clips/faces/train/${image}`} />
onMouseLeave={isDesktop ? () => setHovered(false) : undefined} </div>
onClick={isDesktop ? undefined : () => setHovered(!hovered)} <div className="rounded-b-lg bg-card p-2">
> <div className="flex w-full flex-row items-center justify-between gap-2">
{hovered && ( <div className="flex flex-col items-start text-xs text-primary-variant">
<div className="absolute right-1 top-1"> <div className="capitalize">{data.name}</div>
<Chip <div>{Number.parseFloat(data.score) * 100}%</div>
className="cursor-pointer rounded-md bg-gray-500 bg-gradient-to-br from-gray-400 to-gray-500" </div>
onClick={() => onDelete()} <div className="flex flex-row items-start justify-end gap-5 md:gap-4">
> <Tooltip>
<LuTrash className="size-4 fill-destructive text-destructive" /> <DropdownMenu>
</Chip> <DropdownMenuTrigger>
<TooltipTrigger>
<AddFaceIcon className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</TooltipTrigger>
</DropdownMenuTrigger>
<DropdownMenuContent>
<DropdownMenuLabel>Train Face as:</DropdownMenuLabel>
{faceNames.map((faceName) => (
<DropdownMenuItem
key={faceName}
className="cursor-pointer capitalize"
onClick={() => onTrainAttempt(faceName)}
>
{faceName}
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
<TooltipContent>Train Face as Person</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={onDelete}
/>
</TooltipTrigger>
<TooltipContent>Delete Face Attempt</TooltipContent>
</Tooltip>
</div>
</div> </div>
)}
<div className="rounded-md bg-secondary">
<img
className="h-40 rounded-md"
src={`${baseUrl}clips/faces/debug/${image}`}
/>
<div className="p-2">{`${data.name}: ${data.score}`}</div>
</div> </div>
</div> </div>
); );
@ -280,8 +359,6 @@ type FaceImageProps = {
onRefresh: () => void; onRefresh: () => void;
}; };
function FaceImage({ name, image, onRefresh }: FaceImageProps) { function FaceImage({ name, image, onRefresh }: FaceImageProps) {
const [hovered, setHovered] = useState(false);
const onDelete = useCallback(() => { const onDelete = useCallback(() => {
axios axios
.post(`/faces/${name}/delete`, { ids: [image] }) .post(`/faces/${name}/delete`, { ids: [image] })
@ -307,26 +384,28 @@ function FaceImage({ name, image, onRefresh }: FaceImageProps) {
}, [name, image, onRefresh]); }, [name, image, onRefresh]);
return ( return (
<div <div className="relative flex flex-col rounded-lg">
className="relative h-40" <div className="w-full overflow-hidden rounded-t-lg border border-t-0 *:text-card-foreground">
onMouseEnter={isDesktop ? () => setHovered(true) : undefined} <img className="h-40" src={`${baseUrl}clips/faces/${name}/${image}`} />
onMouseLeave={isDesktop ? () => setHovered(false) : undefined} </div>
onClick={isDesktop ? undefined : () => setHovered(!hovered)} <div className="rounded-b-lg bg-card p-2">
> <div className="flex w-full flex-row items-center justify-between gap-2">
{hovered && ( <div className="flex flex-col items-start text-xs text-primary-variant">
<div className="absolute right-1 top-1"> <div className="capitalize">{name}</div>
<Chip </div>
className="cursor-pointer rounded-md bg-gray-500 bg-gradient-to-br from-gray-400 to-gray-500" <div className="flex flex-row items-start justify-end gap-5 md:gap-4">
onClick={() => onDelete()} <Tooltip>
> <TooltipTrigger>
<LuTrash className="size-4 fill-destructive text-destructive" /> <LuTrash2
</Chip> className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={onDelete}
/>
</TooltipTrigger>
<TooltipContent>Delete Face Attempt</TooltipContent>
</Tooltip>
</div>
</div> </div>
)} </div>
<img
className="h-40 rounded-md"
src={`${baseUrl}clips/faces/${name}/${image}`}
/>
</div> </div>
); );
} }