safe refactoring (#2552)

Co-authored-by: YS <ys@gm.com>
This commit is contained in:
Yuriy Sannikov 2021-12-31 20:59:43 +03:00 committed by Blake Blackshear
parent 9e987fdebc
commit 80627e4989
7 changed files with 36 additions and 36 deletions

View File

@ -13,8 +13,7 @@ from pydantic import BaseModel, Extra, Field, validator
from pydantic.fields import PrivateAttr
from frigate.const import BASE_DIR, CACHE_DIR, YAML_EXT
from frigate.edgetpu import load_labels
from frigate.util import create_mask, deep_merge
from frigate.util import create_mask, deep_merge, load_labels
logger = logging.getLogger(__name__)
@ -640,7 +639,7 @@ class ModelConfig(FrigateBaseModel):
return self._merged_labelmap
@property
def colormap(self) -> Dict[int, tuple[int, int, int]]:
def colormap(self) -> Dict[int, Tuple[int, int, int]]:
return self._colormap
def __init__(self, **config):

View File

@ -13,31 +13,11 @@ import tflite_runtime.interpreter as tflite
from setproctitle import setproctitle
from tflite_runtime.interpreter import load_delegate
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
logger = logging.getLogger(__name__)
def load_labels(path, encoding="utf-8"):
"""Loads labels from file (with or without index numbers).
Args:
path: path to label file.
encoding: label file encoding.
Returns:
Dictionary mapping indices to labels.
"""
with open(path, "r", encoding=encoding) as f:
lines = f.readlines()
if not lines:
return {}
if lines[0].split(" ", maxsplit=1)[0].isdigit():
pairs = [line.split(" ", maxsplit=1) for line in lines]
return {int(index): label.strip() for index, label in pairs}
else:
return {index: line.strip() for index, line in enumerate(lines)}
class ObjectDetector(ABC):
@abstractmethod
def detect(self, tensor_input, threshold=0.4):

View File

@ -359,9 +359,10 @@ def best(camera_name, label):
crop = bool(request.args.get("crop", 0, type=int))
if crop:
box = best_object.get("box", (0, 0, 300, 300))
box_size = 300
box = best_object.get("box", (0, 0, box_size, box_size))
region = calculate_region(
best_frame.shape, box[0], box[1], box[2], box[3], 1.1
best_frame.shape, box[0], box[1], box[2], box[3], box_size, multiplier=1.1
)
best_frame = best_frame[region[1] : region[3], region[0] : region[2]]

View File

@ -107,7 +107,7 @@ def create_mqtt_client(config: FrigateConfig, camera_metrics):
+ str(rc)
)
logger.info("MQTT connected")
logger.debug("MQTT connected")
client.subscribe(f"{mqtt_config.topic_prefix}/#")
client.publish(mqtt_config.topic_prefix + "/available", "online", retain=True)

View File

@ -18,12 +18,12 @@ import numpy as np
from frigate.config import CameraConfig, SnapshotsConfig, RecordConfig, FrigateConfig
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
from frigate.edgetpu import load_labels
from frigate.util import (
SharedMemoryFrameManager,
calculate_region,
draw_box_with_label,
draw_timestamp,
load_labels,
)
logger = logging.getLogger(__name__)
@ -264,8 +264,9 @@ class TrackedObject:
if crop:
box = self.thumbnail_data["box"]
box_size = 300
region = calculate_region(
best_frame.shape, box[0], box[1], box[2], box[3], 1.1
best_frame.shape, box[0], box[1], box[2], box[3], box_size, multiplier=1.1
)
best_frame = best_frame[region[1] : region[3], region[0] : region[2]]

View File

@ -189,12 +189,12 @@ def draw_box_with_label(
)
def calculate_region(frame_shape, xmin, ymin, xmax, ymax, multiplier=2):
def calculate_region(frame_shape, xmin, ymin, xmax, ymax, model_size, multiplier=2):
# size is the longest edge and divisible by 4
size = int((max(xmax - xmin, ymax - ymin) * multiplier) // 4 * 4)
# dont go any smaller than 300
if size < 300:
size = 300
# dont go any smaller than the model_size
if size < model_size:
size = model_size
# x_offset is midpoint of bounding box minus half the size
x_offset = int((xmax - xmin) / 2.0 + xmin - size / 2.0)
@ -601,6 +601,24 @@ def add_mask(mask, mask_img):
)
cv2.fillPoly(mask_img, pts=[contour], color=(0))
def load_labels(path, encoding="utf-8"):
"""Loads labels from file (with or without index numbers).
Args:
path: path to label file.
encoding: label file encoding.
Returns:
Dictionary mapping indices to labels.
"""
with open(path, "r", encoding=encoding) as f:
lines = f.readlines()
if not lines:
return {}
if lines[0].split(" ", maxsplit=1)[0].isdigit():
pairs = [line.split(" ", maxsplit=1) for line in lines]
return {int(index): label.strip() for index, label in pairs}
else:
return {index: line.strip() for index, line in enumerate(lines)}
class FrameManager(ABC):
@abstractmethod

View File

@ -529,15 +529,16 @@ def process_frames(
# combine motion boxes with known locations of existing objects
combined_boxes = reduce_boxes(motion_boxes + tracked_object_boxes)
region_min_size = max(model_shape[0], model_shape[1])
# compute regions
regions = [
calculate_region(frame_shape, a[0], a[1], a[2], a[3], 1.2)
calculate_region(frame_shape, a[0], a[1], a[2], a[3], region_min_size, multiplier=1.2)
for a in combined_boxes
]
# consolidate regions with heavy overlap
regions = [
calculate_region(frame_shape, a[0], a[1], a[2], a[3], 1.0)
calculate_region(frame_shape, a[0], a[1], a[2], a[3], region_min_size, multiplier=1.0)
for a in reduce_boxes(regions, 0.4)
]
@ -596,7 +597,7 @@ def process_frames(
box = obj[2]
# calculate a new region that will hopefully get the entire object
region = calculate_region(
frame_shape, box[0], box[1], box[2], box[3]
frame_shape, box[0], box[1], box[2], box[3], region_min_size
)
regions.append(region)