mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-30 13:48:07 +02:00
Refactor TensorRT (#18643)
* Combine base and arm trt detectors * Remove unused deps for amd64 build * Add missing packages and cleanup ldconfig * Expand packages for tensorflow model training * Cleanup * Refactor training to not reserve memory
This commit is contained in:
parent
9a5162752c
commit
28fba7122d
@ -6,24 +6,29 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
# Globally set pip break-system-packages option to avoid having to specify it every time
|
||||
ARG PIP_BREAK_SYSTEM_PACKAGES=1
|
||||
|
||||
FROM tensorrt-base AS frigate-tensorrt
|
||||
FROM wheels AS trt-wheels
|
||||
ARG PIP_BREAK_SYSTEM_PACKAGES
|
||||
ENV TRT_VER=8.6.1
|
||||
|
||||
# Install TensorRT wheels
|
||||
COPY docker/tensorrt/requirements-amd64.txt /requirements-tensorrt.txt
|
||||
RUN pip3 install -U -r /requirements-tensorrt.txt && ldconfig
|
||||
COPY docker/main/requirements-wheels.txt /requirements-wheels.txt
|
||||
RUN pip3 wheel --wheel-dir=/trt-wheels -c /requirements-wheels.txt -r /requirements-tensorrt.txt
|
||||
|
||||
FROM deps AS frigate-tensorrt
|
||||
ARG PIP_BREAK_SYSTEM_PACKAGES
|
||||
|
||||
RUN --mount=type=bind,from=trt-wheels,source=/trt-wheels,target=/deps/trt-wheels \
|
||||
pip3 uninstall -y onnxruntime-openvino tensorflow-cpu \
|
||||
&& pip3 install -U /deps/trt-wheels/*.whl
|
||||
|
||||
COPY --from=rootfs / /
|
||||
COPY docker/tensorrt/detector/rootfs/etc/ld.so.conf.d /etc/ld.so.conf.d
|
||||
RUN ldconfig
|
||||
|
||||
WORKDIR /opt/frigate/
|
||||
COPY --from=rootfs / /
|
||||
|
||||
# Dev Container w/ TRT
|
||||
FROM devcontainer AS devcontainer-trt
|
||||
|
||||
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
|
||||
COPY --from=trt-deps /usr/local/src/tensorrt_demos /usr/local/src/tensorrt_demos
|
||||
COPY --from=trt-deps /usr/local/cuda-12.1 /usr/local/cuda
|
||||
COPY docker/tensorrt/detector/rootfs/ /
|
||||
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
|
||||
RUN --mount=type=bind,from=trt-wheels,source=/trt-wheels,target=/deps/trt-wheels \
|
||||
pip3 install -U /deps/trt-wheels/*.whl
|
||||
|
@ -2,8 +2,60 @@
|
||||
|
||||
# https://askubuntu.com/questions/972516/debian-frontend-environment-variable
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG TRT_BASE=nvcr.io/nvidia/tensorrt:23.12-py3
|
||||
|
||||
# Build TensorRT-specific library
|
||||
FROM ${TRT_BASE} AS trt-deps
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG COMPUTE_LEVEL
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y git build-essential cuda-nvcc-* cuda-nvtx-* libnvinfer-dev libnvinfer-plugin-dev libnvparsers-dev libnvonnxparsers-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN --mount=type=bind,source=docker/tensorrt/detector/tensorrt_libyolo.sh,target=/tensorrt_libyolo.sh \
|
||||
/tensorrt_libyolo.sh
|
||||
|
||||
# COPY required individual CUDA deps
|
||||
RUN mkdir -p /usr/local/cuda-deps
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
cp /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcurand.so.* /usr/local/cuda-deps/ && \
|
||||
cp /usr/local/cuda-12.3/targets/x86_64-linux/lib/libnvrtc.so.* /usr/local/cuda-deps/ && \
|
||||
cd /usr/local/cuda-deps/ && \
|
||||
for lib in libnvrtc.so.*; do \
|
||||
if [[ "$lib" =~ libnvrtc.so\.([0-9]+\.[0-9]+\.[0-9]+) ]]; then \
|
||||
version="${BASH_REMATCH[1]}"; \
|
||||
ln -sf "libnvrtc.so.$version" libnvrtc.so; \
|
||||
fi; \
|
||||
done && \
|
||||
for lib in libcurand.so.*; do \
|
||||
if [[ "$lib" =~ libcurand.so\.([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+) ]]; then \
|
||||
version="${BASH_REMATCH[1]}"; \
|
||||
ln -sf "libcurand.so.$version" libcurand.so; \
|
||||
fi; \
|
||||
done; \
|
||||
fi
|
||||
|
||||
# Frigate w/ TensorRT Support as separate image
|
||||
FROM deps AS tensorrt-base
|
||||
|
||||
#Disable S6 Global timeout
|
||||
ENV S6_CMD_WAIT_FOR_SERVICES_MAXTIME=0
|
||||
|
||||
# COPY TensorRT Model Generation Deps
|
||||
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
|
||||
COPY --from=trt-deps /usr/local/src/tensorrt_demos /usr/local/src/tensorrt_demos
|
||||
|
||||
# COPY Individual CUDA deps folder
|
||||
COPY --from=trt-deps /usr/local/cuda-deps /usr/local/cuda
|
||||
|
||||
COPY docker/tensorrt/detector/rootfs/ /
|
||||
ENV YOLO_MODELS=""
|
||||
|
||||
HEALTHCHECK --start-period=600s --start-interval=5s --interval=15s --timeout=5s --retries=3 \
|
||||
CMD curl --fail --silent --show-error http://127.0.0.1:5000/api/version || exit 1
|
||||
|
||||
FROM ${BASE_IMAGE} AS build-wheels
|
||||
ARG DEBIAN_FRONTEND
|
||||
|
||||
|
@ -1,57 +0,0 @@
|
||||
# syntax=docker/dockerfile:1.6
|
||||
|
||||
# https://askubuntu.com/questions/972516/debian-frontend-environment-variable
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG TRT_BASE=nvcr.io/nvidia/tensorrt:23.12-py3
|
||||
|
||||
# Build TensorRT-specific library
|
||||
FROM ${TRT_BASE} AS trt-deps
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG COMPUTE_LEVEL
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y git build-essential cuda-nvcc-* cuda-nvtx-* libnvinfer-dev libnvinfer-plugin-dev libnvparsers-dev libnvonnxparsers-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN --mount=type=bind,source=docker/tensorrt/detector/tensorrt_libyolo.sh,target=/tensorrt_libyolo.sh \
|
||||
/tensorrt_libyolo.sh
|
||||
|
||||
# COPY required individual CUDA deps
|
||||
RUN mkdir -p /usr/local/cuda-deps
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
cp /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcurand.so.* /usr/local/cuda-deps/ && \
|
||||
cp /usr/local/cuda-12.3/targets/x86_64-linux/lib/libnvrtc.so.* /usr/local/cuda-deps/ && \
|
||||
cd /usr/local/cuda-deps/ && \
|
||||
for lib in libnvrtc.so.*; do \
|
||||
if [[ "$lib" =~ libnvrtc.so\.([0-9]+\.[0-9]+\.[0-9]+) ]]; then \
|
||||
version="${BASH_REMATCH[1]}"; \
|
||||
ln -sf "libnvrtc.so.$version" libnvrtc.so; \
|
||||
fi; \
|
||||
done && \
|
||||
for lib in libcurand.so.*; do \
|
||||
if [[ "$lib" =~ libcurand.so\.([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+) ]]; then \
|
||||
version="${BASH_REMATCH[1]}"; \
|
||||
ln -sf "libcurand.so.$version" libcurand.so; \
|
||||
fi; \
|
||||
done; \
|
||||
fi
|
||||
|
||||
# Frigate w/ TensorRT Support as separate image
|
||||
FROM deps AS tensorrt-base
|
||||
|
||||
#Disable S6 Global timeout
|
||||
ENV S6_CMD_WAIT_FOR_SERVICES_MAXTIME=0
|
||||
|
||||
# COPY TensorRT Model Generation Deps
|
||||
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
|
||||
COPY --from=trt-deps /usr/local/src/tensorrt_demos /usr/local/src/tensorrt_demos
|
||||
|
||||
# COPY Individual CUDA deps folder
|
||||
COPY --from=trt-deps /usr/local/cuda-deps /usr/local/cuda
|
||||
|
||||
COPY docker/tensorrt/detector/rootfs/ /
|
||||
ENV YOLO_MODELS=""
|
||||
|
||||
HEALTHCHECK --start-period=600s --start-interval=5s --interval=15s --timeout=5s --retries=3 \
|
||||
CMD curl --fail --silent --show-error http://127.0.0.1:5000/api/version || exit 1
|
@ -1,7 +1,6 @@
|
||||
/usr/local/lib
|
||||
/usr/local/cuda
|
||||
/usr/local/lib/python3.11/dist-packages/tensorrt
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cudnn/lib
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cuda_runtime/lib
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cublas/lib
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cufft/lib
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cufft/lib
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/curand/lib/
|
||||
/usr/local/lib/python3.11/dist-packages/nvidia/cuda_nvrtc/lib/
|
@ -1,17 +1,19 @@
|
||||
# NVidia TensorRT Support (amd64 only)
|
||||
--extra-index-url 'https://pypi.nvidia.com'
|
||||
numpy < 1.24; platform_machine == 'x86_64'
|
||||
tensorrt == 8.6.1; platform_machine == 'x86_64'
|
||||
tensorrt_bindings == 8.6.1; platform_machine == 'x86_64'
|
||||
cuda-python == 11.8.*; platform_machine == 'x86_64'
|
||||
cython == 3.0.*; platform_machine == 'x86_64'
|
||||
nvidia-cuda-runtime-cu12 == 12.1.*; platform_machine == 'x86_64'
|
||||
nvidia-cuda-runtime-cu11 == 11.8.*; platform_machine == 'x86_64'
|
||||
nvidia-cublas-cu11 == 11.11.3.6; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu11 == 8.6.0.*; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu12 == 9.5.0.*; platform_machine == 'x86_64'
|
||||
nvidia-cufft-cu11==10.*; platform_machine == 'x86_64'
|
||||
nvidia-cufft-cu12==11.*; platform_machine == 'x86_64'
|
||||
cython==3.0.*; platform_machine == 'x86_64'
|
||||
nvidia_cuda_cupti_cu12==12.5.82; platform_machine == 'x86_64'
|
||||
nvidia-cublas-cu12==12.5.3.*; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu12==9.3.0.*; platform_machine == 'x86_64'
|
||||
nvidia-cufft-cu12==11.2.3.*; platform_machine == 'x86_64'
|
||||
nvidia-curand-cu12==10.3.6.*; platform_machine == 'x86_64'
|
||||
nvidia_cuda_nvcc_cu12==12.5.82; platform_machine == 'x86_64'
|
||||
nvidia-cuda-nvrtc-cu12==12.5.82; platform_machine == 'x86_64'
|
||||
nvidia_cuda_runtime_cu12==12.5.82; platform_machine == 'x86_64'
|
||||
nvidia_cusolver_cu12==11.6.3.*; platform_machine == 'x86_64'
|
||||
nvidia_cusparse_cu12==12.5.1.*; platform_machine == 'x86_64'
|
||||
nvidia_nccl_cu12==2.23.4; platform_machine == 'x86_64'
|
||||
nvidia_nvjitlink_cu12==12.5.82; platform_machine == 'x86_64'
|
||||
tensorflow==2.19.*; platform_machine == 'x86_64'
|
||||
onnx==1.16.*; platform_machine == 'x86_64'
|
||||
onnxruntime-gpu==1.22.*; platform_machine == 'x86_64'
|
||||
protobuf==3.20.3; platform_machine == 'x86_64'
|
||||
|
@ -93,7 +93,8 @@ target "tensorrt" {
|
||||
context = "."
|
||||
contexts = {
|
||||
wget = "target:wget",
|
||||
tensorrt-base = "target:tensorrt-base",
|
||||
wheels = "target:wheels",
|
||||
deps = "target:deps",
|
||||
rootfs = "target:rootfs"
|
||||
}
|
||||
target = "frigate-tensorrt"
|
||||
|
@ -12,7 +12,7 @@ class EmbeddingsRequestEnum(Enum):
|
||||
# audio
|
||||
transcribe_audio = "transcribe_audio"
|
||||
# custom classification
|
||||
train_classification = "train_classification"
|
||||
reload_classification_model = "reload_classification_model"
|
||||
# face
|
||||
clear_face_classifier = "clear_face_classifier"
|
||||
recognize_face = "recognize_face"
|
||||
|
@ -3,7 +3,6 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
@ -17,10 +16,8 @@ from frigate.comms.event_metadata_updater import (
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config.classification import CustomClassificationConfig
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
||||
from frigate.util.classification import train_classification_model
|
||||
from frigate.util.object import box_overlaps, calculate_region
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
@ -72,18 +69,6 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
self.classifications_per_second.start()
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
@ -172,19 +157,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@ -232,18 +213,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
prefill=0,
|
||||
)
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
@ -307,19 +276,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
@ -20,8 +20,8 @@ class DataProcessorMetrics:
|
||||
alpr_pps: Synchronized
|
||||
yolov9_lpr_speed: Synchronized
|
||||
yolov9_lpr_pps: Synchronized
|
||||
classification_speeds: dict[str, Synchronized]
|
||||
classification_cps: dict[str, Synchronized]
|
||||
classification_speeds: dict[str, Synchronized] = {}
|
||||
classification_cps: dict[str, Synchronized] = {}
|
||||
|
||||
def __init__(self, custom_classification_models: list[str]):
|
||||
self.image_embeddings_speed = mp.Value("d", 0.0)
|
||||
@ -36,8 +36,6 @@ class DataProcessorMetrics:
|
||||
self.yolov9_lpr_pps = mp.Value("d", 0.0)
|
||||
|
||||
if custom_classification_models:
|
||||
self.classification_speeds = {}
|
||||
self.classification_cps = {}
|
||||
for key in custom_classification_models:
|
||||
self.classification_speeds[key] = mp.Value("d", 0.0)
|
||||
self.classification_cps[key] = mp.Value("d", 0.0)
|
||||
|
@ -21,6 +21,7 @@ from frigate.data_processing.types import DataProcessorMetrics
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event, Recordings
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.classification import kickoff_model_training
|
||||
from frigate.util.services import listen
|
||||
|
||||
from .maintainer import EmbeddingMaintainer
|
||||
@ -293,9 +294,12 @@ class EmbeddingsContext:
|
||||
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
|
||||
|
||||
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
||||
return self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
|
||||
)
|
||||
threading.Thread(
|
||||
target=kickoff_model_training,
|
||||
args=(self.requestor, model_name),
|
||||
daemon=True,
|
||||
).start()
|
||||
return {"success": True, "message": f"Began training {model_name} model."}
|
||||
|
||||
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
||||
return self.requestor.send_data(
|
||||
|
@ -10,7 +10,11 @@ from tensorflow.keras import layers, models, optimizers
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util import Process
|
||||
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 50
|
||||
@ -18,7 +22,7 @@ LEARNING_RATE = 0.001
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_representative_dataset_factory(dataset_dir: str):
|
||||
def __generate_representative_dataset_factory(dataset_dir: str):
|
||||
def generate_representative_dataset():
|
||||
image_paths = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
@ -38,7 +42,7 @@ def generate_representative_dataset_factory(dataset_dir: str):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def train_classification_model(model_name: str) -> bool:
|
||||
def __train_classification_model(model_name: str) -> bool:
|
||||
"""Train a classification model."""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||
@ -107,7 +111,7 @@ def train_classification_model(model_name: str) -> bool:
|
||||
# convert model to tflite
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = generate_representative_dataset_factory(
|
||||
converter.representative_dataset = __generate_representative_dataset_factory(
|
||||
dataset_dir
|
||||
)
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
@ -122,3 +126,42 @@ def train_classification_model(model_name: str) -> bool:
|
||||
# restore original stdout / stderr
|
||||
sys.stdout = original_stdout
|
||||
sys.stderr = original_stderr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
) -> None:
|
||||
requestor = InterProcessRequestor()
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model_name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
|
||||
# run training in sub process so that
|
||||
# tensorflow will free CPU / GPU memory
|
||||
# upon training completion
|
||||
training_process = Process(
|
||||
target=__train_classification_model,
|
||||
name=f"model_training:{model_name}",
|
||||
args=(model_name,),
|
||||
)
|
||||
training_process.start()
|
||||
training_process.join()
|
||||
|
||||
# reload model and mark training as complete
|
||||
embeddingRequestor.send_data(
|
||||
EmbeddingsRequestEnum.reload_classification_model.value,
|
||||
{"model_name": model_name},
|
||||
)
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model_name,
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
requestor.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user