Saqr/saqr/core/grouping.py

81 lines
2.6 KiB
Python

"""Cluster PPE detections into per-person candidates."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
from saqr.core.detection import PPEItem
from saqr.core.geometry import box_distance, expand_bbox, merge_boxes
@dataclass
class PersonCandidate:
bbox: Tuple[int, int, int, int]
items: Dict[str, float]
detections: List[PPEItem] = field(default_factory=list)
def should_merge(candidate: PersonCandidate, item: PPEItem) -> bool:
cx1, cy1, cx2, cy2 = candidate.bbox
ix1, iy1, ix2, iy2 = item.bbox
cw, ch = cx2 - cx1, cy2 - cy1
iw, ih = ix2 - ix1, iy2 - iy1
cxc, cyc = (cx1 + cx2) / 2, (cy1 + cy2) / 2
ixc, iyc = (ix1 + ix2) / 2, (iy1 + iy2) / 2
max_dx = max(cw, iw) * 1.2 + 40
max_dy = max(ch, ih) * 1.8 + 50
return abs(ixc - cxc) <= max_dx and abs(iyc - cyc) <= max_dy
def group_detections_to_people(detections: List[PPEItem], w: int, h: int) -> List[PersonCandidate]:
if not detections:
return []
candidates: List[PersonCandidate] = []
for item in detections:
merged = False
for cand in candidates:
if should_merge(cand, item):
cand.bbox = merge_boxes(cand.bbox, item.bbox)
cand.items[item.label] = max(cand.items.get(item.label, 0.0), item.conf)
cand.detections.append(item)
merged = True
break
if not merged:
candidates.append(PersonCandidate(
bbox=item.bbox,
items={item.label: item.conf},
detections=[item],
))
again = True
while again:
again = False
merged_list: List[PersonCandidate] = []
for person in candidates:
matched = False
for prev in merged_list:
pw = prev.bbox[2] - prev.bbox[0]
ph = prev.bbox[3] - prev.bbox[1]
dist = box_distance(prev.bbox, person.bbox)
th = max(pw, ph) * 0.55
if dist <= th:
prev.bbox = merge_boxes(prev.bbox, person.bbox)
for label, conf in person.items.items():
prev.items[label] = max(prev.items.get(label, 0.0), conf)
prev.detections.extend(person.detections)
again = True
matched = True
break
if not matched:
merged_list.append(person)
candidates = merged_list
for cand in candidates:
cand.bbox = expand_bbox(cand.bbox, w, h)
return candidates