Improve classification UI (#18910)

* Move threhsold to base model config

* Improve score handling

* Add back button
This commit is contained in:
Nicolas Mowen 2025-06-27 08:35:02 -06:00 committed by GitHub
parent bd6dee5b38
commit 0f4cac736a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 16 deletions

View File

@ -59,9 +59,6 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on."
)
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
class CustomClassificationStateConfig(FrigateBaseModel):
@ -86,6 +83,9 @@ class CustomClassificationObjectConfig(FrigateBaseModel):
class CustomClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable running the model.")
name: str | None = Field(default=None, title="Name of classification model.")
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: CustomClassificationStateConfig | None = Field(default=None)

View File

@ -152,7 +152,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
score,
)
if score >= camera_config.threshold:
if score >= self.model_config.threshold:
self.requestor.send_data(
f"{camera}/classification/{self.model_config.name}",
self.labelmap[best_id],
@ -271,6 +271,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
score,
)
if score < self.model_config.threshold:
logger.debug(f"Score {score} is less than threshold.")
return
if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return

View File

@ -282,6 +282,7 @@ export type CameraStreamingSettings = {
export type CustomClassificationModelConfig = {
enabled: boolean;
name: string;
threshold: number;
object_config: null | {
objects: string[];
};
@ -289,7 +290,6 @@ export type CustomClassificationModelConfig = {
cameras: {
[cameraName: string]: {
crop: [number, number, number, number];
threshold: number;
};
};
motion: boolean;

View File

@ -48,12 +48,15 @@ import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useNavigate } from "react-router-dom";
import { IoMdArrowRoundBack } from "react-icons/io";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
};
export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { t } = useTranslation(["views/classificationModel"]);
const navigate = useNavigate();
const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
@ -294,14 +297,28 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</AlertDialog>
<div className="flex flex-row justify-between gap-2 p-2 align-middle">
<LibrarySelector
pageToggle={pageToggle}
dataset={dataset || {}}
trainImages={trainImages || []}
setPageToggle={setPageToggle}
onDelete={onDelete}
onRename={() => {}}
/>
<div className="flex flex-row items-center justify-center gap-2">
<Button
className="flex items-center gap-2.5 rounded-lg"
aria-label={t("label.back", { ns: "common" })}
onClick={() => navigate(-1)}
>
<IoMdArrowRoundBack className="size-5 text-secondary-foreground" />
{isDesktop && (
<div className="text-primary">
{t("button.back", { ns: "common" })}
</div>
)}
</Button>
<LibrarySelector
pageToggle={pageToggle}
dataset={dataset || {}}
trainImages={trainImages || []}
setPageToggle={setPageToggle}
onDelete={onDelete}
onRename={() => {}}
/>
</div>
{selectedImages?.length > 0 ? (
<div className="flex items-center justify-center gap-2">
<div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground">
@ -640,15 +657,17 @@ function TrainGrid({
trainImages
.map((raw) => {
const parts = raw.replaceAll(".webp", "").split("-");
const rawScore = Number.parseFloat(parts[2]);
return {
raw,
timestamp: parts[0],
label: parts[1],
score: Number.parseFloat(parts[2]) * 100,
score: rawScore * 100,
truePositive: rawScore >= model.threshold,
};
})
.sort((a, b) => b.timestamp.localeCompare(a.timestamp)),
[trainImages],
[model, trainImages],
);
return (
@ -684,7 +703,14 @@ function TrainGrid({
<div className="smart-capitalize">
{data.label.replaceAll("_", " ")}
</div>
<div>{data.score}%</div>
<div
className={cn(
"",
data.truePositive ? "text-success" : "text-danger",
)}
>
{data.score}%
</div>
</div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<ClassificationSelectionDialog