diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index f153b5b92..fb1d31e89 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): super().__init__(config, metrics) self.model_config = model_config self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) - self.train_dir = os.path.join(self.model_dir, "train") + self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.interpreter: Interpreter = None self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] = None @@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): obj_data["box"][1], obj_data["box"][2], obj_data["box"][3], - 224, + max( + obj_data["box"][1] - obj_data["box"][0], + obj_data["box"][3] - obj_data["box"][2], + ), 1.0, ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) - input = rgb[ + crop = rgb[ y:y2, x:x2, ] - if input.shape != (224, 224): - input = cv2.resize(input, (224, 224)) + if crop.shape != (224, 224): + crop = cv2.resize(crop, (224, 224)) - input = np.expand_dims(input, axis=0) + input = np.expand_dims(crop, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( @@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): write_classification_attempt( self.train_dir, - cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), now, self.labelmap[best_id], score, @@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): logger.debug(f"Score {score} is worse than previous score {previous_score}") return - self.sub_label_publisher.publish( - EventMetadataTypeEnum.sub_label, - (obj_data["id"], self.labelmap[best_id], score), - ) + sub_label = self.labelmap[best_id] self.detected_objects[obj_data["id"]] = score + if sub_label != "none": + self.sub_label_publisher.publish( + EventMetadataTypeEnum.sub_label, + (obj_data["id"], sub_label, score), + ) + def handle_request(self, topic, request_data): if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 3c030a986..6eab829f2 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -49,6 +49,7 @@ def __train_classification_model(model_name: str) -> bool: from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.preprocessing.image import ImageDataGenerator + logger.info(f"Kicking off classification training for {model_name}.") dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") model_dir = os.path.join(MODEL_CACHE_DIR, model_name) num_classes = len( diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx index 7cb8ca156..f86ced19a 100644 --- a/web/src/components/overlay/ClassificationSelectionDialog.tsx +++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx @@ -82,7 +82,7 @@ export default function ClassificationSelectionDialog({ ); // control - const [newFace, setNewFace] = useState(false); + const [newClass, setNewClass] = useState(false); // components const Selector = isDesktop ? DropdownMenu : Drawer; @@ -98,10 +98,10 @@ export default function ClassificationSelectionDialog({ return (
- {newFace && ( + {newClass && ( onCategorizeImage(newCat)} /> @@ -130,7 +130,7 @@ export default function ClassificationSelectionDialog({ > setNewFace(true)} + onClick={() => setNewClass(true)} > {t("createCategory.new")} @@ -142,7 +142,7 @@ export default function ClassificationSelectionDialog({ onClick={() => onCategorizeImage(category)} > - {category} + {category.replaceAll("_", " ")} ))}
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 1f62a4f53..ea265bd51 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -375,7 +375,7 @@ function LibrarySelector({ }: LibrarySelectorProps) { const { t } = useTranslation(["views/classificationModel"]); const [confirmDelete, setConfirmDelete] = useState(null); - const [renameFace, setRenameFace] = useState(null); + const [renameClass, setRenameFace] = useState(null); const handleDeleteFace = useCallback( (name: string) => { @@ -390,9 +390,9 @@ function LibrarySelector({ const handleSetOpen = useCallback( (open: boolean) => { - setRenameFace(open ? renameFace : null); + setRenameFace(open ? renameClass : null); }, - [renameFace], + [renameClass], ); return ( @@ -428,15 +428,15 @@ function LibrarySelector({ { - onRename(renameFace!, newName); + onRename(renameClass!, newName); setRenameFace(null); }} - defaultValue={renameFace || ""} + defaultValue={renameClass || ""} regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u} regexErrorMessage={t("description.invalidName")} /> @@ -484,10 +484,10 @@ function LibrarySelector({ className="group flex items-center justify-between" >
setPageToggle(id)} > - {id} + {id.replaceAll("_", " ")} ({dataset?.[id].length}) @@ -681,7 +681,9 @@ function TrainGrid({
-
{data.label}
+
+ {data.label.replaceAll("_", " ")} +
{data.score}%