blakeblackshear.frigate/frigate/detectors/detector_config.py
2023-04-30 13:32:36 -05:00

148 lines
4.6 KiB
Python

import hashlib
import json
import logging
from enum import Enum
import os
from typing import Dict, List, Optional, Tuple, Union, Literal
import requests
import matplotlib.pyplot as plt
from pydantic import BaseModel, Extra, Field, validator
from pydantic.fields import PrivateAttr
from frigate.plus import PlusApi
from frigate.util import load_labels
logger = logging.getLogger(__name__)
class PixelFormatEnum(str, Enum):
rgb = "rgb"
bgr = "bgr"
yuv = "yuv"
class InputTensorEnum(str, Enum):
nchw = "nchw"
nhwc = "nhwc"
class ModelTypeEnum(str, Enum):
ssd = "ssd"
yolox = "yolox"
yolov5 = "yolov5"
yolov8 = "yolov8"
class ModelConfig(BaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field(
default_factory=dict, title="Labelmap customization."
)
input_tensor: InputTensorEnum = Field(
default=InputTensorEnum.nhwc, title="Model Input Tensor Shape"
)
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
model_type: ModelTypeEnum = Field(
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()
_model_hash: str = PrivateAttr()
@property
def merged_labelmap(self) -> Dict[int, str]:
return self._merged_labelmap
@property
def colormap(self) -> Dict[int, Tuple[int, int, int]]:
return self._colormap
@property
def model_hash(self) -> str:
return self._model_hash
def __init__(self, **config):
super().__init__(**config)
self._merged_labelmap = {
**load_labels(config.get("labelmap_path", "/labelmap.txt")),
**config.get("labelmap", {}),
}
self._colormap = {}
def check_and_load_plus_model(
self, plus_api: PlusApi, detector: str = None
) -> None:
if not self.path or not self.path.startswith("plus://"):
return
model_id = self.path[7:]
self.path = f"/config/model_cache/{model_id}"
model_info_path = f"{self.path}.json"
# download the model if it doesn't exist
if not os.path.isfile(self.path):
download_url = plus_api.get_model_download_url(model_id)
r = requests.get(download_url)
with open(self.path, "wb") as f:
f.write(r.content)
# download the model info if it doesn't exist
if not os.path.isfile(model_info_path):
model_info = plus_api.get_model_info(model_id)
with open(model_info_path, "w") as f:
json.dump(model_info, f)
else:
with open(model_info_path, "r") as f:
model_info = json.load(f)
if detector and detector not in model_info["supportedDetectors"]:
raise ValueError(f"Model does not support detector type of {detector}")
self.width = model_info["width"]
self.height = model_info["height"]
self.input_tensor = model_info["inputShape"]
self.input_pixel_format = model_info["pixelFormat"]
self.model_type = model_info["type"]
self._merged_labelmap = {
**{int(key): val for key, val in model_info["labelMap"].items()},
**self.labelmap,
}
def compute_model_hash(self) -> None:
with open(self.path, "rb") as f:
file_hash = hashlib.md5()
while chunk := f.read(8192):
file_hash.update(chunk)
self._model_hash = file_hash.hexdigest()
def create_colormap(self, enabled_labels: set[str]) -> None:
"""Get a list of colors for enabled labels."""
cmap = plt.cm.get_cmap("tab10", len(enabled_labels))
for key, val in enumerate(enabled_labels):
self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
class Config:
extra = Extra.forbid
class BaseDetectorConfig(BaseModel):
# the type field must be defined in all subclasses
type: str = Field(default="cpu", title="Detector Type")
model: ModelConfig = Field(
default=None, title="Detector specific model configuration."
)
class Config:
extra = Extra.allow
arbitrary_types_allowed = True