Improve classification UI (#18910)

* Move threhsold to base model config

* Improve score handling

* Add back button
This commit is contained in:
Nicolas Mowen 2025-06-27 08:35:02 -06:00
parent 2159d40ad5
commit 8139bc9bbd
4 changed files with 46 additions and 16 deletions

View File

@ -59,9 +59,6 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field( crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on." title="Crop of image frame on this camera to run classification on."
) )
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
class CustomClassificationStateConfig(FrigateBaseModel): class CustomClassificationStateConfig(FrigateBaseModel):
@ -86,6 +83,9 @@ class CustomClassificationObjectConfig(FrigateBaseModel):
class CustomClassificationConfig(FrigateBaseModel): class CustomClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable running the model.") enabled: bool = Field(default=True, title="Enable running the model.")
name: str | None = Field(default=None, title="Name of classification model.") name: str | None = Field(default=None, title="Name of classification model.")
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
object_config: CustomClassificationObjectConfig | None = Field(default=None) object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: CustomClassificationStateConfig | None = Field(default=None) state_config: CustomClassificationStateConfig | None = Field(default=None)

View File

@ -152,7 +152,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
score, score,
) )
if score >= camera_config.threshold: if score >= self.model_config.threshold:
self.requestor.send_data( self.requestor.send_data(
f"{camera}/classification/{self.model_config.name}", f"{camera}/classification/{self.model_config.name}",
self.labelmap[best_id], self.labelmap[best_id],
@ -271,6 +271,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
score, score,
) )
if score < self.model_config.threshold:
logger.debug(f"Score {score} is less than threshold.")
return
if score <= previous_score: if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}") logger.debug(f"Score {score} is worse than previous score {previous_score}")
return return

View File

@ -282,6 +282,7 @@ export type CameraStreamingSettings = {
export type CustomClassificationModelConfig = { export type CustomClassificationModelConfig = {
enabled: boolean; enabled: boolean;
name: string; name: string;
threshold: number;
object_config: null | { object_config: null | {
objects: string[]; objects: string[];
}; };
@ -289,7 +290,6 @@ export type CustomClassificationModelConfig = {
cameras: { cameras: {
[cameraName: string]: { [cameraName: string]: {
crop: [number, number, number, number]; crop: [number, number, number, number];
threshold: number;
}; };
}; };
motion: boolean; motion: boolean;

View File

@ -48,12 +48,15 @@ import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws"; import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws"; import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator"; import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useNavigate } from "react-router-dom";
import { IoMdArrowRoundBack } from "react-icons/io";
type ModelTrainingViewProps = { type ModelTrainingViewProps = {
model: CustomClassificationModelConfig; model: CustomClassificationModelConfig;
}; };
export default function ModelTrainingView({ model }: ModelTrainingViewProps) { export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const navigate = useNavigate();
const [page, setPage] = useState<string>("train"); const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
@ -294,6 +297,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</AlertDialog> </AlertDialog>
<div className="flex flex-row justify-between gap-2 p-2 align-middle"> <div className="flex flex-row justify-between gap-2 p-2 align-middle">
<div className="flex flex-row items-center justify-center gap-2">
<Button
className="flex items-center gap-2.5 rounded-lg"
aria-label={t("label.back", { ns: "common" })}
onClick={() => navigate(-1)}
>
<IoMdArrowRoundBack className="size-5 text-secondary-foreground" />
{isDesktop && (
<div className="text-primary">
{t("button.back", { ns: "common" })}
</div>
)}
</Button>
<LibrarySelector <LibrarySelector
pageToggle={pageToggle} pageToggle={pageToggle}
dataset={dataset || {}} dataset={dataset || {}}
@ -302,6 +318,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
onDelete={onDelete} onDelete={onDelete}
onRename={() => {}} onRename={() => {}}
/> />
</div>
{selectedImages?.length > 0 ? ( {selectedImages?.length > 0 ? (
<div className="flex items-center justify-center gap-2"> <div className="flex items-center justify-center gap-2">
<div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground"> <div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground">
@ -640,15 +657,17 @@ function TrainGrid({
trainImages trainImages
.map((raw) => { .map((raw) => {
const parts = raw.replaceAll(".webp", "").split("-"); const parts = raw.replaceAll(".webp", "").split("-");
const rawScore = Number.parseFloat(parts[2]);
return { return {
raw, raw,
timestamp: parts[0], timestamp: parts[0],
label: parts[1], label: parts[1],
score: Number.parseFloat(parts[2]) * 100, score: rawScore * 100,
truePositive: rawScore >= model.threshold,
}; };
}) })
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)), .sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
[trainImages], [model, trainImages],
); );
return ( return (
@ -684,7 +703,14 @@ function TrainGrid({
<div className="smart-capitalize"> <div className="smart-capitalize">
{data.label.replaceAll("_", " ")} {data.label.replaceAll("_", " ")}
</div> </div>
<div>{data.score}%</div> <div
className={cn(
"",
data.truePositive ? "text-success" : "text-danger",
)}
>
{data.score}%
</div>
</div> </div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4"> <div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<ClassificationSelectionDialog <ClassificationSelectionDialog