Refactor TensorRT (#18643)

* Combine base and arm trt detectors

* Remove unused deps for amd64 build

* Add missing packages and cleanup ldconfig

* Expand packages for tensorflow model training

* Cleanup

* Refactor training to not reserve memory
This commit is contained in:
Nicolas Mowen 2025-06-09 08:25:33 -06:00
parent f41495df86
commit b35e89928a
6 changed files with 68 additions and 58 deletions

View File

@ -13,6 +13,7 @@ nvidia_cusolver_cu12==11.6.3.*; platform_machine == 'x86_64'
nvidia_cusparse_cu12==12.5.1.*; platform_machine == 'x86_64' nvidia_cusparse_cu12==12.5.1.*; platform_machine == 'x86_64'
nvidia_nccl_cu12==2.23.4; platform_machine == 'x86_64' nvidia_nccl_cu12==2.23.4; platform_machine == 'x86_64'
nvidia_nvjitlink_cu12==12.5.82; platform_machine == 'x86_64' nvidia_nvjitlink_cu12==12.5.82; platform_machine == 'x86_64'
tensorflow==2.19.*; platform_machine == 'x86_64'
onnx==1.16.*; platform_machine == 'x86_64' onnx==1.16.*; platform_machine == 'x86_64'
onnxruntime-gpu==1.22.*; platform_machine == 'x86_64' onnxruntime-gpu==1.22.*; platform_machine == 'x86_64'
protobuf==3.20.3; platform_machine == 'x86_64' protobuf==3.20.3; platform_machine == 'x86_64'

View File

@ -12,7 +12,7 @@ class EmbeddingsRequestEnum(Enum):
# audio # audio
transcribe_audio = "transcribe_audio" transcribe_audio = "transcribe_audio"
# custom classification # custom classification
train_classification = "train_classification" reload_classification_model = "reload_classification_model"
# face # face
clear_face_classifier = "clear_face_classifier" clear_face_classifier = "clear_face_classifier"
recognize_face = "recognize_face" recognize_face = "recognize_face"

View File

@ -3,7 +3,6 @@
import datetime import datetime
import logging import logging
import os import os
import threading
from typing import Any from typing import Any
import cv2 import cv2
@ -17,10 +16,8 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig from frigate.config.classification import CustomClassificationConfig
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
from frigate.util.classification import train_classification_model
from frigate.util.object import box_overlaps, calculate_region from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics from ..types import DataProcessorMetrics
@ -72,18 +69,6 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
) )
self.classifications_per_second.start() self.classifications_per_second.start()
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def __update_metrics(self, duration: float) -> None: def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update() self.classifications_per_second.update()
self.inference_speed.update(duration) self.inference_speed.update(duration)
@ -172,19 +157,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
) )
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
if topic == EmbeddingsRequestEnum.train_classification.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data( self.__build_detector()
UPDATE_MODEL_STATE, logger.info(
{ f"Successfully loaded updated model for {self.model_config.name}"
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
) )
threading.Thread(target=self.__retrain_model).start()
return { return {
"success": True, "success": True,
"message": f"Began training {self.model_config.name} model.", "message": f"Loaded {self.model_config.name} model.",
} }
else: else:
return None return None
@ -232,18 +213,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
prefill=0, prefill=0,
) )
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def __update_metrics(self, duration: float) -> None: def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update() self.classifications_per_second.update()
self.inference_speed.update(duration) self.inference_speed.update(duration)
@ -307,19 +276,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.detected_objects[obj_data["id"]] = score self.detected_objects[obj_data["id"]] = score
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
if topic == EmbeddingsRequestEnum.train_classification.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data( logger.info(
UPDATE_MODEL_STATE, f"Successfully loaded updated model for {self.model_config.name}"
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
) )
threading.Thread(target=self.__retrain_model).start()
return { return {
"success": True, "success": True,
"message": f"Began training {self.model_config.name} model.", "message": f"Loaded {self.model_config.name} model.",
} }
else: else:
return None return None

View File

@ -20,8 +20,8 @@ class DataProcessorMetrics:
alpr_pps: Synchronized alpr_pps: Synchronized
yolov9_lpr_speed: Synchronized yolov9_lpr_speed: Synchronized
yolov9_lpr_pps: Synchronized yolov9_lpr_pps: Synchronized
classification_speeds: dict[str, Synchronized] classification_speeds: dict[str, Synchronized] = {}
classification_cps: dict[str, Synchronized] classification_cps: dict[str, Synchronized] = {}
def __init__(self, custom_classification_models: list[str]): def __init__(self, custom_classification_models: list[str]):
self.image_embeddings_speed = mp.Value("d", 0.0) self.image_embeddings_speed = mp.Value("d", 0.0)
@ -36,8 +36,6 @@ class DataProcessorMetrics:
self.yolov9_lpr_pps = mp.Value("d", 0.0) self.yolov9_lpr_pps = mp.Value("d", 0.0)
if custom_classification_models: if custom_classification_models:
self.classification_speeds = {}
self.classification_cps = {}
for key in custom_classification_models: for key in custom_classification_models:
self.classification_speeds[key] = mp.Value("d", 0.0) self.classification_speeds[key] = mp.Value("d", 0.0)
self.classification_cps[key] = mp.Value("d", 0.0) self.classification_cps[key] = mp.Value("d", 0.0)

View File

@ -22,6 +22,7 @@ from frigate.data_processing.types import DataProcessorMetrics
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event, Recordings from frigate.models import Event, Recordings
from frigate.util.builtin import serialize from frigate.util.builtin import serialize
from frigate.util.classification import kickoff_model_training
from frigate.util.services import listen from frigate.util.services import listen
from .maintainer import EmbeddingMaintainer from .maintainer import EmbeddingMaintainer
@ -302,9 +303,12 @@ class EmbeddingsContext:
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {}) return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
def start_classification_training(self, model_name: str) -> dict[str, Any]: def start_classification_training(self, model_name: str) -> dict[str, Any]:
return self.requestor.send_data( threading.Thread(
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name} target=kickoff_model_training,
) args=(self.requestor, model_name),
daemon=True,
).start()
return {"success": True, "message": f"Began training {model_name} model."}
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
return self.requestor.send_data( return self.requestor.send_data(

View File

@ -10,7 +10,11 @@ from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.preprocessing.image import ImageDataGenerator
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util import Process
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 50 EPOCHS = 50
@ -18,7 +22,7 @@ LEARNING_RATE = 0.001
@staticmethod @staticmethod
def generate_representative_dataset_factory(dataset_dir: str): def __generate_representative_dataset_factory(dataset_dir: str):
def generate_representative_dataset(): def generate_representative_dataset():
image_paths = [] image_paths = []
for root, dirs, files in os.walk(dataset_dir): for root, dirs, files in os.walk(dataset_dir):
@ -38,7 +42,7 @@ def generate_representative_dataset_factory(dataset_dir: str):
@staticmethod @staticmethod
def train_classification_model(model_name: str) -> bool: def __train_classification_model(model_name: str) -> bool:
"""Train a classification model.""" """Train a classification model."""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, model_name) model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
@ -107,7 +111,7 @@ def train_classification_model(model_name: str) -> bool:
# convert model to tflite # convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = generate_representative_dataset_factory( converter.representative_dataset = __generate_representative_dataset_factory(
dataset_dir dataset_dir
) )
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
@ -122,3 +126,42 @@ def train_classification_model(model_name: str) -> bool:
# restore original stdout / stderr # restore original stdout / stderr
sys.stdout = original_stdout sys.stdout = original_stdout
sys.stderr = original_stderr sys.stderr = original_stderr
@staticmethod
def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str
) -> None:
requestor = InterProcessRequestor()
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.training,
},
)
# run training in sub process so that
# tensorflow will free CPU / GPU memory
# upon training completion
training_process = Process(
target=__train_classification_model,
name=f"model_training:{model_name}",
args=(model_name,),
)
training_process.start()
training_process.join()
# reload model and mark training as complete
embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
{"model_name": model_name},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
requestor.stop()