diff --git a/docker/main/Dockerfile b/docker/main/Dockerfile index f58e3d5a9..959ff59d4 100644 --- a/docker/main/Dockerfile +++ b/docker/main/Dockerfile @@ -227,6 +227,9 @@ ENV OPENCV_FFMPEG_LOGLEVEL=8 # Set HailoRT to disable logging ENV HAILORT_LOGGER_PATH=NONE +# TensorFlow error only +ENV TF_CPP_MIN_LOG_LEVEL=3 + ENV PATH="/usr/local/go2rtc/bin:/usr/local/tempio/bin:/usr/local/nginx/sbin:${PATH}" # Install dependencies diff --git a/docker/main/requirements-wheels.txt b/docker/main/requirements-wheels.txt index 59cc1ab9c..624983eb4 100644 --- a/docker/main/requirements-wheels.txt +++ b/docker/main/requirements-wheels.txt @@ -11,6 +11,9 @@ joserfc == 1.0.* pathvalidate == 3.2.* markupsafe == 3.0.* python-multipart == 0.0.12 +# Classification Model Training +tensorflow == 2.19.* ; platform_machine == 'aarch64' +tensorflow-cpu == 2.19.* ; platform_machine == 'x86_64' # General mypy == 1.6.1 onvif-zeep-async == 3.1.* diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 81112933c..f2c6ac06b 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -7,7 +7,7 @@ import shutil from typing import Any import cv2 -from fastapi import APIRouter, Depends, Request, UploadFile +from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile from fastapi.responses import JSONResponse from pathvalidate import sanitize_filename from peewee import DoesNotExist @@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import ( RenameFaceBody, ) from frigate.api.defs.tags import Tags +from frigate.config import FrigateConfig from frigate.config.camera import DetectConfig -from frigate.const import FACE_DIR +from frigate.const import FACE_DIR, MODEL_CACHE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event +from frigate.util.classification import train_classification_model from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -424,3 +426,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody): }, status_code=500, ) + + +# custom classification training + + +@router.post("/classification/{name}/train") +async def train_configured_model( + request: Request, name: str, background_tasks: BackgroundTasks +): + config: FrigateConfig = request.app.frigate_config + + if name not in config.classification.custom: + return JSONResponse( + content=( + { + "success": False, + "message": f"{name} is not a known classification model.", + } + ), + status_code=404, + ) + + background_tasks.add_task( + train_classification_model, os.path.join(MODEL_CACHE_DIR, name) + ) + return JSONResponse( + content={"success": True, "message": "Started classification model training."}, + status_code=200, + ) diff --git a/frigate/config/classification.py b/frigate/config/classification.py index 29568f5cd..c0584ce63 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -85,8 +85,7 @@ class CustomClassificationObjectConfig(FrigateBaseModel): class CustomClassificationConfig(FrigateBaseModel): enabled: bool = Field(default=True, title="Enable running the model.") - model_path: str = Field(title="Path to custom classification tflite model.") - labelmap_path: str = Field(title="Path to custom classification model labelmap.") + name: str | None = Field(default=None, title="Name of classification model.") object_config: CustomClassificationObjectConfig | None = Field(default=None) state_config: CustomClassificationStateConfig | None = Field(default=None) diff --git a/frigate/config/config.py b/frigate/config/config.py index 5bca436b6..d912a574d 100644 --- a/frigate/config/config.py +++ b/frigate/config/config.py @@ -706,6 +706,10 @@ class FrigateConfig(FrigateBaseModel): verify_objects_track(camera_config, labelmap_objects) verify_lpr_and_face(self, camera_config) + # set names on classification configs + for name, config in self.classification.custom.items(): + config.name = name + self.objects.parse_all_objects(self.cameras) self.model.create_colormap(sorted(self.objects.all_objects)) self.model.check_and_load_plus_model(self.plus_api) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index cd99508c9..f94c2b28c 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -2,6 +2,7 @@ import datetime import logging +import os from typing import Any import cv2 @@ -14,6 +15,7 @@ from frigate.comms.event_metadata_updater import ( from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import CustomClassificationConfig +from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.util.builtin import load_labels from frigate.util.object import box_overlaps, calculate_region @@ -33,14 +35,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self, config: FrigateConfig, model_config: CustomClassificationConfig, - name: str, requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) self.model_config = model_config - self.name = name self.requestor = requestor + self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) + self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name) self.interpreter: Interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None @@ -50,13 +52,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): def __build_detector(self) -> None: self.interpreter = Interpreter( - model_path=self.model_config.model_path, + model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) + self.labelmap = load_labels( + os.path.join(self.model_dir, "labelmap.txt"), + prefill=0, + ) def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") @@ -105,15 +110,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) - input = rgb[ + frame = rgb[ y:y2, x:x2, ] - if input.shape != (224, 224): - input = cv2.resize(input, (224, 224)) + if frame.shape != (224, 224): + frame = cv2.resize(frame, (224, 224)) - input = np.expand_dims(input, axis=0) + input = np.expand_dims(frame, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( @@ -123,9 +128,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): best_id = np.argmax(probs) score = round(probs[best_id], 2) + write_classification_attempt( + self.train_dir, + cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + now, + self.labelmap[best_id], + score, + ) + if score >= camera_config.threshold: self.requestor.send_data( - f"{camera}/classification/{self.name}", self.labelmap[best_id] + f"{camera}/classification/{self.model_config.name}", + self.labelmap[best_id], ) def handle_request(self, topic, request_data): @@ -145,6 +159,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ): super().__init__(config, metrics) self.model_config = model_config + self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) + self.train_dir = os.path.join(self.model_dir, "train") self.interpreter: Interpreter = None self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] = None @@ -155,18 +171,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __build_detector(self) -> None: self.interpreter = Interpreter( - model_path=self.model_config.model_path, + model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) + self.labelmap = load_labels( + os.path.join(self.model_dir, "labelmap.txt"), + prefill=0, + ) def process_frame(self, obj_data, frame): if obj_data["label"] not in self.model_config.object_config.objects: return + now = datetime.datetime.now().timestamp() x, y, x2, y2 = calculate_region( frame.shape, obj_data["box"][0], @@ -194,11 +214,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): )[0] probs = res / res.sum(axis=0) best_id = np.argmax(probs) - score = round(probs[best_id], 2) - previous_score = self.detected_objects.get(obj_data["id"], 0.0) + write_classification_attempt( + self.train_dir, + cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + now, + self.labelmap[best_id], + score, + ) + if score <= previous_score: logger.debug(f"Score {score} is worse than previous score {previous_score}") return @@ -215,3 +241,29 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def expire_object(self, object_id, camera): if object_id in self.detected_objects: self.detected_objects.pop(object_id) + + +@staticmethod +def write_classification_attempt( + folder: str, + frame: np.ndarray, + timestamp: float, + label: str, + score: float, +) -> None: + if "-" in label: + label = label.replace("-", "_") + + file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp") + os.makedirs(folder, exist_ok=True) + cv2.imwrite(file, frame) + + files = sorted( + filter(lambda f: (f.endswith(".webp")), os.listdir(folder)), + key=lambda f: os.path.getctime(os.path.join(folder, f)), + reverse=True, + ) + + # delete oldest face image if maximum is reached + if len(files) > 100: + os.unlink(os.path.join(folder, files[-1])) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 25601f014..9a2378221 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -150,10 +150,10 @@ class EmbeddingMaintainer(threading.Thread): ) ) - for name, model_config in self.config.classification.custom.items(): + for model_config in self.config.classification.custom.values(): self.realtime_processors.append( CustomStateClassificationProcessor( - self.config, model_config, name, self.requestor, self.metrics + self.config, model_config, self.requestor, self.metrics ) if model_config.state_config != None else CustomObjectClassificationProcessor( diff --git a/frigate/util/classification.py b/frigate/util/classification.py new file mode 100644 index 000000000..4ee5e1d54 --- /dev/null +++ b/frigate/util/classification.py @@ -0,0 +1,108 @@ +"""Util for classification models.""" + +import os + +import cv2 +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers, models, optimizers +from tensorflow.keras.applications import MobileNetV2 +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +BATCH_SIZE = 16 +EPOCHS = 50 +LEARNING_RATE = 0.001 + + +@staticmethod +def generate_representative_dataset_factory(dataset_dir: str): + def generate_representative_dataset(): + image_paths = [] + for root, dirs, files in os.walk(dataset_dir): + for file in files: + if file.lower().endswith((".jpg", ".jpeg", ".png")): + image_paths.append(os.path.join(root, file)) + + for path in image_paths[:300]: + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (224, 224)) + img_array = np.array(img, dtype=np.float32) / 255.0 + img_array = img_array[None, ...] + yield [img_array] + + return generate_representative_dataset + + +@staticmethod +def train_classification_model(model_dir: str) -> bool: + """Train a classification model.""" + dataset_dir = os.path.join(model_dir, "dataset") + num_classes = len( + [ + d + for d in os.listdir(dataset_dir) + if os.path.isdir(os.path.join(dataset_dir, d)) + ] + ) + + # Start with imagenet base model with 35% of channels in each layer + base_model = MobileNetV2( + input_shape=(224, 224, 3), + include_top=False, + weights="imagenet", + alpha=0.35, + ) + base_model.trainable = False # Freeze pre-trained layers + + model = models.Sequential( + [ + base_model, + layers.GlobalAveragePooling2D(), + layers.Dense(128, activation="relu"), + layers.Dropout(0.3), + layers.Dense(num_classes, activation="softmax"), + ] + ) + + model.compile( + optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), + loss="categorical_crossentropy", + metrics=["accuracy"], + ) + + # create training set + datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) + train_gen = datagen.flow_from_directory( + dataset_dir, + target_size=(224, 224), + batch_size=BATCH_SIZE, + class_mode="categorical", + subset="training", + ) + + # write labelmap + class_indices = train_gen.class_indices + index_to_class = {v: k for k, v in class_indices.items()} + sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] + with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: + for class_name in sorted_classes: + f.write(f"{class_name}\n") + + # train the model + model.fit(train_gen, epochs=EPOCHS, verbose=0) + + # convert model to tflite + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = generate_representative_dataset_factory( + dataset_dir + ) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + + # write model + with open(os.path.join(model_dir, "model.tflite"), "wb") as f: + f.write(tflite_model)