Improve object classification (#18908)

* Ui improvements

* Improve image cropping and model saving

* Improve naming

* Add logs for training

* Improve model labeling

* Don't set sub label for none object classification

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-06-27 06:28:40 -06:00 committed by GitHub
parent 4b18d54d3d
commit 53315342c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 26 deletions

View File

@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
super().__init__(config, metrics) super().__init__(config, metrics)
self.model_config = model_config self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(self.model_dir, "train") self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None self.interpreter: Interpreter = None
self.sub_label_publisher = sub_label_publisher self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None self.tensor_input_details: dict[str, Any] = None
@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
obj_data["box"][1], obj_data["box"][1],
obj_data["box"][2], obj_data["box"][2],
obj_data["box"][3], obj_data["box"][3],
224, max(
obj_data["box"][1] - obj_data["box"][0],
obj_data["box"][3] - obj_data["box"][2],
),
1.0, 1.0,
) )
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[ crop = rgb[
y:y2, y:y2,
x:x2, x:x2,
] ]
if input.shape != (224, 224): if crop.shape != (224, 224):
input = cv2.resize(input, (224, 224)) crop = cv2.resize(crop, (224, 224))
input = np.expand_dims(input, axis=0) input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke() self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor( res: np.ndarray = self.interpreter.get_tensor(
@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,
@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
logger.debug(f"Score {score} is worse than previous score {previous_score}") logger.debug(f"Score {score} is worse than previous score {previous_score}")
return return
self.sub_label_publisher.publish( sub_label = self.labelmap[best_id]
EventMetadataTypeEnum.sub_label,
(obj_data["id"], self.labelmap[best_id], score),
)
self.detected_objects[obj_data["id"]] = score self.detected_objects[obj_data["id"]] = score
if sub_label != "none":
self.sub_label_publisher.publish(
EventMetadataTypeEnum.sub_label,
(obj_data["id"], sub_label, score),
)
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
if topic == EmbeddingsRequestEnum.reload_classification_model.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:

View File

@ -49,6 +49,7 @@ def __train_classification_model(model_name: str) -> bool:
from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, model_name) model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
num_classes = len( num_classes = len(

View File

@ -82,7 +82,7 @@ export default function ClassificationSelectionDialog({
); );
// control // control
const [newFace, setNewFace] = useState(false); const [newClass, setNewClass] = useState(false);
// components // components
const Selector = isDesktop ? DropdownMenu : Drawer; const Selector = isDesktop ? DropdownMenu : Drawer;
@ -98,10 +98,10 @@ export default function ClassificationSelectionDialog({
return ( return (
<div className={className ?? ""}> <div className={className ?? ""}>
{newFace && ( {newClass && (
<TextEntryDialog <TextEntryDialog
open={true} open={true}
setOpen={setNewFace} setOpen={setNewClass}
title={t("createCategory.new")} title={t("createCategory.new")}
onSave={(newCat) => onCategorizeImage(newCat)} onSave={(newCat) => onCategorizeImage(newCat)}
/> />
@ -130,7 +130,7 @@ export default function ClassificationSelectionDialog({
> >
<SelectorItem <SelectorItem
className="flex cursor-pointer gap-2 smart-capitalize" className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => setNewFace(true)} onClick={() => setNewClass(true)}
> >
<LuPlus /> <LuPlus />
{t("createCategory.new")} {t("createCategory.new")}
@ -142,7 +142,7 @@ export default function ClassificationSelectionDialog({
onClick={() => onCategorizeImage(category)} onClick={() => onCategorizeImage(category)}
> >
<MdCategory /> <MdCategory />
{category} {category.replaceAll("_", " ")}
</SelectorItem> </SelectorItem>
))} ))}
</div> </div>

View File

@ -375,7 +375,7 @@ function LibrarySelector({
}: LibrarySelectorProps) { }: LibrarySelectorProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const [confirmDelete, setConfirmDelete] = useState<string | null>(null); const [confirmDelete, setConfirmDelete] = useState<string | null>(null);
const [renameFace, setRenameFace] = useState<string | null>(null); const [renameClass, setRenameFace] = useState<string | null>(null);
const handleDeleteFace = useCallback( const handleDeleteFace = useCallback(
(name: string) => { (name: string) => {
@ -390,9 +390,9 @@ function LibrarySelector({
const handleSetOpen = useCallback( const handleSetOpen = useCallback(
(open: boolean) => { (open: boolean) => {
setRenameFace(open ? renameFace : null); setRenameFace(open ? renameClass : null);
}, },
[renameFace], [renameClass],
); );
return ( return (
@ -428,15 +428,15 @@ function LibrarySelector({
</Dialog> </Dialog>
<TextEntryDialog <TextEntryDialog
open={!!renameFace} open={!!renameClass}
setOpen={handleSetOpen} setOpen={handleSetOpen}
title={t("renameCategory.title")} title={t("renameCategory.title")}
description={t("renameCategory.desc", { name: renameFace })} description={t("renameCategory.desc", { name: renameClass })}
onSave={(newName) => { onSave={(newName) => {
onRename(renameFace!, newName); onRename(renameClass!, newName);
setRenameFace(null); setRenameFace(null);
}} }}
defaultValue={renameFace || ""} defaultValue={renameClass || ""}
regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u} regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u}
regexErrorMessage={t("description.invalidName")} regexErrorMessage={t("description.invalidName")}
/> />
@ -484,10 +484,10 @@ function LibrarySelector({
className="group flex items-center justify-between" className="group flex items-center justify-between"
> >
<div <div
className="flex-grow cursor-pointer" className="flex-grow cursor-pointer capitalize"
onClick={() => setPageToggle(id)} onClick={() => setPageToggle(id)}
> >
{id} {id.replaceAll("_", " ")}
<span className="ml-2 text-muted-foreground"> <span className="ml-2 text-muted-foreground">
({dataset?.[id].length}) ({dataset?.[id].length})
</span> </span>
@ -681,7 +681,9 @@ function TrainGrid({
<div className="rounded-b-lg bg-card p-3"> <div className="rounded-b-lg bg-card p-3">
<div className="flex w-full flex-row items-center justify-between gap-2"> <div className="flex w-full flex-row items-center justify-between gap-2">
<div className="flex flex-col items-start text-xs text-primary-variant"> <div className="flex flex-col items-start text-xs text-primary-variant">
<div className="smart-capitalize">{data.label}</div> <div className="smart-capitalize">
{data.label.replaceAll("_", " ")}
</div>
<div>{data.score}%</div> <div>{data.score}%</div>
</div> </div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4"> <div className="flex flex-row items-start justify-end gap-5 md:gap-4">