disallow extra keys in config

This commit is contained in:
Blake Blackshear 2021-09-04 16:56:01 -05:00
parent 8109445fdd
commit e8eb3125a5
2 changed files with 37 additions and 31 deletions

View File

@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import yaml import yaml
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Extra, Field, validator
from pydantic.fields import PrivateAttr from pydantic.fields import PrivateAttr
from frigate.const import BASE_DIR, CACHE_DIR, RECORD_DIR from frigate.const import BASE_DIR, CACHE_DIR, RECORD_DIR
@ -29,18 +29,23 @@ DEFAULT_TRACKED_OBJECTS = ["person"]
DEFAULT_DETECTORS = {"cpu": {"type": "cpu"}} DEFAULT_DETECTORS = {"cpu": {"type": "cpu"}}
class FrigateBaseModel(BaseModel):
class Config:
extra = Extra.forbid
class DetectorTypeEnum(str, Enum): class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu" edgetpu = "edgetpu"
cpu = "cpu" cpu = "cpu"
class DetectorConfig(BaseModel): class DetectorConfig(FrigateBaseModel):
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type") type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
device: str = Field(default="usb", title="Device Type") device: str = Field(default="usb", title="Device Type")
num_threads: int = Field(default=3, title="Number of detection threads") num_threads: int = Field(default=3, title="Number of detection threads")
class MqttConfig(BaseModel): class MqttConfig(FrigateBaseModel):
host: str = Field(title="MQTT Host") host: str = Field(title="MQTT Host")
port: int = Field(default=1883, title="MQTT Port") port: int = Field(default=1883, title="MQTT Port")
topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix") topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix")
@ -60,7 +65,7 @@ class MqttConfig(BaseModel):
return v return v
class RetainConfig(BaseModel): class RetainConfig(FrigateBaseModel):
default: int = Field(default=10, title="Default retention period.") default: int = Field(default=10, title="Default retention period.")
objects: Dict[str, int] = Field( objects: Dict[str, int] = Field(
default_factory=dict, title="Object retention period." default_factory=dict, title="Object retention period."
@ -68,7 +73,7 @@ class RetainConfig(BaseModel):
# DEPRECATED: Will eventually be removed # DEPRECATED: Will eventually be removed
class ClipsConfig(BaseModel): class ClipsConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Save clips.") enabled: bool = Field(default=False, title="Save clips.")
max_seconds: int = Field(default=300, title="Maximum clip duration.") max_seconds: int = Field(default=300, title="Maximum clip duration.")
pre_capture: int = Field(default=5, title="Seconds to capture before event starts.") pre_capture: int = Field(default=5, title="Seconds to capture before event starts.")
@ -85,7 +90,7 @@ class ClipsConfig(BaseModel):
) )
class RecordConfig(BaseModel): class RecordConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable record on all cameras.") enabled: bool = Field(default=False, title="Enable record on all cameras.")
retain_days: int = Field(default=0, title="Recording retention period in days.") retain_days: int = Field(default=0, title="Recording retention period in days.")
events: ClipsConfig = Field( events: ClipsConfig = Field(
@ -93,7 +98,7 @@ class RecordConfig(BaseModel):
) )
class MotionConfig(BaseModel): class MotionConfig(FrigateBaseModel):
threshold: int = Field( threshold: int = Field(
default=25, default=25,
title="Motion detection threshold (1-255).", title="Motion detection threshold (1-255).",
@ -146,9 +151,10 @@ class RuntimeMotionConfig(MotionConfig):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = Extra.ignore
class DetectConfig(BaseModel): class DetectConfig(FrigateBaseModel):
height: int = Field(default=720, title="Height of the stream for the detect role.") height: int = Field(default=720, title="Height of the stream for the detect role.")
width: int = Field(default=1280, title="Width of the stream for the detect role.") width: int = Field(default=1280, title="Width of the stream for the detect role.")
fps: int = Field( fps: int = Field(
@ -160,7 +166,7 @@ class DetectConfig(BaseModel):
) )
class FilterConfig(BaseModel): class FilterConfig(FrigateBaseModel):
min_area: int = Field( min_area: int = Field(
default=0, title="Minimum area of bounding box for object to be counted." default=0, title="Minimum area of bounding box for object to be counted."
) )
@ -201,8 +207,10 @@ class RuntimeFilterConfig(FilterConfig):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = Extra.ignore
# this uses the base model because the color is an extra attribute
class ZoneConfig(BaseModel): class ZoneConfig(BaseModel):
filters: Dict[str, FilterConfig] = Field( filters: Dict[str, FilterConfig] = Field(
default_factory=dict, title="Zone filters." default_factory=dict, title="Zone filters."
@ -244,7 +252,7 @@ class ZoneConfig(BaseModel):
self._contour = np.array([]) self._contour = np.array([])
class ObjectConfig(BaseModel): class ObjectConfig(FrigateBaseModel):
track: List[str] = Field(default=DEFAULT_TRACKED_OBJECTS, title="Objects to track.") track: List[str] = Field(default=DEFAULT_TRACKED_OBJECTS, title="Objects to track.")
filters: Optional[Dict[str, FilterConfig]] = Field(title="Object filters.") filters: Optional[Dict[str, FilterConfig]] = Field(title="Object filters.")
mask: Union[str, List[str]] = Field(default="", title="Object mask.") mask: Union[str, List[str]] = Field(default="", title="Object mask.")
@ -256,7 +264,7 @@ class BirdseyeModeEnum(str, Enum):
continuous = "continuous" continuous = "continuous"
class BirdseyeConfig(BaseModel): class BirdseyeConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable birdseye view.") enabled: bool = Field(default=True, title="Enable birdseye view.")
width: int = Field(default=1280, title="Birdseye width.") width: int = Field(default=1280, title="Birdseye width.")
height: int = Field(default=720, title="Birdseye height.") height: int = Field(default=720, title="Birdseye height.")
@ -303,7 +311,7 @@ RECORD_FFMPEG_OUTPUT_ARGS_DEFAULT = [
] ]
class FfmpegOutputArgsConfig(BaseModel): class FfmpegOutputArgsConfig(FrigateBaseModel):
detect: Union[str, List[str]] = Field( detect: Union[str, List[str]] = Field(
default=DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT, default=DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT,
title="Detect role FFmpeg output arguments.", title="Detect role FFmpeg output arguments.",
@ -318,7 +326,7 @@ class FfmpegOutputArgsConfig(BaseModel):
) )
class FfmpegConfig(BaseModel): class FfmpegConfig(FrigateBaseModel):
global_args: Union[str, List[str]] = Field( global_args: Union[str, List[str]] = Field(
default=FFMPEG_GLOBAL_ARGS_DEFAULT, title="Global FFmpeg arguments." default=FFMPEG_GLOBAL_ARGS_DEFAULT, title="Global FFmpeg arguments."
) )
@ -340,7 +348,7 @@ class CameraRoleEnum(str, Enum):
detect = "detect" detect = "detect"
class CameraInput(BaseModel): class CameraInput(FrigateBaseModel):
path: str = Field(title="Camera input path.") path: str = Field(title="Camera input path.")
roles: List[CameraRoleEnum] = Field(title="Roles assigned to this input.") roles: List[CameraRoleEnum] = Field(title="Roles assigned to this input.")
global_args: Union[str, List[str]] = Field( global_args: Union[str, List[str]] = Field(
@ -371,7 +379,7 @@ class CameraFfmpegConfig(FfmpegConfig):
return v return v
class SnapshotsConfig(BaseModel): class SnapshotsConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Snapshots enabled.") enabled: bool = Field(default=False, title="Snapshots enabled.")
clean_copy: bool = Field( clean_copy: bool = Field(
default=True, title="Create a clean copy of the snapshot image." default=True, title="Create a clean copy of the snapshot image."
@ -399,7 +407,7 @@ class SnapshotsConfig(BaseModel):
) )
class ColorConfig(BaseModel): class ColorConfig(FrigateBaseModel):
red: int = Field(default=255, ge=0, le=255, title="Red") red: int = Field(default=255, ge=0, le=255, title="Red")
green: int = Field(default=255, ge=0, le=255, title="Green") green: int = Field(default=255, ge=0, le=255, title="Green")
blue: int = Field(default=255, ge=0, le=255, title="Blue") blue: int = Field(default=255, ge=0, le=255, title="Blue")
@ -417,7 +425,7 @@ class TimestampEffectEnum(str, Enum):
shadow = "shadow" shadow = "shadow"
class TimestampStyleConfig(BaseModel): class TimestampStyleConfig(FrigateBaseModel):
position: TimestampPositionEnum = Field( position: TimestampPositionEnum = Field(
default=TimestampPositionEnum.tl, title="Timestamp position." default=TimestampPositionEnum.tl, title="Timestamp position."
) )
@ -427,7 +435,7 @@ class TimestampStyleConfig(BaseModel):
effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.") effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.")
class CameraMqttConfig(BaseModel): class CameraMqttConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Send image over MQTT.") enabled: bool = Field(default=True, title="Send image over MQTT.")
timestamp: bool = Field(default=True, title="Add timestamp to MQTT image.") timestamp: bool = Field(default=True, title="Add timestamp to MQTT image.")
bounding_box: bool = Field(default=True, title="Add bounding box to MQTT image.") bounding_box: bool = Field(default=True, title="Add bounding box to MQTT image.")
@ -445,16 +453,16 @@ class CameraMqttConfig(BaseModel):
) )
class RtmpConfig(BaseModel): class RtmpConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="RTMP restreaming enabled.") enabled: bool = Field(default=True, title="RTMP restreaming enabled.")
class CameraLiveConfig(BaseModel): class CameraLiveConfig(FrigateBaseModel):
height: int = Field(default=720, title="Live camera view height") height: int = Field(default=720, title="Live camera view height")
quality: int = Field(default=8, ge=1, le=31, title="Live camera view quality") quality: int = Field(default=8, ge=1, le=31, title="Live camera view quality")
class CameraConfig(BaseModel): class CameraConfig(FrigateBaseModel):
name: Optional[str] = Field(title="Camera name.") name: Optional[str] = Field(title="Camera name.")
ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.") ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.")
best_image_timeout: int = Field( best_image_timeout: int = Field(
@ -590,13 +598,13 @@ class CameraConfig(BaseModel):
return [part for part in cmd if part != ""] return [part for part in cmd if part != ""]
class DatabaseConfig(BaseModel): class DatabaseConfig(FrigateBaseModel):
path: str = Field( path: str = Field(
default=os.path.join(BASE_DIR, "frigate.db"), title="Database path." default=os.path.join(BASE_DIR, "frigate.db"), title="Database path."
) )
class ModelConfig(BaseModel): class ModelConfig(FrigateBaseModel):
width: int = Field(default=320, title="Object detection model input width.") width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.") height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field( labelmap: Dict[int, str] = Field(
@ -636,7 +644,7 @@ class LogLevelEnum(str, Enum):
critical = "critical" critical = "critical"
class LoggerConfig(BaseModel): class LoggerConfig(FrigateBaseModel):
default: LogLevelEnum = Field( default: LogLevelEnum = Field(
default=LogLevelEnum.info, title="Default logging level." default=LogLevelEnum.info, title="Default logging level."
) )
@ -645,7 +653,7 @@ class LoggerConfig(BaseModel):
) )
class FrigateConfig(BaseModel): class FrigateConfig(FrigateBaseModel):
mqtt: MqttConfig = Field(title="MQTT Configuration.") mqtt: MqttConfig = Field(title="MQTT Configuration.")
database: DatabaseConfig = Field( database: DatabaseConfig = Field(
default_factory=DatabaseConfig, title="Database configuration." default_factory=DatabaseConfig, title="Database configuration."

View File

@ -962,7 +962,7 @@ class TestConfig(unittest.TestCase):
config = { config = {
"mqtt": {"host": "mqtt"}, "mqtt": {"host": "mqtt"},
"timestamp_style": {"position": "bl", "scale": 1.5}, "timestamp_style": {"position": "bl"},
"cameras": { "cameras": {
"back": { "back": {
"ffmpeg": { "ffmpeg": {
@ -981,7 +981,6 @@ class TestConfig(unittest.TestCase):
runtime_config = frigate_config.runtime_config runtime_config = frigate_config.runtime_config
assert runtime_config.cameras["back"].timestamp_style.position == "bl" assert runtime_config.cameras["back"].timestamp_style.position == "bl"
assert runtime_config.cameras["back"].timestamp_style.scale == 1.5
def test_default_timestamp_style(self): def test_default_timestamp_style(self):
@ -1005,14 +1004,13 @@ class TestConfig(unittest.TestCase):
runtime_config = frigate_config.runtime_config runtime_config = frigate_config.runtime_config
assert runtime_config.cameras["back"].timestamp_style.position == "tl" assert runtime_config.cameras["back"].timestamp_style.position == "tl"
assert runtime_config.cameras["back"].timestamp_style.scale == 1.0
def test_global_timestamp_style_merge(self): def test_global_timestamp_style_merge(self):
config = { config = {
"mqtt": {"host": "mqtt"}, "mqtt": {"host": "mqtt"},
"rtmp": {"enabled": False}, "rtmp": {"enabled": False},
"timestamp_style": {"position": "br", "scale": 2.0}, "timestamp_style": {"position": "br", "thickness": 2},
"cameras": { "cameras": {
"back": { "back": {
"ffmpeg": { "ffmpeg": {
@ -1023,7 +1021,7 @@ class TestConfig(unittest.TestCase):
}, },
] ]
}, },
"timestamp_style": {"position": "bl", "scale": 1.5}, "timestamp_style": {"position": "bl", "thickness": 4},
} }
}, },
} }
@ -1032,7 +1030,7 @@ class TestConfig(unittest.TestCase):
runtime_config = frigate_config.runtime_config runtime_config = frigate_config.runtime_config
assert runtime_config.cameras["back"].timestamp_style.position == "bl" assert runtime_config.cameras["back"].timestamp_style.position == "bl"
assert runtime_config.cameras["back"].timestamp_style.scale == 1.5 assert runtime_config.cameras["back"].timestamp_style.thickness == 4
if __name__ == "__main__": if __name__ == "__main__":