blakeblackshear.frigate/frigate/detectors/plugins/deepstack.py
Martin Weinelt ab50d0b006
Add isort and ruff linter (#6575)
* Add isort and ruff linter

Both linters are pretty common among modern python code bases.

The isort tool provides stable sorting and grouping, as well as pruning
of unused imports.

Ruff is a modern linter, that is very fast due to being written in rust.
It can detect many common issues in a python codebase.

Removes the pylint dev requirement, since ruff replaces it.

* treewide: fix issues detected by ruff

* treewide: fix bare except clauses

* .devcontainer: Set up isort

* treewide: optimize imports

* treewide: apply black

* treewide: make regex patterns raw strings

This is necessary for escape sequences to be properly recognized.
2023-05-29 05:31:17 -05:00

83 lines
2.8 KiB
Python

import io
import logging
import numpy as np
import requests
from PIL import Image
from pydantic import Field
from typing_extensions import Literal
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
logger = logging.getLogger(__name__)
DETECTOR_KEY = "deepstack"
class DeepstackDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
api_url: str = Field(
default="http://localhost:80/v1/vision/detection", title="DeepStack API URL"
)
api_timeout: float = Field(default=0.1, title="DeepStack API timeout (in seconds)")
api_key: str = Field(default="", title="DeepStack API key (if required)")
class DeepStack(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, detector_config: DeepstackDetectorConfig):
self.api_url = detector_config.api_url
self.api_timeout = detector_config.api_timeout
self.api_key = detector_config.api_key
self.labels = detector_config.model.merged_labelmap
def get_label_index(self, label_value):
if label_value.lower() == "truck":
label_value = "car"
for index, value in self.labels.items():
if value == label_value.lower():
return index
return -1
def detect_raw(self, tensor_input):
image_data = np.squeeze(tensor_input).astype(np.uint8)
image = Image.fromarray(image_data)
self.w, self.h = image.size
with io.BytesIO() as output:
image.save(output, format="JPEG")
image_bytes = output.getvalue()
data = {"api_key": self.api_key}
response = requests.post(
self.api_url,
data=data,
files={"image": image_bytes},
timeout=self.api_timeout,
)
response_json = response.json()
detections = np.zeros((20, 6), np.float32)
if response_json.get("predictions") is None:
logger.debug(f"Error in parsing response json: {response_json}")
return detections
for i, detection in enumerate(response_json.get("predictions")):
logger.debug(f"Response: {detection}")
if detection["confidence"] < 0.4:
logger.debug("Break due to confidence < 0.4")
break
label = self.get_label_index(detection["label"])
if label < 0:
logger.debug("Break due to unknown label")
break
detections[i] = [
label,
float(detection["confidence"]),
detection["y_min"] / self.h,
detection["x_min"] / self.w,
detection["y_max"] / self.h,
detection["x_max"] / self.w,
]
return detections