148 lines
4.5 KiB
Python
148 lines
4.5 KiB
Python
"""
|
|
Saqr - PPE Detection | Simple Detection (no tracking)
|
|
========================================================
|
|
Single-pass YOLO inference: draw PPE boxes on frame, no person tracking.
|
|
Green = PPE worn, Red = PPE missing.
|
|
|
|
Usage:
|
|
python detect.py --source 0
|
|
python detect.py --source image.jpg --model models/saqr_best.pt
|
|
"""
|
|
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
from logger import get_logger
|
|
|
|
log = get_logger("Inference", "detect")
|
|
|
|
# Global inference config (set by main())
|
|
_INFER_KWARGS: dict = {"device": "cpu", "half": False, "imgsz": 640}
|
|
|
|
VIOLATION = {"no-helmet", "no-vest", "no-boots", "no-gloves", "no-goggles"}
|
|
COMPLIANT = {"helmet", "vest", "boots", "gloves", "goggles"}
|
|
GREEN = (0, 200, 0)
|
|
RED = (0, 0, 220)
|
|
BLUE = (200, 100, 0)
|
|
WHITE = (255, 255, 255)
|
|
|
|
|
|
def box_color(label: str):
|
|
if label in VIOLATION:
|
|
return RED
|
|
if label in COMPLIANT:
|
|
return GREEN
|
|
return BLUE
|
|
|
|
|
|
def draw_boxes(frame, results, model):
|
|
for box in results.boxes:
|
|
cls_id = int(box.cls)
|
|
label = model.names[cls_id]
|
|
conf = float(box.conf)
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
color = box_color(label)
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
text = f"{label} {conf:.2f}"
|
|
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.48, 1)
|
|
y_t = max(y1, th + 6)
|
|
cv2.rectangle(frame, (x1, y_t - th - 4), (x1 + tw + 4, y_t), color, -1)
|
|
cv2.putText(frame, text, (x1 + 2, y_t - 3),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.48, WHITE, 1, cv2.LINE_AA)
|
|
|
|
|
|
def run_video(model, source, conf):
|
|
cap = cv2.VideoCapture(int(source) if source.isdigit() else source)
|
|
if not cap.isOpened():
|
|
log.error(f"Cannot open: {source}")
|
|
return
|
|
|
|
print("Running - q to quit, s to save.")
|
|
prev = time.time()
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
results = model(frame, conf=conf, verbose=False, **_INFER_KWARGS)[0]
|
|
draw_boxes(frame, results, model)
|
|
|
|
fps = 1.0 / max(time.time() - prev, 1e-9)
|
|
prev = time.time()
|
|
cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, WHITE, 2, cv2.LINE_AA)
|
|
|
|
cv2.imshow("Saqr Detect", frame)
|
|
key = cv2.waitKey(1) & 0xFF
|
|
if key == ord("q"):
|
|
break
|
|
if key == ord("s"):
|
|
cv2.imwrite("detect_saved.jpg", frame)
|
|
print("Saved: detect_saved.jpg")
|
|
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
def run_image(model, path, conf):
|
|
frame = cv2.imread(path)
|
|
if frame is None:
|
|
log.error(f"Cannot read: {path}")
|
|
return
|
|
results = model(frame, conf=conf, verbose=False)[0]
|
|
draw_boxes(frame, results, model)
|
|
out = Path(path).stem + "_detect.jpg"
|
|
cv2.imwrite(out, frame)
|
|
print(f"Saved: {out}")
|
|
cv2.imshow("Saqr Detect", frame)
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Saqr simple PPE detection")
|
|
parser.add_argument("--source", default="0")
|
|
parser.add_argument("--model", default="models/saqr_best.pt")
|
|
parser.add_argument("--conf", type=float, default=0.35)
|
|
parser.add_argument("--device", default="0", help="'cpu', '0', 'cuda:0'")
|
|
parser.add_argument("--half", action="store_true", help="FP16 inference")
|
|
parser.add_argument("--imgsz", type=int, default=320, help="Inference size")
|
|
args = parser.parse_args()
|
|
|
|
global _INFER_KWARGS
|
|
_INFER_KWARGS = {"device": args.device, "half": args.half, "imgsz": args.imgsz}
|
|
try:
|
|
import torch
|
|
if not torch.cuda.is_available() and args.device != "cpu":
|
|
log.warning("CUDA unavailable - falling back to CPU")
|
|
_INFER_KWARGS["device"] = "cpu"
|
|
_INFER_KWARGS["half"] = False
|
|
except ImportError:
|
|
pass
|
|
|
|
root = Path(__file__).parent
|
|
model_path = root / args.model
|
|
if not model_path.exists():
|
|
model_path = Path(args.model)
|
|
if not model_path.exists():
|
|
log.error(f"Model not found: {args.model}")
|
|
raise SystemExit(1)
|
|
|
|
model = YOLO(str(model_path))
|
|
src = args.source
|
|
if src.isdigit() or Path(src).suffix.lower() in {".mp4", ".avi", ".mov", ".mkv"}:
|
|
run_video(model, src, args.conf)
|
|
elif Path(src).exists():
|
|
run_image(model, src, args.conf)
|
|
else:
|
|
log.error(f"Source not found: {src}")
|
|
raise SystemExit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|