mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Improve object classification (#18908)
* Ui improvements * Improve image cropping and model saving * Improve naming * Add logs for training * Improve model labeling * Don't set sub label for none object classification * Cleanup
This commit is contained in:
parent
4b18d54d3d
commit
53315342c0
@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
super().__init__(config, metrics)
|
||||
self.model_config = model_config
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(self.model_dir, "train")
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Interpreter = None
|
||||
self.sub_label_publisher = sub_label_publisher
|
||||
self.tensor_input_details: dict[str, Any] = None
|
||||
@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
obj_data["box"][1],
|
||||
obj_data["box"][2],
|
||||
obj_data["box"][3],
|
||||
224,
|
||||
max(
|
||||
obj_data["box"][1] - obj_data["box"][0],
|
||||
obj_data["box"][3] - obj_data["box"][2],
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
input = rgb[
|
||||
crop = rgb[
|
||||
y:y2,
|
||||
x:x2,
|
||||
]
|
||||
|
||||
if input.shape != (224, 224):
|
||||
input = cv2.resize(input, (224, 224))
|
||||
if crop.shape != (224, 224):
|
||||
crop = cv2.resize(crop, (224, 224))
|
||||
|
||||
input = np.expand_dims(input, axis=0)
|
||||
input = np.expand_dims(crop, axis=0)
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
res: np.ndarray = self.interpreter.get_tensor(
|
||||
@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
now,
|
||||
self.labelmap[best_id],
|
||||
score,
|
||||
@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
||||
return
|
||||
|
||||
self.sub_label_publisher.publish(
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
(obj_data["id"], self.labelmap[best_id], score),
|
||||
)
|
||||
sub_label = self.labelmap[best_id]
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
if sub_label != "none":
|
||||
self.sub_label_publisher.publish(
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
(obj_data["id"], sub_label, score),
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
|
@ -49,6 +49,7 @@ def __train_classification_model(model_name: str) -> bool:
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
logger.info(f"Kicking off classification training for {model_name}.")
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||
num_classes = len(
|
||||
|
@ -82,7 +82,7 @@ export default function ClassificationSelectionDialog({
|
||||
);
|
||||
|
||||
// control
|
||||
const [newFace, setNewFace] = useState(false);
|
||||
const [newClass, setNewClass] = useState(false);
|
||||
|
||||
// components
|
||||
const Selector = isDesktop ? DropdownMenu : Drawer;
|
||||
@ -98,10 +98,10 @@ export default function ClassificationSelectionDialog({
|
||||
|
||||
return (
|
||||
<div className={className ?? ""}>
|
||||
{newFace && (
|
||||
{newClass && (
|
||||
<TextEntryDialog
|
||||
open={true}
|
||||
setOpen={setNewFace}
|
||||
setOpen={setNewClass}
|
||||
title={t("createCategory.new")}
|
||||
onSave={(newCat) => onCategorizeImage(newCat)}
|
||||
/>
|
||||
@ -130,7 +130,7 @@ export default function ClassificationSelectionDialog({
|
||||
>
|
||||
<SelectorItem
|
||||
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||
onClick={() => setNewFace(true)}
|
||||
onClick={() => setNewClass(true)}
|
||||
>
|
||||
<LuPlus />
|
||||
{t("createCategory.new")}
|
||||
@ -142,7 +142,7 @@ export default function ClassificationSelectionDialog({
|
||||
onClick={() => onCategorizeImage(category)}
|
||||
>
|
||||
<MdCategory />
|
||||
{category}
|
||||
{category.replaceAll("_", " ")}
|
||||
</SelectorItem>
|
||||
))}
|
||||
</div>
|
||||
|
@ -375,7 +375,7 @@ function LibrarySelector({
|
||||
}: LibrarySelectorProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
|
||||
const [renameFace, setRenameFace] = useState<string | null>(null);
|
||||
const [renameClass, setRenameFace] = useState<string | null>(null);
|
||||
|
||||
const handleDeleteFace = useCallback(
|
||||
(name: string) => {
|
||||
@ -390,9 +390,9 @@ function LibrarySelector({
|
||||
|
||||
const handleSetOpen = useCallback(
|
||||
(open: boolean) => {
|
||||
setRenameFace(open ? renameFace : null);
|
||||
setRenameFace(open ? renameClass : null);
|
||||
},
|
||||
[renameFace],
|
||||
[renameClass],
|
||||
);
|
||||
|
||||
return (
|
||||
@ -428,15 +428,15 @@ function LibrarySelector({
|
||||
</Dialog>
|
||||
|
||||
<TextEntryDialog
|
||||
open={!!renameFace}
|
||||
open={!!renameClass}
|
||||
setOpen={handleSetOpen}
|
||||
title={t("renameCategory.title")}
|
||||
description={t("renameCategory.desc", { name: renameFace })}
|
||||
description={t("renameCategory.desc", { name: renameClass })}
|
||||
onSave={(newName) => {
|
||||
onRename(renameFace!, newName);
|
||||
onRename(renameClass!, newName);
|
||||
setRenameFace(null);
|
||||
}}
|
||||
defaultValue={renameFace || ""}
|
||||
defaultValue={renameClass || ""}
|
||||
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
|
||||
regexErrorMessage={t("description.invalidName")}
|
||||
/>
|
||||
@ -484,10 +484,10 @@ function LibrarySelector({
|
||||
className="group flex items-center justify-between"
|
||||
>
|
||||
<div
|
||||
className="flex-grow cursor-pointer"
|
||||
className="flex-grow cursor-pointer capitalize"
|
||||
onClick={() => setPageToggle(id)}
|
||||
>
|
||||
{id}
|
||||
{id.replaceAll("_", " ")}
|
||||
<span className="ml-2 text-muted-foreground">
|
||||
({dataset?.[id].length})
|
||||
</span>
|
||||
@ -681,7 +681,9 @@ function TrainGrid({
|
||||
<div className="rounded-b-lg bg-card p-3">
|
||||
<div className="flex w-full flex-row items-center justify-between gap-2">
|
||||
<div className="flex flex-col items-start text-xs text-primary-variant">
|
||||
<div className="smart-capitalize">{data.label}</div>
|
||||
<div className="smart-capitalize">
|
||||
{data.label.replaceAll("_", " ")}
|
||||
</div>
|
||||
<div>{data.score}%</div>
|
||||
</div>
|
||||
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
|
||||
|
Loading…
Reference in New Issue
Block a user