Implement API to train classification models (#18475)

This commit is contained in:
Nicolas Mowen 2025-05-29 17:51:32 -06:00 committed by GitHub
parent 2bd6fa53fe
commit 20e0addae1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 219 additions and 19 deletions

View File

@ -227,6 +227,9 @@ ENV OPENCV_FFMPEG_LOGLEVEL=8
# Set HailoRT to disable logging
ENV HAILORT_LOGGER_PATH=NONE
# TensorFlow error only
ENV TF_CPP_MIN_LOG_LEVEL=3
ENV PATH="/usr/local/go2rtc/bin:/usr/local/tempio/bin:/usr/local/nginx/sbin:${PATH}"
# Install dependencies

View File

@ -11,6 +11,9 @@ joserfc == 1.0.*
pathvalidate == 3.2.*
markupsafe == 3.0.*
python-multipart == 0.0.12
# Classification Model Training
tensorflow == 2.19.* ; platform_machine == 'aarch64'
tensorflow-cpu == 2.19.* ; platform_machine == 'x86_64'
# General
mypy == 1.6.1
onvif-zeep-async == 3.1.*

View File

@ -7,7 +7,7 @@ import shutil
from typing import Any
import cv2
from fastapi import APIRouter, Depends, Request, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename
from peewee import DoesNotExist
@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import (
RenameFaceBody,
)
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.const import FACE_DIR
from frigate.const import FACE_DIR, MODEL_CACHE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__)
@ -424,3 +426,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
},
status_code=500,
)
# custom classification training
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
background_tasks.add_task(
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
)

View File

@ -85,8 +85,7 @@ class CustomClassificationObjectConfig(FrigateBaseModel):
class CustomClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable running the model.")
model_path: str = Field(title="Path to custom classification tflite model.")
labelmap_path: str = Field(title="Path to custom classification model labelmap.")
name: str | None = Field(default=None, title="Name of classification model.")
object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: CustomClassificationStateConfig | None = Field(default=None)

View File

@ -706,6 +706,10 @@ class FrigateConfig(FrigateBaseModel):
verify_objects_track(camera_config, labelmap_objects)
verify_lpr_and_face(self, camera_config)
# set names on classification configs
for name, config in self.classification.custom.items():
config.name = name
self.objects.parse_all_objects(self.cameras)
self.model.create_colormap(sorted(self.objects.all_objects))
self.model.check_and_load_plus_model(self.plus_api)

View File

@ -2,6 +2,7 @@
import datetime
import logging
import os
from typing import Any
import cv2
@ -14,6 +15,7 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.util.builtin import load_labels
from frigate.util.object import box_overlaps, calculate_region
@ -33,14 +35,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self,
config: FrigateConfig,
model_config: CustomClassificationConfig,
name: str,
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
self.model_config = model_config
self.name = name
self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name)
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
@ -50,13 +52,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
def __build_detector(self) -> None:
self.interpreter = Interpreter(
model_path=self.model_config.model_path,
model_path=os.path.join(self.model_dir, "model.tflite"),
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0)
self.labelmap = load_labels(
os.path.join(self.model_dir, "labelmap.txt"),
prefill=0,
)
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
@ -105,15 +110,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[
frame = rgb[
y:y2,
x:x2,
]
if input.shape != (224, 224):
input = cv2.resize(input, (224, 224))
if frame.shape != (224, 224):
frame = cv2.resize(frame, (224, 224))
input = np.expand_dims(input, axis=0)
input = np.expand_dims(frame, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
@ -123,9 +128,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
best_id = np.argmax(probs)
score = round(probs[best_id], 2)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
)
if score >= camera_config.threshold:
self.requestor.send_data(
f"{camera}/classification/{self.name}", self.labelmap[best_id]
f"{camera}/classification/{self.model_config.name}",
self.labelmap[best_id],
)
def handle_request(self, topic, request_data):
@ -145,6 +159,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
):
super().__init__(config, metrics)
self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(self.model_dir, "train")
self.interpreter: Interpreter = None
self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None
@ -155,18 +171,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def __build_detector(self) -> None:
self.interpreter = Interpreter(
model_path=self.model_config.model_path,
model_path=os.path.join(self.model_dir, "model.tflite"),
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0)
self.labelmap = load_labels(
os.path.join(self.model_dir, "labelmap.txt"),
prefill=0,
)
def process_frame(self, obj_data, frame):
if obj_data["label"] not in self.model_config.object_config.objects:
return
now = datetime.datetime.now().timestamp()
x, y, x2, y2 = calculate_region(
frame.shape,
obj_data["box"][0],
@ -194,11 +214,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)[0]
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)
score = round(probs[best_id], 2)
previous_score = self.detected_objects.get(obj_data["id"], 0.0)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
)
if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return
@ -215,3 +241,29 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def expire_object(self, object_id, camera):
if object_id in self.detected_objects:
self.detected_objects.pop(object_id)
@staticmethod
def write_classification_attempt(
folder: str,
frame: np.ndarray,
timestamp: float,
label: str,
score: float,
) -> None:
if "-" in label:
label = label.replace("-", "_")
file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp")
os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, frame)
files = sorted(
filter(lambda f: (f.endswith(".webp")), os.listdir(folder)),
key=lambda f: os.path.getctime(os.path.join(folder, f)),
reverse=True,
)
# delete oldest face image if maximum is reached
if len(files) > 100:
os.unlink(os.path.join(folder, files[-1]))

View File

@ -150,10 +150,10 @@ class EmbeddingMaintainer(threading.Thread):
)
)
for name, model_config in self.config.classification.custom.items():
for model_config in self.config.classification.custom.values():
self.realtime_processors.append(
CustomStateClassificationProcessor(
self.config, model_config, name, self.requestor, self.metrics
self.config, model_config, self.requestor, self.metrics
)
if model_config.state_config != None
else CustomObjectClassificationProcessor(

View File

@ -0,0 +1,108 @@
"""Util for classification models."""
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
@staticmethod
def generate_representative_dataset_factory(dataset_dir: str):
def generate_representative_dataset():
image_paths = []
for root, dirs, files in os.walk(dataset_dir):
for file in files:
if file.lower().endswith((".jpg", ".jpeg", ".png")):
image_paths.append(os.path.join(root, file))
for path in image_paths[:300]:
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img_array = np.array(img, dtype=np.float32) / 255.0
img_array = img_array[None, ...]
yield [img_array]
return generate_representative_dataset
@staticmethod
def train_classification_model(model_dir: str) -> bool:
"""Train a classification model."""
dataset_dir = os.path.join(model_dir, "dataset")
num_classes = len(
[
d
for d in os.listdir(dataset_dir)
if os.path.isdir(os.path.join(dataset_dir, d))
]
)
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet",
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
model = models.Sequential(
[
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# create training set
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
train_gen = datagen.flow_from_directory(
dataset_dir,
target_size=(224, 224),
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
for class_name in sorted_classes:
f.write(f"{class_name}\n")
# train the model
model.fit(train_gen, epochs=EPOCHS, verbose=0)
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = generate_representative_dataset_factory(
dataset_dir
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
f.write(tflite_model)