Improve object classification (#18908)

* Ui improvements

* Improve image cropping and model saving

* Improve naming

* Add logs for training

* Improve model labeling

* Don't set sub label for none object classification

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-06-27 06:28:40 -06:00 committed by GitHub
parent 4b18d54d3d
commit 53315342c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 26 deletions

View File

@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
super().__init__(config, metrics)
self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(self.model_dir, "train")
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None
@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
obj_data["box"][1],
obj_data["box"][2],
obj_data["box"][3],
224,
max(
obj_data["box"][1] - obj_data["box"][0],
obj_data["box"][3] - obj_data["box"][2],
),
1.0,
)
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[
crop = rgb[
y:y2,
x:x2,
]
if input.shape != (224, 224):
input = cv2.resize(input, (224, 224))
if crop.shape != (224, 224):
crop = cv2.resize(crop, (224, 224))
input = np.expand_dims(input, axis=0)
input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return
self.sub_label_publisher.publish(
EventMetadataTypeEnum.sub_label,
(obj_data["id"], self.labelmap[best_id], score),
)
sub_label = self.labelmap[best_id]
self.detected_objects[obj_data["id"]] = score
if sub_label != "none":
self.sub_label_publisher.publish(
EventMetadataTypeEnum.sub_label,
(obj_data["id"], sub_label, score),
)
def handle_request(self, topic, request_data):
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name:

View File

@ -49,6 +49,7 @@ def __train_classification_model(model_name: str) -> bool:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
num_classes = len(

View File

@ -82,7 +82,7 @@ export default function ClassificationSelectionDialog({
);
// control
const [newFace, setNewFace] = useState(false);
const [newClass, setNewClass] = useState(false);
// components
const Selector = isDesktop ? DropdownMenu : Drawer;
@ -98,10 +98,10 @@ export default function ClassificationSelectionDialog({
return (
<div className={className ?? ""}>
{newFace && (
{newClass && (
<TextEntryDialog
open={true}
setOpen={setNewFace}
setOpen={setNewClass}
title={t("createCategory.new")}
onSave={(newCat) => onCategorizeImage(newCat)}
/>
@ -130,7 +130,7 @@ export default function ClassificationSelectionDialog({
>
<SelectorItem
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => setNewFace(true)}
onClick={() => setNewClass(true)}
>
<LuPlus />
{t("createCategory.new")}
@ -142,7 +142,7 @@ export default function ClassificationSelectionDialog({
onClick={() => onCategorizeImage(category)}
>
<MdCategory />
{category}
{category.replaceAll("_", " ")}
</SelectorItem>
))}
</div>

View File

@ -375,7 +375,7 @@ function LibrarySelector({
}: LibrarySelectorProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
const [renameFace, setRenameFace] = useState<string | null>(null);
const [renameClass, setRenameFace] = useState<string | null>(null);
const handleDeleteFace = useCallback(
(name: string) => {
@ -390,9 +390,9 @@ function LibrarySelector({
const handleSetOpen = useCallback(
(open: boolean) => {
setRenameFace(open ? renameFace : null);
setRenameFace(open ? renameClass : null);
},
[renameFace],
[renameClass],
);
return (
@ -428,15 +428,15 @@ function LibrarySelector({
</Dialog>
<TextEntryDialog
open={!!renameFace}
open={!!renameClass}
setOpen={handleSetOpen}
title={t("renameCategory.title")}
description={t("renameCategory.desc", { name: renameFace })}
description={t("renameCategory.desc", { name: renameClass })}
onSave={(newName) => {
onRename(renameFace!, newName);
onRename(renameClass!, newName);
setRenameFace(null);
}}
defaultValue={renameFace || ""}
defaultValue={renameClass || ""}
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
regexErrorMessage={t("description.invalidName")}
/>
@ -484,10 +484,10 @@ function LibrarySelector({
className="group flex items-center justify-between"
>
<div
className="flex-grow cursor-pointer"
className="flex-grow cursor-pointer capitalize"
onClick={() => setPageToggle(id)}
>
{id}
{id.replaceAll("_", " ")}
<span className="ml-2 text-muted-foreground">
({dataset?.[id].length})
</span>
@ -681,7 +681,9 @@ function TrainGrid({
<div className="rounded-b-lg bg-card p-3">
<div className="flex w-full flex-row items-center justify-between gap-2">
<div className="flex flex-col items-start text-xs text-primary-variant">
<div className="smart-capitalize">{data.label}</div>
<div className="smart-capitalize">
{data.label.replaceAll("_", " ")}
</div>
<div>{data.score}%</div>
</div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">