From 2298ca740ce526e8a430d3e44df79d74c6fd2813 Mon Sep 17 00:00:00 2001 From: Alexander Smirnov Date: Wed, 8 Apr 2020 20:53:58 +0100 Subject: [PATCH] features added: - GPU via TensorRT - CPU: TensorFlow Lite -> Tensorflow --- .dockerignore | 4 + .gitignore | 6 +- Dockerfile | 32 ++-- Dockerfile.gpu | 183 ++++++++++++++++++++++ README.md | 16 +- benchmark.py | 0 detect_objects_gpu.py | 14 ++ engine.py | 87 +++++++++++ frigate/edgetpu.py | 100 +++++++----- frigate/tensorflowcpu.py | 51 +++++++ frigate/tensorrtgpu.py | 96 ++++++++++++ frigate/util.py | 0 frigate/video.py | 0 plugin/CMakeLists.txt | 47 ++++++ plugin/FlattenConcat.cpp | 320 +++++++++++++++++++++++++++++++++++++++ requirements.txt | 13 ++ 16 files changed, 914 insertions(+), 55 deletions(-) mode change 100755 => 100644 Dockerfile create mode 100644 Dockerfile.gpu mode change 100755 => 100644 benchmark.py create mode 100644 detect_objects_gpu.py create mode 100644 engine.py create mode 100644 frigate/tensorflowcpu.py create mode 100644 frigate/tensorrtgpu.py mode change 100755 => 100644 frigate/util.py mode change 100755 => 100644 frigate/video.py create mode 100644 plugin/CMakeLists.txt create mode 100644 plugin/FlattenConcat.cpp create mode 100644 requirements.txt diff --git a/.dockerignore b/.dockerignore index d77da2b48..361432ee5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,5 +2,9 @@ README.md diagram.png .gitignore debug +build +venv* +.idea config/ +docker-compose.* *.pyc \ No newline at end of file diff --git a/.gitignore b/.gitignore index bce858dc0..103c12259 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ *.pyc debug +build +venv* .vscode -config/config.yml \ No newline at end of file +.idea +config/config.yml +docker-compose.* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile old mode 100755 new mode 100644 index cb81052af..e5aae5568 --- a/Dockerfile +++ b/Dockerfile @@ -9,22 +9,20 @@ RUN apt -qq update && apt -qq install --no-install-recommends -y \ build-essential \ gnupg wget unzip \ # libcap-dev \ - && add-apt-repository ppa:deadsnakes/ppa -y \ && apt -qq install --no-install-recommends -y \ - python3.7 \ - python3.7-dev \ + python3-dev \ python3-pip \ ffmpeg \ # VAAPI drivers for Intel hardware accel libva-drm2 libva2 i965-va-driver vainfo \ - && python3.7 -m pip install -U wheel setuptools \ - && python3.7 -m pip install -U \ + && python3 -m pip install -U wheel pip setuptools \ + && python3 -m pip install -U \ opencv-python-headless \ # python-prctl \ numpy \ imutils \ scipy \ - && python3.7 -m pip install -U \ + && python3 -m pip install -U \ Flask \ paho-mqtt \ PyYAML \ @@ -37,23 +35,27 @@ RUN apt -qq update && apt -qq install --no-install-recommends -y \ && apt -qq install --no-install-recommends -y \ libedgetpu1-max \ ## Tensorflow lite (python 3.7 only) - && wget -q https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp37-cp37m-linux_x86_64.whl \ - && python3.7 -m pip install tflite_runtime-2.1.0.post1-cp37-cp37m-linux_x86_64.whl \ - && rm tflite_runtime-2.1.0.post1-cp37-cp37m-linux_x86_64.whl \ + && wget -q https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp36-cp36m-linux_x86_64.whl \ + && python3 -m pip install tflite_runtime-2.1.0.post1-cp36-cp36m-linux_x86_64.whl \ + && rm tflite_runtime-2.1.0.post1-cp36-cp36m-linux_x86_64.whl \ && rm -rf /var/lib/apt/lists/* \ - && (apt-get autoremove -y; apt-get autoclean -y) + && (apt-get autoremove -y; apt-get autoclean -y) \ + ## Tensorflow + && python3 -m pip install tensorflow==1.15.2 # get model and labels RUN wget -q https://github.com/google-coral/edgetpu/raw/master/test_data/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite -O /edgetpu_model.tflite --trust-server-names RUN wget -q https://dl.google.com/coral/canned_models/coco_labels.txt -O /labelmap.txt --trust-server-names -RUN wget -q https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip -O /cpu_model.zip && \ - unzip /cpu_model.zip detect.tflite -d / && \ - mv /detect.tflite /cpu_model.tflite && \ - rm /cpu_model.zip +RUN wget -q http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz -O /cpu_model.tar.gz && \ + tar -xf /cpu_model.tar.gz -C / ssd_mobilenet_v1_coco_2018_01_28/frozen_inference_graph.pb --strip-components 1 && \ + mv /frozen_inference_graph.pb /cpu_model.pb && \ + rm /cpu_model.tar.gz WORKDIR /opt/frigate/ ADD frigate frigate/ COPY detect_objects.py . COPY benchmark.py . -CMD ["python3.7", "-u", "detect_objects.py"] +ENV TF_CPP_MIN_LOG_LEVEL 2 + +CMD ["python3", "-u", "detect_objects.py"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 000000000..168aed934 --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,183 @@ +FROM frigate AS base + +# +# CUDA 10.2 base +# +# https://gitlab.com/nvidia/container-images/cuda/blob/master/dist/ubuntu18.04/10.2/base/Dockerfile +# +RUN apt-get update && apt-get install -y --no-install-recommends \ +gnupg2 curl ca-certificates && \ + curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ + echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ + echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ + apt-get purge --autoremove -y curl && \ +rm -rf /var/lib/apt/lists/* + +ENV CUDA_VERSION 10.2.89 +LABEL com.nvidia.cuda.version="${CUDA_VERSION}" + +ENV CUDA_PKG_VERSION 10-2=$CUDA_VERSION-1 + +# For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a +RUN apt-get update && apt-get install -y --no-install-recommends \ + cuda-cudart-$CUDA_PKG_VERSION \ +cuda-compat-10-2 && \ +ln -s cuda-10.2 /usr/local/cuda && \ + rm -rf /var/lib/apt/lists/* + +# Required for nvidia-docker v1 +RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ + echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 + +# nvidia-container-runtime +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=396,driver<397 brand=tesla,driver>=410,driver<411 brand=tesla,driver>=418,driver<419" + +# +# CUDA 10.2 runtime +# +# https://gitlab.com/nvidia/container-images/cuda/blob/master/dist/ubuntu18.04/10.2/runtime/Dockerfile +# +ENV NCCL_VERSION 2.5.6 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + cuda-libraries-$CUDA_PKG_VERSION \ +cuda-nvtx-$CUDA_PKG_VERSION \ +libcublas10=10.2.2.89-1 \ +libnccl2=$NCCL_VERSION-1+cuda10.2 && \ + apt-mark hold libnccl2 && \ + rm -rf /var/lib/apt/lists/* + +# +# cuDNN 7.6.5.32 runtime +# +# https://gitlab.com/nvidia/container-images/cuda/blob/master/dist/ubuntu18.04/10.2/runtime/cudnn7/Dockerfile +# +ENV CUDNN_VERSION 7.6.5.32 +LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libcudnn7=$CUDNN_VERSION-1+cuda10.2 \ +&& \ + apt-mark hold libcudnn7 && \ + rm -rf /var/lib/apt/lists/* + +# +# TensorRT 6.0.1.8 +# +# https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-601/tensorrt-install-guide/index.html#maclearn-net-repo-install +# +ENV TENSORRT_VERSION 6.0.1 +LABEL com.nvidia.tensorrt.version="${TENSORRT_VERSION}" + +RUN version=$TENSORRT_VERSION-1+cuda10.2 && \ + apt-get update && apt-get install -y --no-install-recommends \ + libnvinfer6=${version} \ + libnvonnxparsers6=${version} libnvparsers6=${version} \ + libnvinfer-plugin6=${version} \ + python3-libnvinfer=${version} \ +&& \ + apt-mark hold \ + libnvinfer6 \ + libnvonnxparsers6 libnvparsers6 \ + libnvinfer-plugin6 \ + python3-libnvinfer \ +&& \ + rm -rf /var/lib/apt/lists/* + +# +# Use a previous stage as a new temporary stage for building libraries +# +FROM base AS builder + +# +# CUDA 10.2 devel +# +# https://gitlab.com/nvidia/container-images/cuda/blob/master/dist/ubuntu18.04/10.2/devel/Dockerfile +# +RUN apt-get update && apt-get install -y --no-install-recommends \ + cuda-nvml-dev-$CUDA_PKG_VERSION \ + cuda-command-line-tools-$CUDA_PKG_VERSION \ +cuda-libraries-dev-$CUDA_PKG_VERSION \ + cuda-minimal-build-$CUDA_PKG_VERSION \ + libnccl-dev=$NCCL_VERSION-1+cuda10.2 \ +libcublas-dev=10.2.2.89-1 \ +&& \ + rm -rf /var/lib/apt/lists/* + +ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs + +# +# cuDNN 7.6.5.32 devel +# +# https://gitlab.com/nvidia/container-images/cuda/blob/master/dist/ubuntu18.04/10.2/devel/cudnn7/Dockerfile +# +RUN apt-get update && apt-get install -y --no-install-recommends \ + libcudnn7-dev=$CUDNN_VERSION-1+cuda10.2 \ +&& \ + rm -rf /var/lib/apt/lists/* + +# +# TensorRT 6.0.1.8 devel +# +# https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-601/tensorrt-install-guide/index.html#maclearn-net-repo-install +# +RUN version=$TENSORRT_VERSION-1+cuda10.2 && \ + apt-get update && apt-get install -y --no-install-recommends \ + libnvinfer-dev=${version} \ + libnvonnxparsers-dev=${version} libnvparsers-dev=${version} \ + libnvinfer-plugin-dev=${version} \ +&& \ + apt-mark hold \ + libnvinfer-dev \ + libnvonnxparsers-dev libnvparsers-dev \ + libnvinfer-plugin-dev \ +&& \ + rm -rf /var/lib/apt/lists/* + +# Install PyCUDA +RUN python3 -m pip install pycuda \ + && python3 -m pip wheel --wheel-dir install pycuda + +# Install Cmake +ENV CMAKE_VERSION 3.14.4 + +RUN cd /tmp && \ + wget https://github.com/Kitware/CMake/releases/download/v$CMAKE_VERSION/cmake-$CMAKE_VERSION-Linux-x86_64.sh && \ + chmod +x cmake-$CMAKE_VERSION-Linux-x86_64.sh && \ + ./cmake-$CMAKE_VERSION-Linux-x86_64.sh --prefix=/usr/local --exclude-subdir --skip-license && \ + rm ./cmake-$CMAKE_VERSION-Linux-x86_64.sh + +# Build plugin +ADD plugin plugin/ +RUN mkdir -p build \ + && cd build \ + && cmake ../plugin \ + && make \ + && cd .. + +# +# Copy libraries to the final image +# +FROM base AS result + +COPY --from=builder /opt/frigate/install install +COPY --from=builder /opt/frigate/build/libflattenconcat.so /usr/lib + +RUN python3 -m pip install install/* \ + && rm -r install + +# Get UFF model +RUN wget -q https://github.com/dusty-nv/jetson-inference/releases/download/model-mirror-190618/SSD-Mobilenet-v2.tar.gz -O /gpu_model.tar.gz && \ + tar -xf /gpu_model.tar.gz -C / SSD-Mobilenet-v2/ssd_mobilenet_v2_coco.uff --strip-components 1 && \ + mv /ssd_mobilenet_v2_coco.uff /gpu_model.uff && \ + rm /gpu_model.tar.gz + +COPY engine.py . +COPY detect_objects_gpu.py . + +CMD ["python3", "-u", "detect_objects_gpu.py"] diff --git a/README.md b/README.md index 33c84db23..405a4197d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Frigate - Realtime Object Detection for IP Cameras -Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras. Designed for integration with HomeAssistant or others via MQTT. +Uses OpenCV, Tensorflow/TensorRT to perform realtime object detection locally for IP cameras. Designed for integration with HomeAssistant or others via MQTT. -Use of a [Google Coral USB Accelerator](https://coral.withgoogle.com/products/accelerator/) is optional, but highly recommended. On my Intel i7 processor, I can process 2-3 FPS with the CPU. The Coral can process 100+ FPS with very low CPU load. +Use of a [Google Coral USB Accelerator](https://coral.withgoogle.com/products/accelerator/) or [Nvidia CUDA GPUs](https://developer.nvidia.com/cuda-gpus) is optional, but highly recommended. On my Intel i7 processor, I can process 24 FPS with the CPU. Budget entry-level GPU processes 64 FPS and powerful GPU or the Coral can process 100+ FPS with very low CPU load. - Leverages multiprocessing heavily with an emphasis on realtime over processing every frame - Uses a very low overhead motion detection to determine where to run object detection @@ -29,6 +29,10 @@ docker run --rm \ blakeblackshear/frigate:stable ``` +To run GPU accelerated `frigate-gpu` Docker image use the [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-(Native-GPU-Support)). +If your GPU supports Half precision (also known as FP16), you can boost performance by enabling this mode as follows: +`docker run --gpus all --env TRT_FLOAT_PRECISION=16 ...` + Example docker-compose: ```yaml frigate: @@ -47,6 +51,8 @@ Example docker-compose: FRIGATE_RTSP_PASSWORD: "password" ``` +Please note that native GPU support has not landed in docker-compose [yet](https://github.com/docker/compose/issues/6691). + A `config.yml` file must exist in the `config` directory. See example [here](config/config.example.yml) and device specific info can be found [here](docs/DEVICES.md). Access the mjpeg stream at `http://localhost:5000/` and the best snapshot for any object type with at `http://localhost:5000///best.jpg` @@ -118,10 +124,12 @@ sensor: unit_of_measurement: 'ms' ``` ## Using a custom model -Models for both CPU and EdgeTPU (Coral) are bundled in the image. You can use your own models with volume mounts: -- CPU Model: `/cpu_model.tflite` +Models for CPU/GPU and EdgeTPU (Coral) are bundled in the images. You can use your own models with volume mounts: +- CPU Model: `/cpu_model.pb` +- GPU Model: `/gpu_model.uff` - EdgeTPU Model: `/edgetpu_model.tflite` - Labels: `/labelmap.txt` ## Tips - Lower the framerate of the video feed on the camera to reduce the CPU usage for capturing the feed +- Choose smaller camera resolution as the images are resized to the shape of the model 300x300 anyway diff --git a/benchmark.py b/benchmark.py old mode 100755 new mode 100644 diff --git a/detect_objects_gpu.py b/detect_objects_gpu.py new file mode 100644 index 000000000..9ea0a7057 --- /dev/null +++ b/detect_objects_gpu.py @@ -0,0 +1,14 @@ +import os +import subprocess + +if __name__ == '__main__': + if not os.path.isfile('/gpu_model.buf'): + engine = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'engine.py') + subprocess.run(['python3', '-u', engine, + '-i', '/gpu_model.uff', + '-o', '/gpu_model.buf', + '-p', os.getenv('TRT_FLOAT_PRECISION', '32') + ], check=True) + + from detect_objects import main as detect_objects_main + detect_objects_main() diff --git a/engine.py b/engine.py new file mode 100644 index 000000000..bde5819e8 --- /dev/null +++ b/engine.py @@ -0,0 +1,87 @@ +import ctypes +import argparse +import sys +import os +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) + + +def model_input_shape(): + return 3, 300, 300 + + +def build_engine(uff_model_path, trt_engine_datatype=trt.DataType.FLOAT, batch_size=1): + with trt.Builder(TRT_LOGGER) as builder, \ + builder.create_network() as network, \ + trt.UffParser() as parser: + builder.max_workspace_size = 1 << 30 + builder.max_batch_size = batch_size + if trt_engine_datatype == trt.DataType.HALF: + builder.fp16_mode = True + + parser.register_input("Input", model_input_shape()) + parser.register_output("MarkOutput_0") + parser.parse(uff_model_path, network) + + return builder.build_cuda_engine(network) + + +def save_engine(engine, engine_dest_path): + os.makedirs(os.path.dirname(engine_dest_path), exist_ok=True) + buf = engine.serialize() + with open(engine_dest_path, 'wb') as f: + f.write(buf) + + +def load_engine(trt_runtime, engine_path): + with open(engine_path, 'rb') as f: + engine_data = f.read() + engine = trt_runtime.deserialize_cuda_engine(engine_data) + return engine + + +def load_plugins(): + trt.init_libnvinfer_plugins(TRT_LOGGER, '') + + try: + ctypes.CDLL('libflattenconcat.so') + except Exception as e: + print("Error: {}\n{}".format(e, "Make sure FlattenConcat custom plugin layer is provided")) + sys.exit(1) + + +TRT_PRECISION_TO_DATATYPE = { + 16: trt.DataType.HALF, + 32: trt.DataType.FLOAT +} + +if __name__ == '__main__': + # Define script command line arguments + parser = argparse.ArgumentParser(description='Utility to build TensorRT engine prior to inference.') + parser.add_argument('-i', "--input", + dest='uff_model_path', metavar='UFF_MODEL_PATH', required=True, + help='preprocessed TensorFlow model in UFF format') + parser.add_argument('-p', '--precision', type=int, choices=[32, 16], default=32, + help='desired TensorRT float precision to build an engine with') + parser.add_argument('-b', '--batch_size', type=int, default=1, + help='max TensorRT engine batch size') + parser.add_argument("-o", "--output", dest='trt_engine_path', + help="path of the output file", + default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "engine.buf")) + + # Parse arguments passed + args = parser.parse_args() + + load_plugins() + + # Using supplied .uff file alongside with UffParser build TensorRT engine + print("Building TensorRT engine. This may take few minutes.") + trt_engine = build_engine( + uff_model_path=args.uff_model_path, + trt_engine_datatype=TRT_PRECISION_TO_DATATYPE[args.precision], + batch_size=args.batch_size) + + # Save the engine to file + save_engine(trt_engine, args.trt_engine_path) + print("TensorRT engine saved to {}".format(args.trt_engine_path)) diff --git a/frigate/edgetpu.py b/frigate/edgetpu.py index b9a28976d..87758d660 100644 --- a/frigate/edgetpu.py +++ b/frigate/edgetpu.py @@ -7,6 +7,12 @@ import pyarrow.plasma as plasma import tflite_runtime.interpreter as tflite from tflite_runtime.interpreter import load_delegate from frigate.util import EventsPerSecond, listen +from frigate.tensorflowcpu import ObjectDetector as CPUObjectDetector +try: + import pycuda.driver as cuda + from frigate.tensorrtgpu import ObjectDetector as GPUObjectDetector +except ImportError: + pass def load_labels(path, encoding='utf-8'): """Loads labels from file (with or without index numbers). @@ -28,26 +34,22 @@ def load_labels(path, encoding='utf-8'): return {index: line.strip() for index, line in enumerate(lines)} class ObjectDetector(): - def __init__(self): - edge_tpu_delegate = None - try: - edge_tpu_delegate = load_delegate('libedgetpu.so.1.0') - except ValueError: - print("No EdgeTPU detected. Falling back to CPU.") - - if edge_tpu_delegate is None: - self.interpreter = tflite.Interpreter( - model_path='/cpu_model.tflite') - else: - self.interpreter = tflite.Interpreter( - model_path='/edgetpu_model.tflite', - experimental_delegates=[edge_tpu_delegate]) - + def __init__(self, edge_tpu_delegate): + self.interpreter = tflite.Interpreter( + model_path='/edgetpu_model.tflite', + experimental_delegates=[edge_tpu_delegate]) + self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + def detect_raw(self, tensor_input): self.interpreter.set_tensor(self.tensor_input_details[0]['index'], tensor_input) self.interpreter.invoke() @@ -57,34 +59,62 @@ class ObjectDetector(): detections = np.zeros((20,6), np.float32) for i, score in enumerate(scores): + if i == detections.shape[0]: + break detections[i] = [label_codes[i], score, boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]] - + return detections +def create_object_detector(): + edge_tpu_delegate = None + try: + edge_tpu_delegate = load_delegate('libedgetpu.so.1.0') + except ValueError: + pass + + if edge_tpu_delegate is not None: + return ObjectDetector(edge_tpu_delegate) + + gpu_device_count = 0 + try: + cuda.init() + gpu_device_count = cuda.Device.count() + except (RuntimeError, TypeError, NameError): + pass + except cuda.RuntimeError: + pass + + if gpu_device_count > 0: + print("No EdgeTPU detected. Falling back to GPU.") + return GPUObjectDetector() + + print("No EdgeTPU or GPU detected. Falling back to CPU.") + return CPUObjectDetector() + def run_detector(detection_queue, avg_speed, start): print(f"Starting detection process: {os.getpid()}") listen() plasma_client = plasma.connect("/tmp/plasma") - object_detector = ObjectDetector() - while True: - object_id_str = detection_queue.get() - object_id_hash = hashlib.sha1(str.encode(object_id_str)) - object_id = plasma.ObjectID(object_id_hash.digest()) - object_id_out = plasma.ObjectID(hashlib.sha1(str.encode(f"out-{object_id_str}")).digest()) - input_frame = plasma_client.get(object_id, timeout_ms=0) + with create_object_detector() as object_detector: + while True: + object_id_str = detection_queue.get() + object_id_hash = hashlib.sha1(str.encode(object_id_str)) + object_id = plasma.ObjectID(object_id_hash.digest()) + object_id_out = plasma.ObjectID(hashlib.sha1(str.encode(f"out-{object_id_str}")).digest()) + input_frame = plasma_client.get(object_id, timeout_ms=0) - if input_frame is plasma.ObjectNotAvailable: - continue + if input_frame is plasma.ObjectNotAvailable: + continue - # detect and put the output in the plasma store - start.value = datetime.datetime.now().timestamp() - plasma_client.put(object_detector.detect_raw(input_frame), object_id_out) - duration = datetime.datetime.now().timestamp()-start.value - start.value = 0.0 + # detect and put the output in the plasma store + start.value = datetime.datetime.now().timestamp() + plasma_client.put(object_detector.detect_raw(input_frame), object_id_out) + duration = datetime.datetime.now().timestamp()-start.value + start.value = 0.0 + + avg_speed.value = (avg_speed.value*9 + duration)/10 - avg_speed.value = (avg_speed.value*9 + duration)/10 - class EdgeTPUProcess(): def __init__(self): self.detection_queue = mp.SimpleQueue() @@ -114,7 +144,7 @@ class RemoteObjectDetector(): self.fps = EventsPerSecond() self.plasma_client = plasma.connect("/tmp/plasma") self.detection_queue = detection_queue - + def detect(self, tensor_input, threshold=.4): detections = [] @@ -139,4 +169,4 @@ class RemoteObjectDetector(): )) self.plasma_client.delete([object_id_frame, object_id_detections]) self.fps.update() - return detections \ No newline at end of file + return detections diff --git a/frigate/tensorflowcpu.py b/frigate/tensorflowcpu.py new file mode 100644 index 000000000..8cf787010 --- /dev/null +++ b/frigate/tensorflowcpu.py @@ -0,0 +1,51 @@ +import numpy as np +import tensorflow as tf + + +class ObjectDetector(): + def __init__(self): + self.detection_graph = tf.Graph() + with self.detection_graph.as_default(): + od_graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile('/cpu_model.pb', 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + + config = tf.compat.v1.ConfigProto( + device_count={'GPU': 0} + ) + self.sess = tf.compat.v1.Session( + graph=self.detection_graph, + config=config) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def detect_raw(self, tensor_input): + ops = self.detection_graph.get_operations() + all_tensor_names = {output.name for op in ops for output in op.outputs} + tensor_dict = {} + for key in ['detection_boxes', 'detection_scores', 'detection_classes']: + tensor_name = key + ':0' + if tensor_name in all_tensor_names: + tensor_dict[key] = self.detection_graph.get_tensor_by_name(tensor_name) + + image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') + output_dict = self.sess.run(tensor_dict, + feed_dict={image_tensor: tensor_input}) + + boxes = output_dict['detection_boxes'][0] + label_codes = output_dict['detection_classes'][0] - 1 + scores = output_dict['detection_scores'][0] + + detections = np.zeros((20, 6), np.float32) + for i, score in enumerate(scores): + if i == detections.shape[0]: + break + detections[i] = [label_codes[i], score, boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]] + + return detections diff --git a/frigate/tensorrtgpu.py b/frigate/tensorrtgpu.py new file mode 100644 index 000000000..2c1d62b7b --- /dev/null +++ b/frigate/tensorrtgpu.py @@ -0,0 +1,96 @@ +import numpy as np +import pycuda.driver as cuda +import tensorrt as trt +import engine +from collections import namedtuple +from pycuda.tools import make_default_context +from pycuda.tools import clear_context_caches + +HostDeviceMem = namedtuple('HostDeviceMem', 'host device') + + +class ObjectDetector(): + def __init__(self): + self.context = make_default_context() + self.device = self.context.get_device() + + engine.load_plugins() + + self.trt_runtime = trt.Runtime(engine.TRT_LOGGER) + self.trt_engine = engine.load_engine(self.trt_runtime, '/gpu_model.buf') + + self._allocate_buffers() + self.execution_context = self.trt_engine.create_execution_context() + + input_volume = trt.volume(engine.model_input_shape()) + self.numpy_array = np.zeros((self.trt_engine.max_batch_size, input_volume)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.context.pop() + self.context = None + + clear_context_caches() + + def detect_raw(self, tensor_input): + # HWC -> CHW + img_np = tensor_input.transpose((0, 3, 1, 2)) + # Normalize to [-1.0, 1.0] interval (expected by model) + img_np = (2.0 / 255.0) * img_np - 1.0 + img_np = img_np.ravel() + + np.copyto(self.inputs[0].host, img_np) + detection_out, keep_count_out = self._do_inference() + + detections = np.zeros((20, 6), np.float32) + for i in range(int(keep_count_out[0])): + if i == detections.shape[0]: + break + pred_start_idx = i * 7 + label = detection_out[pred_start_idx + 1] - 1 + score = detection_out[pred_start_idx + 2] + xmin = detection_out[pred_start_idx + 3] + ymin = detection_out[pred_start_idx + 4] + xmax = detection_out[pred_start_idx + 5] + ymax = detection_out[pred_start_idx + 6] + detections[i] = [label, score, ymin, xmin, ymax, xmax] + + return detections + + def _do_inference(self): + [cuda.memcpy_htod_async(inp.device, inp.host, self.stream) for inp in self.inputs] + self.execution_context.execute_async(batch_size=self.trt_engine.max_batch_size, + bindings=self.bindings, + stream_handle=self.stream.handle) + [cuda.memcpy_dtoh_async(out.host, out.device, self.stream) for out in self.outputs] + self.stream.synchronize() + return [out.host for out in self.outputs] + + def _allocate_buffers(self): + self.inputs = [] + self.outputs = [] + self.bindings = [] + self.stream = cuda.Stream() + + # NMS implementation in TRT 6 only supports DataType.FLOAT + binding_to_type = {"Input": np.float32, + "NMS": np.float32, + "NMS_1": np.int32} + for binding in self.trt_engine: + size = trt.volume(self.trt_engine.get_binding_shape(binding)) * self.trt_engine.max_batch_size + dtype = binding_to_type[str(binding)] + + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + + # Append the device buffer to device bindings. + self.bindings.append(int(device_mem)) + + # Append to the appropriate list. + if self.trt_engine.binding_is_input(binding): + self.inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + self.outputs.append(HostDeviceMem(host_mem, device_mem)) diff --git a/frigate/util.py b/frigate/util.py old mode 100755 new mode 100644 diff --git a/frigate/video.py b/frigate/video.py old mode 100755 new mode 100644 diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt new file mode 100644 index 000000000..010b6dbf8 --- /dev/null +++ b/plugin/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.2 FATAL_ERROR) + +project(FlattenConcat LANGUAGES CXX) + +# Enable all compile warnings +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-long-long -pedantic -Wno-deprecated-declarations") +# Use C++11 +set (CMAKE_CXX_STANDARD 11) + +# Sets variable to a value if variable is unset. +macro(set_ifndef var val) + if (NOT ${var}) + set(${var} ${val}) + endif() + message(STATUS "Configurable variable ${var} set to ${${var}}") +endmacro() + +# -------- CONFIGURATION -------- +find_package(CUDA REQUIRED) + +set_ifndef(TRT_LIB /usr/lib/x86_64-linux-gnu) +set_ifndef(TRT_INCLUDE /usr/include/x86_64-linux-gnu) +set_ifndef(CUDA_ROOT /usr/local/cuda) + +# Find dependencies: +message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n") + +# TensorRT's nvinfer lib +find_library(_NVINFER_LIB nvinfer HINTS ${TRT_LIB} PATH_SUFFIXES lib lib64) +set_ifndef(NVINFER_LIB ${_NVINFER_LIB}) + +# cuBLAS +find_library(_CUBLAS_LIB cublas HINTS ${CUDA_ROOT} PATH_SUFFIXES lib lib64) +set_ifndef(CUBLAS_LIB ${_CUBLAS_LIB}) + +# CUDA include dir +find_path(_CUDA_INC_DIR cuda_runtime_api.h HINTS ${CUDA_ROOT} PATH_SUFFIXES include) +set_ifndef(CUDA_INC_DIR ${_CUDA_INC_DIR}) + +# -------- BUILDING -------- +include_directories(${TRT_INCLUDE} ${CUDA_INC_DIR}) +add_library(flattenconcat MODULE + ${CMAKE_SOURCE_DIR}/FlattenConcat.cpp +) + +# Link TensorRT's nvinfer lib +target_link_libraries(flattenconcat PRIVATE ${NVINFER_LIB} ${CUBLAS_LIB}) diff --git a/plugin/FlattenConcat.cpp b/plugin/FlattenConcat.cpp new file mode 100644 index 000000000..b387c4b3b --- /dev/null +++ b/plugin/FlattenConcat.cpp @@ -0,0 +1,320 @@ +/* + * The TensorFlow SSD graph has some operations that are currently not supported in TensorRT. + * Using a preprocessor on the graph, multiple operations in the graph are combined into a + * single custom operation which is implemented as a plugin layer in TensorRT. The preprocessor + * stitches all nodes within a namespace into one custom node. + * + * The plugin called `FlattenConcat` is used to flatten each input and then concatenate the + * results. This is applied to the location and confidence data before it is fed to the post + * processor step. + * + * Loading FlattenConcat plugin library using CDLL has a side effect of loading FlattenConcat + * plugin into internal TensorRT plugin registry: the latter has FlattenConcat shipped with + * TensorRT, while we load own version. There are subtle differences between built-in + * FlattenConcat and this one. + * + * The pre-trained TensorFlow model has been converted to UFF format using this FlattenConcat + * plugin and we have to stick to it when building a TensorRT inference engine. To avoid collision + * with built-in plugin of the same name of version "1" we set version "B" and load it the last. + */ + +#include +#include +#include +#include +#include + +#include + +#include "NvInferPlugin.h" + +// Macro for calling GPU functions +#define CHECK(status) \ + do \ + { \ + auto ret = (status); \ + if (ret != 0) \ + { \ + std::cout << "Cuda failure: " << ret; \ + abort(); \ + } \ + } while (0) + +using namespace nvinfer1; + +namespace +{ +const char* FLATTENCONCAT_PLUGIN_VERSION{"B"}; +const char* FLATTENCONCAT_PLUGIN_NAME{"FlattenConcat_TRT"}; +} + +// Flattens all input tensors and concats their flattened version together +// along the major non-batch dimension, i.e axis = 1 +class FlattenConcat : public IPluginV2 +{ +public: + // Ordinary ctor, plugin not yet configured for particular inputs/output + FlattenConcat() {} + + // Ctor for clone() + FlattenConcat(const int* flattenedInputSize, int numInputs, int flattenedOutputSize) + : mFlattenedOutputSize(flattenedOutputSize) + { + for (int i = 0; i < numInputs; ++i) + mFlattenedInputSize.push_back(flattenedInputSize[i]); + } + + // Ctor for loading from serialized byte array + FlattenConcat(const void* data, size_t length) + { + const char* d = reinterpret_cast(data); + const char* a = d; + + size_t numInputs = read(d); + for (size_t i = 0; i < numInputs; ++i) + { + mFlattenedInputSize.push_back(read(d)); + } + mFlattenedOutputSize = read(d); + + assert(d == a + length); + } + + int getNbOutputs() const override + { + // We always return one output + return 1; + } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override + { + // At least one input + assert(nbInputDims >= 1); + // We only have one output, so it doesn't + // make sense to check index != 0 + assert(index == 0); + + size_t flattenedOutputSize = 0; + int inputVolume = 0; + + for (int i = 0; i < nbInputDims; ++i) + { + // We only support NCHW. And inputs Dims are without batch num. + assert(inputs[i].nbDims == 3); + + inputVolume = inputs[i].d[0] * inputs[i].d[1] * inputs[i].d[2]; + flattenedOutputSize += inputVolume; + } + + return DimsCHW(flattenedOutputSize, 1, 1); + } + + int initialize() override + { + // Called on engine initialization, we initialize cuBLAS library here, + // since we'll be using it for inference + CHECK(cublasCreate(&mCublas)); + return 0; + } + + void terminate() override + { + // Called on engine destruction, we destroy cuBLAS data structures, + // which were created in initialize() + CHECK(cublasDestroy(mCublas)); + } + + size_t getWorkspaceSize(int maxBatchSize) const override + { + // The operation is done in place, it doesn't use GPU memory + return 0; + } + + int enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream) override + { + // Does the actual concat of inputs, which is just + // copying all inputs bytes to output byte array + size_t inputOffset = 0; + float* output = reinterpret_cast(outputs[0]); + cublasSetStream(mCublas, stream); + + for (size_t i = 0; i < mFlattenedInputSize.size(); ++i) + { + const float* input = reinterpret_cast(inputs[i]); + for (int batchIdx = 0; batchIdx < batchSize; ++batchIdx) + { + CHECK(cublasScopy(mCublas, mFlattenedInputSize[i], + input + batchIdx * mFlattenedInputSize[i], 1, + output + (batchIdx * mFlattenedOutputSize + inputOffset), 1)); + } + inputOffset += mFlattenedInputSize[i]; + } + + return 0; + } + + size_t getSerializationSize() const override + { + // Returns FlattenConcat plugin serialization size + size_t size = sizeof(mFlattenedInputSize[0]) * mFlattenedInputSize.size() + + sizeof(mFlattenedOutputSize) + + sizeof(size_t); // For serializing mFlattenedInputSize vector size + return size; + } + + void serialize(void* buffer) const override + { + // Serializes FlattenConcat plugin into byte array + + // Cast buffer to char* and save its beginning to a, + // (since value of d will be changed during write) + char* d = reinterpret_cast(buffer); + char* a = d; + + size_t numInputs = mFlattenedInputSize.size(); + + // Write FlattenConcat fields into buffer + write(d, numInputs); + for (size_t i = 0; i < numInputs; ++i) + { + write(d, mFlattenedInputSize[i]); + } + write(d, mFlattenedOutputSize); + + // Sanity check - checks if d is offset + // from a by exactly the size of serialized plugin + assert(d == a + getSerializationSize()); + } + + void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override + { + // We only support one output + assert(nbOutputs == 1); + + // Reset plugin private data structures + mFlattenedInputSize.clear(); + mFlattenedOutputSize = 0; + + // For each input we save its size, we also validate it + for (int i = 0; i < nbInputs; ++i) + { + int inputVolume = 0; + + // We only support NCHW. And inputs Dims are without batch num. + assert(inputs[i].nbDims == 3); + + // All inputs dimensions along non concat axis should be same + for (size_t dim = 1; dim < 3; dim++) + { + assert(inputs[i].d[dim] == inputs[0].d[dim]); + } + + // Size of flattened input + inputVolume = inputs[i].d[0] * inputs[i].d[1] * inputs[i].d[2]; + mFlattenedInputSize.push_back(inputVolume); + mFlattenedOutputSize += mFlattenedInputSize[i]; + } + } + + bool supportsFormat(DataType type, PluginFormat format) const override + { + return (type == DataType::kFLOAT && format == PluginFormat::kNCHW); + } + + const char* getPluginType() const override { return FLATTENCONCAT_PLUGIN_NAME; } + + const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; } + + void destroy() override {} + + IPluginV2* clone() const override + { + return new FlattenConcat(mFlattenedInputSize.data(), mFlattenedInputSize.size(), mFlattenedOutputSize); + } + + void setPluginNamespace(const char* pluginNamespace) override + { + mPluginNamespace = pluginNamespace; + } + + const char* getPluginNamespace() const override + { + return mPluginNamespace.c_str(); + } + +private: + template + void write(char*& buffer, const T& val) const + { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + template + T read(const char*& buffer) + { + T val = *reinterpret_cast(buffer); + buffer += sizeof(T); + return val; + } + + // Number of elements in each plugin input, flattened + std::vector mFlattenedInputSize; + // Number of elements in output, flattened + int mFlattenedOutputSize{0}; + // cuBLAS library handle + cublasHandle_t mCublas; + // We're not using TensorRT namespaces in + // this sample, so it's just an empty string + std::string mPluginNamespace = ""; +}; + +// PluginCreator boilerplate code for FlattenConcat plugin +class FlattenConcatPluginCreator : public IPluginCreator +{ +public: + FlattenConcatPluginCreator() + { + mFC.nbFields = 0; + mFC.fields = 0; + } + + ~FlattenConcatPluginCreator() {} + + const char* getPluginName() const override { return FLATTENCONCAT_PLUGIN_NAME; } + + const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; } + + const PluginFieldCollection* getFieldNames() override { return &mFC; } + + IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override + { + return new FlattenConcat(); + } + + IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override + { + + return new FlattenConcat(serialData, serialLength); + } + + void setPluginNamespace(const char* pluginNamespace) override + { + mPluginNamespace = pluginNamespace; + } + + const char* getPluginNamespace() const override + { + return mPluginNamespace.c_str(); + } + +private: + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mPluginNamespace = ""; +}; + +PluginFieldCollection FlattenConcatPluginCreator::mFC{}; +std::vector FlattenConcatPluginCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(FlattenConcatPluginCreator); diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..e0d317d15 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +numpy +imutils +scipy +Flask +paho-mqtt +PyYAML +matplotlib +pyarrow +pycuda +tensorrt +opencv-python +tensorflow +tflite_runtime \ No newline at end of file