mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Improve object classification (#18908)
* Ui improvements * Improve image cropping and model saving * Improve naming * Add logs for training * Improve model labeling * Don't set sub label for none object classification * Cleanup
This commit is contained in:
parent
4b18d54d3d
commit
53315342c0
@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
super().__init__(config, metrics)
|
super().__init__(config, metrics)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||||
self.train_dir = os.path.join(self.model_dir, "train")
|
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||||
self.interpreter: Interpreter = None
|
self.interpreter: Interpreter = None
|
||||||
self.sub_label_publisher = sub_label_publisher
|
self.sub_label_publisher = sub_label_publisher
|
||||||
self.tensor_input_details: dict[str, Any] = None
|
self.tensor_input_details: dict[str, Any] = None
|
||||||
@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
obj_data["box"][1],
|
obj_data["box"][1],
|
||||||
obj_data["box"][2],
|
obj_data["box"][2],
|
||||||
obj_data["box"][3],
|
obj_data["box"][3],
|
||||||
224,
|
max(
|
||||||
|
obj_data["box"][1] - obj_data["box"][0],
|
||||||
|
obj_data["box"][3] - obj_data["box"][2],
|
||||||
|
),
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||||
input = rgb[
|
crop = rgb[
|
||||||
y:y2,
|
y:y2,
|
||||||
x:x2,
|
x:x2,
|
||||||
]
|
]
|
||||||
|
|
||||||
if input.shape != (224, 224):
|
if crop.shape != (224, 224):
|
||||||
input = cv2.resize(input, (224, 224))
|
crop = cv2.resize(crop, (224, 224))
|
||||||
|
|
||||||
input = np.expand_dims(input, axis=0)
|
input = np.expand_dims(crop, axis=0)
|
||||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||||
self.interpreter.invoke()
|
self.interpreter.invoke()
|
||||||
res: np.ndarray = self.interpreter.get_tensor(
|
res: np.ndarray = self.interpreter.get_tensor(
|
||||||
@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
|
|
||||||
write_classification_attempt(
|
write_classification_attempt(
|
||||||
self.train_dir,
|
self.train_dir,
|
||||||
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||||
now,
|
now,
|
||||||
self.labelmap[best_id],
|
self.labelmap[best_id],
|
||||||
score,
|
score,
|
||||||
@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.sub_label_publisher.publish(
|
sub_label = self.labelmap[best_id]
|
||||||
EventMetadataTypeEnum.sub_label,
|
|
||||||
(obj_data["id"], self.labelmap[best_id], score),
|
|
||||||
)
|
|
||||||
self.detected_objects[obj_data["id"]] = score
|
self.detected_objects[obj_data["id"]] = score
|
||||||
|
|
||||||
|
if sub_label != "none":
|
||||||
|
self.sub_label_publisher.publish(
|
||||||
|
EventMetadataTypeEnum.sub_label,
|
||||||
|
(obj_data["id"], sub_label, score),
|
||||||
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
if request_data.get("model_name") == self.model_config.name:
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
|
@ -49,6 +49,7 @@ def __train_classification_model(model_name: str) -> bool:
|
|||||||
from tensorflow.keras.applications import MobileNetV2
|
from tensorflow.keras.applications import MobileNetV2
|
||||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
logger.info(f"Kicking off classification training for {model_name}.")
|
||||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||||
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||||
num_classes = len(
|
num_classes = len(
|
||||||
|
@ -82,7 +82,7 @@ export default function ClassificationSelectionDialog({
|
|||||||
);
|
);
|
||||||
|
|
||||||
// control
|
// control
|
||||||
const [newFace, setNewFace] = useState(false);
|
const [newClass, setNewClass] = useState(false);
|
||||||
|
|
||||||
// components
|
// components
|
||||||
const Selector = isDesktop ? DropdownMenu : Drawer;
|
const Selector = isDesktop ? DropdownMenu : Drawer;
|
||||||
@ -98,10 +98,10 @@ export default function ClassificationSelectionDialog({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={className ?? ""}>
|
<div className={className ?? ""}>
|
||||||
{newFace && (
|
{newClass && (
|
||||||
<TextEntryDialog
|
<TextEntryDialog
|
||||||
open={true}
|
open={true}
|
||||||
setOpen={setNewFace}
|
setOpen={setNewClass}
|
||||||
title={t("createCategory.new")}
|
title={t("createCategory.new")}
|
||||||
onSave={(newCat) => onCategorizeImage(newCat)}
|
onSave={(newCat) => onCategorizeImage(newCat)}
|
||||||
/>
|
/>
|
||||||
@ -130,7 +130,7 @@ export default function ClassificationSelectionDialog({
|
|||||||
>
|
>
|
||||||
<SelectorItem
|
<SelectorItem
|
||||||
className="flex cursor-pointer gap-2 smart-capitalize"
|
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||||
onClick={() => setNewFace(true)}
|
onClick={() => setNewClass(true)}
|
||||||
>
|
>
|
||||||
<LuPlus />
|
<LuPlus />
|
||||||
{t("createCategory.new")}
|
{t("createCategory.new")}
|
||||||
@ -142,7 +142,7 @@ export default function ClassificationSelectionDialog({
|
|||||||
onClick={() => onCategorizeImage(category)}
|
onClick={() => onCategorizeImage(category)}
|
||||||
>
|
>
|
||||||
<MdCategory />
|
<MdCategory />
|
||||||
{category}
|
{category.replaceAll("_", " ")}
|
||||||
</SelectorItem>
|
</SelectorItem>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
@ -375,7 +375,7 @@ function LibrarySelector({
|
|||||||
}: LibrarySelectorProps) {
|
}: LibrarySelectorProps) {
|
||||||
const { t } = useTranslation(["views/classificationModel"]);
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
|
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
|
||||||
const [renameFace, setRenameFace] = useState<string | null>(null);
|
const [renameClass, setRenameFace] = useState<string | null>(null);
|
||||||
|
|
||||||
const handleDeleteFace = useCallback(
|
const handleDeleteFace = useCallback(
|
||||||
(name: string) => {
|
(name: string) => {
|
||||||
@ -390,9 +390,9 @@ function LibrarySelector({
|
|||||||
|
|
||||||
const handleSetOpen = useCallback(
|
const handleSetOpen = useCallback(
|
||||||
(open: boolean) => {
|
(open: boolean) => {
|
||||||
setRenameFace(open ? renameFace : null);
|
setRenameFace(open ? renameClass : null);
|
||||||
},
|
},
|
||||||
[renameFace],
|
[renameClass],
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -428,15 +428,15 @@ function LibrarySelector({
|
|||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
<TextEntryDialog
|
<TextEntryDialog
|
||||||
open={!!renameFace}
|
open={!!renameClass}
|
||||||
setOpen={handleSetOpen}
|
setOpen={handleSetOpen}
|
||||||
title={t("renameCategory.title")}
|
title={t("renameCategory.title")}
|
||||||
description={t("renameCategory.desc", { name: renameFace })}
|
description={t("renameCategory.desc", { name: renameClass })}
|
||||||
onSave={(newName) => {
|
onSave={(newName) => {
|
||||||
onRename(renameFace!, newName);
|
onRename(renameClass!, newName);
|
||||||
setRenameFace(null);
|
setRenameFace(null);
|
||||||
}}
|
}}
|
||||||
defaultValue={renameFace || ""}
|
defaultValue={renameClass || ""}
|
||||||
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
|
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
|
||||||
regexErrorMessage={t("description.invalidName")}
|
regexErrorMessage={t("description.invalidName")}
|
||||||
/>
|
/>
|
||||||
@ -484,10 +484,10 @@ function LibrarySelector({
|
|||||||
className="group flex items-center justify-between"
|
className="group flex items-center justify-between"
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
className="flex-grow cursor-pointer"
|
className="flex-grow cursor-pointer capitalize"
|
||||||
onClick={() => setPageToggle(id)}
|
onClick={() => setPageToggle(id)}
|
||||||
>
|
>
|
||||||
{id}
|
{id.replaceAll("_", " ")}
|
||||||
<span className="ml-2 text-muted-foreground">
|
<span className="ml-2 text-muted-foreground">
|
||||||
({dataset?.[id].length})
|
({dataset?.[id].length})
|
||||||
</span>
|
</span>
|
||||||
@ -681,7 +681,9 @@ function TrainGrid({
|
|||||||
<div className="rounded-b-lg bg-card p-3">
|
<div className="rounded-b-lg bg-card p-3">
|
||||||
<div className="flex w-full flex-row items-center justify-between gap-2">
|
<div className="flex w-full flex-row items-center justify-between gap-2">
|
||||||
<div className="flex flex-col items-start text-xs text-primary-variant">
|
<div className="flex flex-col items-start text-xs text-primary-variant">
|
||||||
<div className="smart-capitalize">{data.label}</div>
|
<div className="smart-capitalize">
|
||||||
|
{data.label.replaceAll("_", " ")}
|
||||||
|
</div>
|
||||||
<div>{data.score}%</div>
|
<div>{data.score}%</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
|
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
|
||||||
|
Loading…
Reference in New Issue
Block a user