Cleanup config validation (#11235)

* Fix reading model config dict

* Fix irrelevant warnings

* Fix tests
This commit is contained in:
Nicolas Mowen 2024-05-04 09:15:03 -06:00 committed by GitHub
parent 51dcdd6f4b
commit 2dd5b893a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 21 deletions

View File

@ -1505,29 +1505,26 @@ class FrigateConfig(FrigateBaseModel):
for key, detector in config.detectors.items(): for key, detector in config.detectors.items():
adapter = TypeAdapter(DetectorConfig) adapter = TypeAdapter(DetectorConfig)
model_dict = ( model_dict = (
detector if isinstance(detector, dict) else detector.model_dump() detector
if isinstance(detector, dict)
else detector.model_dump(warnings="none")
) )
detector_config: DetectorConfig = adapter.validate_python(model_dict) detector_config: DetectorConfig = adapter.validate_python(model_dict)
if detector_config.model is None: if detector_config.model is None:
detector_config.model = config.model detector_config.model = config.model.model_copy()
else: else:
model = detector_config.model path = detector_config.model.path
schema = ModelConfig.model_json_schema()["properties"] detector_config.model = config.model.model_copy()
if ( detector_config.model.path = path
model.width != schema["width"]["default"]
or model.height != schema["height"]["default"] if "path" not in model_dict or len(model_dict.keys()) > 1:
or model.labelmap_path is not None
or model.labelmap
or model.input_tensor != schema["input_tensor"]["default"]
or model.input_pixel_format
!= schema["input_pixel_format"]["default"]
):
logger.warning( logger.warning(
"Customizing more than a detector model path is unsupported." "Customizing more than a detector model path is unsupported."
) )
merged_model = deep_merge( merged_model = deep_merge(
detector_config.model.model_dump(exclude_unset=True), detector_config.model.model_dump(exclude_unset=True, warnings="none"),
config.model.model_dump(exclude_unset=True), config.model.model_dump(exclude_unset=True, warnings="none"),
) )
if "path" not in merged_model: if "path" not in merged_model:

View File

@ -82,7 +82,7 @@ class TestConfig(unittest.TestCase):
}, },
"edgetpu": { "edgetpu": {
"type": "edgetpu", "type": "edgetpu",
"model": {"path": "/edgetpu_model.tflite", "width": 160}, "model": {"path": "/edgetpu_model.tflite"},
}, },
"openvino": { "openvino": {
"type": "openvino", "type": "openvino",
@ -112,11 +112,6 @@ class TestConfig(unittest.TestCase):
assert runtime_config.detectors["edgetpu"].model.path == "/edgetpu_model.tflite" assert runtime_config.detectors["edgetpu"].model.path == "/edgetpu_model.tflite"
assert runtime_config.detectors["openvino"].model.path == "/etc/hosts" assert runtime_config.detectors["openvino"].model.path == "/etc/hosts"
assert runtime_config.model.width == 512
assert runtime_config.detectors["cpu"].model.width == 320
assert runtime_config.detectors["edgetpu"].model.width == 160
assert runtime_config.detectors["openvino"].model.width == 512
def test_invalid_mqtt_config(self): def test_invalid_mqtt_config(self):
config = { config = {
"mqtt": {"host": "mqtt", "user": "test"}, "mqtt": {"host": "mqtt", "user": "test"},