Saqr/saqr/core/detection.py

57 lines
1.5 KiB
Python

"""YOLO inference and PPE class tables."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple
from ultralytics import YOLO
STATUSES = ("SAFE", "PARTIAL", "UNSAFE")
CLASS_ORDER = [
"boots", "gloves", "goggles", "helmet",
"no-boots", "no-gloves", "no-goggles", "no-helmet", "no-vest", "vest",
]
PPE_SET = set(CLASS_ORDER)
POSITIVE_TO_NEGATIVE = {
"helmet": "no-helmet",
"vest": "no-vest",
"boots": "no-boots",
"gloves": "no-gloves",
"goggles": "no-goggles",
}
PPE_DISPLAY_ORDER = ["helmet", "vest", "gloves", "goggles", "boots"]
@dataclass
class PPEItem:
label: str
conf: float
bbox: Tuple[int, int, int, int]
_INFER_KWARGS: Dict = {"device": "cpu", "half": False, "imgsz": 640}
def set_inference_config(*, device: str, half: bool, imgsz: int) -> None:
_INFER_KWARGS.update(device=device, half=half, imgsz=imgsz)
def get_inference_config() -> Dict:
return dict(_INFER_KWARGS)
def collect_detections(frame, model: YOLO, conf: float) -> List[PPEItem]:
"""Run YOLO and return only PPE-class detections."""
results = model(frame, conf=conf, verbose=False, **_INFER_KWARGS)[0]
items: List[PPEItem] = []
for box in results.boxes:
cls_id = int(box.cls)
label = model.names[cls_id]
if label not in PPE_SET:
continue
x1, y1, x2, y2 = map(int, box.xyxy[0])
items.append(PPEItem(label=label, conf=float(box.conf), bbox=(x1, y1, x2, y2)))
return items