Saqr/saqr/apps/detect_cli.py

136 lines
4.1 KiB
Python

"""Simple PPE detection without person tracking."""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import cv2
from ultralytics import YOLO
from saqr.core.detection import get_inference_config, set_inference_config
from saqr.core.model import resolve_model_path
from saqr.utils.logger import get_logger
log = get_logger("Inference", "detect")
VIOLATION = {"no-helmet", "no-vest", "no-boots", "no-gloves", "no-goggles"}
COMPLIANT = {"helmet", "vest", "boots", "gloves", "goggles"}
GREEN = (0, 200, 0)
RED = (0, 0, 220)
BLUE = (200, 100, 0)
WHITE = (255, 255, 255)
def box_color(label: str):
if label in VIOLATION:
return RED
if label in COMPLIANT:
return GREEN
return BLUE
def draw_boxes(frame, results, model):
for box in results.boxes:
cls_id = int(box.cls)
label = model.names[cls_id]
conf = float(box.conf)
x1, y1, x2, y2 = map(int, box.xyxy[0])
color = box_color(label)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
text = f"{label} {conf:.2f}"
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.48, 1)
y_t = max(y1, th + 6)
cv2.rectangle(frame, (x1, y_t - th - 4), (x1 + tw + 4, y_t), color, -1)
cv2.putText(frame, text, (x1 + 2, y_t - 3),
cv2.FONT_HERSHEY_SIMPLEX, 0.48, WHITE, 1, cv2.LINE_AA)
def run_video(model, source, conf):
cap = cv2.VideoCapture(int(source) if source.isdigit() else source)
if not cap.isOpened():
log.error(f"Cannot open: {source}")
return
print("Running - q to quit, s to save.")
prev = time.time()
infer_kw = get_inference_config()
while True:
ret, frame = cap.read()
if not ret:
break
results = model(frame, conf=conf, verbose=False, **infer_kw)[0]
draw_boxes(frame, results, model)
fps = 1.0 / max(time.time() - prev, 1e-9)
prev = time.time()
cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, WHITE, 2, cv2.LINE_AA)
cv2.imshow("Saqr Detect", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
if key == ord("s"):
cv2.imwrite("detect_saved.jpg", frame)
print("Saved: detect_saved.jpg")
cap.release()
cv2.destroyAllWindows()
def run_image(model, path, conf):
frame = cv2.imread(path)
if frame is None:
log.error(f"Cannot read: {path}")
return
results = model(frame, conf=conf, verbose=False)[0]
draw_boxes(frame, results, model)
out = Path(path).stem + "_detect.jpg"
cv2.imwrite(out, frame)
print(f"Saved: {out}")
cv2.imshow("Saqr Detect", frame)
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
parser = argparse.ArgumentParser(description="Saqr simple PPE detection")
parser.add_argument("--source", default="0")
parser.add_argument("--model", default="saqr_best.pt")
parser.add_argument("--conf", type=float, default=0.35)
parser.add_argument("--device", default="0", help="'cpu', '0', 'cuda:0'")
parser.add_argument("--half", action="store_true")
parser.add_argument("--imgsz", type=int, default=320)
args = parser.parse_args()
set_inference_config(device=args.device, half=args.half, imgsz=args.imgsz)
try:
import torch
if not torch.cuda.is_available() and args.device != "cpu":
log.warning("CUDA unavailable - falling back to CPU")
set_inference_config(device="cpu", half=False, imgsz=args.imgsz)
except ImportError:
pass
try:
model_path = resolve_model_path(args.model)
except FileNotFoundError as e:
log.error(str(e))
raise SystemExit(1)
model = YOLO(str(model_path))
src = args.source
if src.isdigit() or Path(src).suffix.lower() in {".mp4", ".avi", ".mov", ".mkv"}:
run_video(model, src, args.conf)
elif Path(src).exists():
run_image(model, src, args.conf)
else:
log.error(f"Source not found: {src}")
raise SystemExit(1)
if __name__ == "__main__":
main()