diff --git a/docs/docs/configuration/advanced.md b/docs/docs/configuration/advanced.md index 74d8d222f..14dcb73de 100644 --- a/docs/docs/configuration/advanced.md +++ b/docs/docs/configuration/advanced.md @@ -110,10 +110,17 @@ detectors: ### `model` +If using a custom model, the width and height will need to be specified. + +The labelmap can be customized to your needs. A common reason to do this is to combine multiple object types that are easily confused when you don't need to be as granular such as car/truck. By default, truck is renamed to car because they are often confused. You cannot add new object types, but you can change the names of existing objects in the model. + ```yaml model: # Required: height of the trained model height: 320 # Required: width of the trained model width: 320 + # Optional: labelmap overrides + labelmap: + 7: car ``` diff --git a/docs/docs/configuration/objects.mdx b/docs/docs/configuration/objects.mdx index 3e95f9e83..a8608c286 100644 --- a/docs/docs/configuration/objects.mdx +++ b/docs/docs/configuration/objects.mdx @@ -4,13 +4,13 @@ title: Default available objects sidebar_label: Available objects --- -import labels from '../../../labelmap.txt'; +import labels from "../../../labelmap.txt"; By default, Frigate includes the following object models from the Google Coral test data. @@ -23,14 +23,3 @@ Models for both CPU and EdgeTPU (Coral) are bundled in the image. You can use yo - Labels: `/labelmap.txt` You also need to update the model width/height in the config if they differ from the defaults. - -### Customizing the Labelmap - -The labelmap can be customized to your needs. A common reason to do this is to combine multiple object types that are easily confused when you don't need to be as granular such as car/truck. You must retain the same number of labels, but you can change the names. To change: - -- Download the [COCO labelmap](https://dl.google.com/coral/canned_models/coco_labels.txt) -- Modify the label names as desired. For example, change `7 truck` to `7 car` -- Mount the new file at `/labelmap.txt` in the container with an additional volume - ``` - -v ./config/labelmap.txt:/labelmap.txt - ``` diff --git a/frigate/app.py b/frigate/app.py index 1d3d7bda5..b76c8146d 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -259,6 +259,7 @@ class FrigateApp: name, config, model_shape, + self.config.model.merged_labelmap, self.detection_queue, self.detection_out_events[name], self.detected_frames_queue, diff --git a/frigate/config.py b/frigate/config.py index 7a19ae25b..7e3426a34 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -13,6 +13,7 @@ from pydantic.fields import PrivateAttr import yaml from frigate.const import BASE_DIR, RECORD_DIR, CACHE_DIR +from frigate.edgetpu import load_labels from frigate.util import create_mask, deep_merge logger = logging.getLogger(__name__) @@ -615,6 +616,22 @@ class DatabaseConfig(BaseModel): class ModelConfig(BaseModel): width: int = Field(default=320, title="Object detection model input width.") height: int = Field(default=320, title="Object detection model input height.") + labelmap: Dict[int, str] = Field( + default_factory=dict, title="Labelmap customization." + ) + _merged_labelmap: Optional[Dict[int, str]] = PrivateAttr() + + @property + def merged_labelmap(self) -> Dict[int, str]: + return self._merged_labelmap + + def __init__(self, **config): + super().__init__(**config) + + self._merged_labelmap = { + **load_labels("/labelmap.txt"), + **config.get("labelmap", {}), + } class LogLevelEnum(str, Enum): diff --git a/frigate/edgetpu.py b/frigate/edgetpu.py index 2ffd8c198..b426e40a6 100644 --- a/frigate/edgetpu.py +++ b/frigate/edgetpu.py @@ -231,7 +231,7 @@ class EdgeTPUProcess: class RemoteObjectDetector: def __init__(self, name, labels, detection_queue, event, model_shape): - self.labels = load_labels(labels) + self.labels = labels self.name = name self.fps = EventsPerSecond() self.detection_queue = detection_queue diff --git a/frigate/test/test_config.py b/frigate/test/test_config.py index 0625e2c28..af1c8a593 100644 --- a/frigate/test/test_config.py +++ b/frigate/test/test_config.py @@ -503,6 +503,86 @@ class TestConfig(unittest.TestCase): runtime_config = frigate_config.runtime_config assert round(runtime_config.cameras["back"].motion.contour_area) == 99 + def test_merge_labelmap(self): + + config = { + "mqtt": {"host": "mqtt"}, + "model": {"labelmap": {7: "truck"}}, + "cameras": { + "back": { + "ffmpeg": { + "inputs": [ + { + "path": "rtsp://10.0.0.1:554/video", + "roles": ["detect"], + }, + ] + }, + "height": 1080, + "width": 1920, + } + }, + } + + frigate_config = FrigateConfig(**config) + assert config == frigate_config.dict(exclude_unset=True) + + runtime_config = frigate_config.runtime_config + assert runtime_config.model.merged_labelmap[7] == "truck" + + def test_default_labelmap_empty(self): + + config = { + "mqtt": {"host": "mqtt"}, + "cameras": { + "back": { + "ffmpeg": { + "inputs": [ + { + "path": "rtsp://10.0.0.1:554/video", + "roles": ["detect"], + }, + ] + }, + "height": 1080, + "width": 1920, + } + }, + } + + frigate_config = FrigateConfig(**config) + assert config == frigate_config.dict(exclude_unset=True) + + runtime_config = frigate_config.runtime_config + assert runtime_config.model.merged_labelmap[0] == "person" + + def test_default_labelmap(self): + + config = { + "mqtt": {"host": "mqtt"}, + "model": {"width": 320, "height": 320}, + "cameras": { + "back": { + "ffmpeg": { + "inputs": [ + { + "path": "rtsp://10.0.0.1:554/video", + "roles": ["detect"], + }, + ] + }, + "height": 1080, + "width": 1920, + } + }, + } + + frigate_config = FrigateConfig(**config) + assert config == frigate_config.dict(exclude_unset=True) + + runtime_config = frigate_config.runtime_config + assert runtime_config.model.merged_labelmap[0] == "person" + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/frigate/video.py b/frigate/video.py index 23897dabd..7285dc0e5 100755 --- a/frigate/video.py +++ b/frigate/video.py @@ -318,6 +318,7 @@ def track_camera( name, config: CameraConfig, model_shape, + labelmap, detection_queue, result_connection, detected_objects_queue, @@ -344,7 +345,7 @@ def track_camera( motion_detector = MotionDetector(frame_shape, config.motion) object_detector = RemoteObjectDetector( - name, "/labelmap.txt", detection_queue, result_connection, model_shape + name, labelmap, detection_queue, result_connection, model_shape ) object_tracker = ObjectTracker(config.detect)