mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-08 13:51:01 +02:00
Refactor TensorRT (#18643)
* Combine base and arm trt detectors * Remove unused deps for amd64 build * Add missing packages and cleanup ldconfig * Expand packages for tensorflow model training * Cleanup * Refactor training to not reserve memory
This commit is contained in:
parent
2692eb4830
commit
ad5076f645
@ -13,6 +13,7 @@ nvidia_cusolver_cu12==11.6.3.*; platform_machine == 'x86_64'
|
||||
nvidia_cusparse_cu12==12.5.1.*; platform_machine == 'x86_64'
|
||||
nvidia_nccl_cu12==2.23.4; platform_machine == 'x86_64'
|
||||
nvidia_nvjitlink_cu12==12.5.82; platform_machine == 'x86_64'
|
||||
tensorflow==2.19.*; platform_machine == 'x86_64'
|
||||
onnx==1.16.*; platform_machine == 'x86_64'
|
||||
onnxruntime-gpu==1.22.*; platform_machine == 'x86_64'
|
||||
protobuf==3.20.3; platform_machine == 'x86_64'
|
||||
|
@ -12,7 +12,7 @@ class EmbeddingsRequestEnum(Enum):
|
||||
# audio
|
||||
transcribe_audio = "transcribe_audio"
|
||||
# custom classification
|
||||
train_classification = "train_classification"
|
||||
reload_classification_model = "reload_classification_model"
|
||||
# face
|
||||
clear_face_classifier = "clear_face_classifier"
|
||||
recognize_face = "recognize_face"
|
||||
|
@ -3,7 +3,6 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
@ -17,10 +16,8 @@ from frigate.comms.event_metadata_updater import (
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config.classification import CustomClassificationConfig
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
||||
from frigate.util.classification import train_classification_model
|
||||
from frigate.util.object import box_overlaps, calculate_region
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
@ -72,18 +69,6 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
self.classifications_per_second.start()
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
@ -172,19 +157,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@ -232,18 +213,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
prefill=0,
|
||||
)
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
@ -307,19 +276,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
@ -20,8 +20,8 @@ class DataProcessorMetrics:
|
||||
alpr_pps: Synchronized
|
||||
yolov9_lpr_speed: Synchronized
|
||||
yolov9_lpr_pps: Synchronized
|
||||
classification_speeds: dict[str, Synchronized]
|
||||
classification_cps: dict[str, Synchronized]
|
||||
classification_speeds: dict[str, Synchronized] = {}
|
||||
classification_cps: dict[str, Synchronized] = {}
|
||||
|
||||
def __init__(self, custom_classification_models: list[str]):
|
||||
self.image_embeddings_speed = mp.Value("d", 0.0)
|
||||
@ -36,8 +36,6 @@ class DataProcessorMetrics:
|
||||
self.yolov9_lpr_pps = mp.Value("d", 0.0)
|
||||
|
||||
if custom_classification_models:
|
||||
self.classification_speeds = {}
|
||||
self.classification_cps = {}
|
||||
for key in custom_classification_models:
|
||||
self.classification_speeds[key] = mp.Value("d", 0.0)
|
||||
self.classification_cps[key] = mp.Value("d", 0.0)
|
||||
|
@ -22,6 +22,7 @@ from frigate.data_processing.types import DataProcessorMetrics
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event, Recordings
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.classification import kickoff_model_training
|
||||
from frigate.util.services import listen
|
||||
|
||||
from .maintainer import EmbeddingMaintainer
|
||||
@ -302,9 +303,12 @@ class EmbeddingsContext:
|
||||
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
|
||||
|
||||
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
||||
return self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
|
||||
)
|
||||
threading.Thread(
|
||||
target=kickoff_model_training,
|
||||
args=(self.requestor, model_name),
|
||||
daemon=True,
|
||||
).start()
|
||||
return {"success": True, "message": f"Began training {model_name} model."}
|
||||
|
||||
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
||||
return self.requestor.send_data(
|
||||
|
@ -10,7 +10,11 @@ from tensorflow.keras import layers, models, optimizers
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util import Process
|
||||
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 50
|
||||
@ -18,7 +22,7 @@ LEARNING_RATE = 0.001
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_representative_dataset_factory(dataset_dir: str):
|
||||
def __generate_representative_dataset_factory(dataset_dir: str):
|
||||
def generate_representative_dataset():
|
||||
image_paths = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
@ -38,7 +42,7 @@ def generate_representative_dataset_factory(dataset_dir: str):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def train_classification_model(model_name: str) -> bool:
|
||||
def __train_classification_model(model_name: str) -> bool:
|
||||
"""Train a classification model."""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||
@ -107,7 +111,7 @@ def train_classification_model(model_name: str) -> bool:
|
||||
# convert model to tflite
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = generate_representative_dataset_factory(
|
||||
converter.representative_dataset = __generate_representative_dataset_factory(
|
||||
dataset_dir
|
||||
)
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
@ -122,3 +126,42 @@ def train_classification_model(model_name: str) -> bool:
|
||||
# restore original stdout / stderr
|
||||
sys.stdout = original_stdout
|
||||
sys.stderr = original_stderr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
) -> None:
|
||||
requestor = InterProcessRequestor()
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model_name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
|
||||
# run training in sub process so that
|
||||
# tensorflow will free CPU / GPU memory
|
||||
# upon training completion
|
||||
training_process = Process(
|
||||
target=__train_classification_model,
|
||||
name=f"model_training:{model_name}",
|
||||
args=(model_name,),
|
||||
)
|
||||
training_process.start()
|
||||
training_process.join()
|
||||
|
||||
# reload model and mark training as complete
|
||||
embeddingRequestor.send_data(
|
||||
EmbeddingsRequestEnum.reload_classification_model.value,
|
||||
{"model_name": model_name},
|
||||
)
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model_name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
requestor.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user