81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
"""Cluster PPE detections into per-person candidates."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Tuple
|
|
|
|
from saqr.core.detection import PPEItem
|
|
from saqr.core.geometry import box_distance, expand_bbox, merge_boxes
|
|
|
|
|
|
@dataclass
|
|
class PersonCandidate:
|
|
bbox: Tuple[int, int, int, int]
|
|
items: Dict[str, float]
|
|
detections: List[PPEItem] = field(default_factory=list)
|
|
|
|
|
|
def should_merge(candidate: PersonCandidate, item: PPEItem) -> bool:
|
|
cx1, cy1, cx2, cy2 = candidate.bbox
|
|
ix1, iy1, ix2, iy2 = item.bbox
|
|
cw, ch = cx2 - cx1, cy2 - cy1
|
|
iw, ih = ix2 - ix1, iy2 - iy1
|
|
|
|
cxc, cyc = (cx1 + cx2) / 2, (cy1 + cy2) / 2
|
|
ixc, iyc = (ix1 + ix2) / 2, (iy1 + iy2) / 2
|
|
|
|
max_dx = max(cw, iw) * 1.2 + 40
|
|
max_dy = max(ch, ih) * 1.8 + 50
|
|
|
|
return abs(ixc - cxc) <= max_dx and abs(iyc - cyc) <= max_dy
|
|
|
|
|
|
def group_detections_to_people(detections: List[PPEItem], w: int, h: int) -> List[PersonCandidate]:
|
|
if not detections:
|
|
return []
|
|
|
|
candidates: List[PersonCandidate] = []
|
|
for item in detections:
|
|
merged = False
|
|
for cand in candidates:
|
|
if should_merge(cand, item):
|
|
cand.bbox = merge_boxes(cand.bbox, item.bbox)
|
|
cand.items[item.label] = max(cand.items.get(item.label, 0.0), item.conf)
|
|
cand.detections.append(item)
|
|
merged = True
|
|
break
|
|
if not merged:
|
|
candidates.append(PersonCandidate(
|
|
bbox=item.bbox,
|
|
items={item.label: item.conf},
|
|
detections=[item],
|
|
))
|
|
|
|
again = True
|
|
while again:
|
|
again = False
|
|
merged_list: List[PersonCandidate] = []
|
|
for person in candidates:
|
|
matched = False
|
|
for prev in merged_list:
|
|
pw = prev.bbox[2] - prev.bbox[0]
|
|
ph = prev.bbox[3] - prev.bbox[1]
|
|
dist = box_distance(prev.bbox, person.bbox)
|
|
th = max(pw, ph) * 0.55
|
|
if dist <= th:
|
|
prev.bbox = merge_boxes(prev.bbox, person.bbox)
|
|
for label, conf in person.items.items():
|
|
prev.items[label] = max(prev.items.get(label, 0.0), conf)
|
|
prev.detections.extend(person.detections)
|
|
again = True
|
|
matched = True
|
|
break
|
|
if not matched:
|
|
merged_list.append(person)
|
|
candidates = merged_list
|
|
|
|
for cand in candidates:
|
|
cand.bbox = expand_bbox(cand.bbox, w, h)
|
|
|
|
return candidates
|