mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Face detection (#14544)
* Add support for face detection * Add support for detecting faces during registration * Set body size to be larger * Undo
This commit is contained in:
parent
fbf067b582
commit
491542c01f
@ -246,6 +246,8 @@ http {
|
||||
proxy_no_cache $should_not_cache;
|
||||
add_header X-Cache-Status $upstream_cache_status;
|
||||
|
||||
client_max_body_size 10M;
|
||||
|
||||
location /api/vod/ {
|
||||
include auth_request.conf;
|
||||
proxy_pass http://frigate_api/vod/;
|
||||
|
@ -8,6 +8,9 @@ class EventsSubLabelBody(BaseModel):
|
||||
subLabelScore: Optional[float] = Field(
|
||||
title="Score for sub label", default=None, gt=0.0, le=1.0
|
||||
)
|
||||
camera: Optional[str] = Field(
|
||||
title="Camera this object is detected on.", default=None
|
||||
)
|
||||
|
||||
|
||||
class EventsDescriptionBody(BaseModel):
|
||||
|
@ -901,22 +901,41 @@ def set_sub_label(
|
||||
try:
|
||||
event: Event = Event.get(Event.id == event_id)
|
||||
except DoesNotExist:
|
||||
if not body.camera:
|
||||
return JSONResponse(
|
||||
content=({"success": False, "message": "Event " + event_id + " not found"}),
|
||||
content=(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Event "
|
||||
+ event_id
|
||||
+ " not found and camera is not provided.",
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
event = None
|
||||
|
||||
if request.app.detected_frames_processor:
|
||||
tracked_obj: TrackedObject = (
|
||||
request.app.detected_frames_processor.camera_states[
|
||||
event.camera if event else body.camera
|
||||
].tracked_objects.get(event_id)
|
||||
)
|
||||
else:
|
||||
tracked_obj = None
|
||||
|
||||
if not event and not tracked_obj:
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{"success": False, "message": "Event " + event_id + " not found."}
|
||||
),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
new_sub_label = body.subLabel
|
||||
new_score = body.subLabelScore
|
||||
|
||||
if not event.end_time:
|
||||
# update tracked object
|
||||
tracked_obj: TrackedObject = (
|
||||
request.app.detected_frames_processor.camera_states[
|
||||
event.camera
|
||||
].tracked_objects.get(event.id)
|
||||
)
|
||||
|
||||
if tracked_obj:
|
||||
tracked_obj.obj_data["sub_label"] = (new_sub_label, new_score)
|
||||
|
||||
@ -925,6 +944,7 @@ def set_sub_label(
|
||||
data=Timeline.data.update({"sub_label": (new_sub_label, new_score)})
|
||||
).where(Timeline.source_id == event_id).execute()
|
||||
|
||||
if event:
|
||||
event.sub_label = new_sub_label
|
||||
|
||||
if new_score:
|
||||
@ -933,6 +953,7 @@ def set_sub_label(
|
||||
event.data = data
|
||||
|
||||
event.save()
|
||||
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{
|
||||
|
@ -129,7 +129,8 @@ class Embeddings:
|
||||
model_name="facenet",
|
||||
model_file="facenet.onnx",
|
||||
download_urls={
|
||||
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx"
|
||||
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx",
|
||||
"facedet.onnx": "https://github.com/opencv/opencv_zoo/raw/refs/heads/main/models/face_detection_yunet/face_detection_yunet_2023mar_int8.onnx",
|
||||
},
|
||||
model_size="large",
|
||||
model_type=ModelTypeEnum.face,
|
||||
|
@ -72,6 +72,19 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.tracked_events: dict[str, list[any]] = {}
|
||||
self.genai_client = get_genai_client(config)
|
||||
|
||||
@property
|
||||
def face_detector(self) -> cv2.FaceDetectorYN:
|
||||
# Lazily create the classifier.
|
||||
if "face_detector" not in self.__dict__:
|
||||
self.__dict__["face_detector"] = cv2.FaceDetectorYN.create(
|
||||
"/config/model_cache/facenet/facedet.onnx",
|
||||
config="",
|
||||
input_size=(320, 320),
|
||||
score_threshold=0.8,
|
||||
nms_threshold=0.3,
|
||||
)
|
||||
return self.__dict__["face_detector"]
|
||||
|
||||
def run(self) -> None:
|
||||
"""Maintain a SQLite-vec database for semantic search."""
|
||||
while not self.stop_event.is_set():
|
||||
@ -90,7 +103,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
def _process_requests(self) -> None:
|
||||
"""Process embeddings requests"""
|
||||
|
||||
def _handle_request(topic: str, data: str) -> str:
|
||||
def _handle_request(topic: str, data: dict[str, any]) -> str:
|
||||
try:
|
||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||
return serialize(
|
||||
@ -110,12 +123,34 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.embeddings.text_embedding([data])[0], pack=False
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.register_face.value:
|
||||
if data.get("cropped"):
|
||||
self.embeddings.embed_face(
|
||||
data["face_name"],
|
||||
base64.b64decode(data["image"]),
|
||||
upsert=True,
|
||||
)
|
||||
return None
|
||||
return True
|
||||
else:
|
||||
img = cv2.imdecode(
|
||||
np.frombuffer(
|
||||
base64.b64decode(data["image"]), dtype=np.uint8
|
||||
),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
face_box = self._detect_face(img)
|
||||
|
||||
if not face_box:
|
||||
return False
|
||||
|
||||
face = img[face_box[1] : face_box[3], face_box[0] : face_box[2]]
|
||||
ret, webp = cv2.imencode(
|
||||
".webp", face, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
||||
)
|
||||
self.embeddings.embed_face(
|
||||
data["face_name"], webp.tobytes(), upsert=True
|
||||
)
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to handle embeddings request {e}")
|
||||
|
||||
@ -277,7 +312,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
if event_id:
|
||||
self.handle_regenerate_description(event_id, source)
|
||||
|
||||
def _search_face(self, query_embedding: bytes) -> list:
|
||||
def _search_face(self, query_embedding: bytes) -> list[tuple[str, float]]:
|
||||
"""Search for the face most closely matching the embedding."""
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
@ -289,6 +324,29 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
"""
|
||||
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||
|
||||
def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||
"""Detect faces in input image."""
|
||||
self.face_detector.setInputSize((input.shape[1], input.shape[0]))
|
||||
faces = self.face_detector.detect(input)
|
||||
|
||||
if faces[1] is None:
|
||||
return None
|
||||
|
||||
face = None
|
||||
|
||||
for _, potential_face in enumerate(faces[1]):
|
||||
raw_bbox = potential_face[0:4].astype(np.uint16)
|
||||
x: int = max(raw_bbox[0], 0)
|
||||
y: int = max(raw_bbox[1], 0)
|
||||
w: int = raw_bbox[2]
|
||||
h: int = raw_bbox[3]
|
||||
bbox = (x, y, x + w, y + h)
|
||||
|
||||
if face is None or area(bbox) > area(face):
|
||||
face = bbox
|
||||
|
||||
return face
|
||||
|
||||
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
||||
"""Look for faces in image."""
|
||||
id = obj_data["id"]
|
||||
@ -309,8 +367,23 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
face: Optional[dict[str, any]] = None
|
||||
|
||||
if self.requires_face_detection:
|
||||
# TODO run cv2 face detection
|
||||
pass
|
||||
logger.debug("Running manual face detection.")
|
||||
person_box = obj_data.get("box")
|
||||
|
||||
if not person_box:
|
||||
return None
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
left, top, right, bottom = person_box
|
||||
person = rgb[top:bottom, left:right]
|
||||
face = self._detect_face(person)
|
||||
|
||||
if not face:
|
||||
logger.debug("Detected no faces for person object.")
|
||||
return
|
||||
|
||||
face_frame = person[face[1] : face[3], face[0] : face[2]]
|
||||
face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
# don't run for object without attributes
|
||||
if not obj_data.get("current_attributes"):
|
||||
@ -332,16 +405,16 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
face_box = face.get("box")
|
||||
|
||||
# check that face is valid
|
||||
if (
|
||||
not face_box
|
||||
or area(face_box) < self.config.semantic_search.face_recognition.min_area
|
||||
):
|
||||
if not face_box or area(face_box) < self.config.face_recognition.min_area:
|
||||
logger.debug(f"Invalid face box {face}")
|
||||
return
|
||||
|
||||
face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420)
|
||||
face_frame = face_frame[face_box[1] : face_box[3], face_box[0] : face_box[2]]
|
||||
ret, jpg = cv2.imencode(
|
||||
face_frame = face_frame[
|
||||
face_box[1] : face_box[3], face_box[0] : face_box[2]
|
||||
]
|
||||
|
||||
ret, webp = cv2.imencode(
|
||||
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
||||
)
|
||||
|
||||
@ -349,12 +422,13 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
logger.debug("Not processing face due to error creating cropped image.")
|
||||
return
|
||||
|
||||
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
|
||||
embedding = self.embeddings.embed_face("unknown", webp.tobytes(), upsert=False)
|
||||
query_embedding = serialize(embedding)
|
||||
best_faces = self._search_face(query_embedding)
|
||||
logger.debug(f"Detected best faces for person as: {best_faces}")
|
||||
|
||||
if not best_faces or len(best_faces) < REQUIRED_FACES:
|
||||
logger.debug(f"{len(best_faces)} < {REQUIRED_FACES} min required faces.")
|
||||
return
|
||||
|
||||
sub_label = str(best_faces[0][0]).split("-")[0]
|
||||
@ -363,28 +437,34 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
for face in best_faces:
|
||||
score = 1.0 - face[1]
|
||||
|
||||
if face[0] != sub_label:
|
||||
if face[0].split("-")[0] != sub_label:
|
||||
logger.debug("Detected multiple faces, result is not valid.")
|
||||
return None
|
||||
return
|
||||
|
||||
avg_score += score
|
||||
|
||||
avg_score = avg_score / REQUIRED_FACES
|
||||
avg_score = round(avg_score / REQUIRED_FACES, 2)
|
||||
|
||||
if avg_score < self.config.semantic_search.face_recognition.threshold or (
|
||||
if avg_score < self.config.face_recognition.threshold or (
|
||||
id in self.detected_faces and avg_score <= self.detected_faces[id]
|
||||
):
|
||||
logger.debug(
|
||||
"Detected face does not score higher than threshold / previous face."
|
||||
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
self.detected_faces[id] = avg_score
|
||||
requests.post(
|
||||
resp = requests.post(
|
||||
f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label",
|
||||
json={"subLabel": sub_label, "subLabelScore": avg_score},
|
||||
json={
|
||||
"camera": obj_data.get("camera"),
|
||||
"subLabel": sub_label,
|
||||
"subLabelScore": avg_score,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
self.detected_faces[id] = avg_score
|
||||
|
||||
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
|
||||
"""Return jpg thumbnail of a region of the frame."""
|
||||
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
||||
|
Loading…
Reference in New Issue
Block a user