"""Saqr PPE tracking CLI — orchestrates capture → pipeline → display/stream.""" from __future__ import annotations import argparse import time from pathlib import Path from typing import Dict import cv2 from ultralytics import YOLO from saqr.core.camera import open_capture from saqr.core.capture import setup_capture_dirs from saqr.core.detection import set_inference_config from saqr.core.drawing import draw_counters from saqr.core.events import EventLogger, write_result_csv from saqr.core.model import resolve_model_path from saqr.core.paths import EVENTS_CSV, RESULT_CSV from saqr.core.pipeline import process_frame from saqr.core.streaming import start_stream_server, update_stream_frame from saqr.core.tracking import PersonTracker from saqr.utils.logger import get_logger log = get_logger("Inference", "saqr") def run_video(model, source, conf, capture_dirs: Dict[str, Path], show_gui, csv_every_frame, max_missing, match_distance, status_confirm_frames, stream_port=0): cap = open_capture(source) if not cap.isOpened(): log.error(f"Cannot open source: {source}") return ok, first = cap.read() if not ok or first is None or first.size == 0: log.error(f"Cannot read first frame from source: {source}") cap.release() return event_logger = EventLogger(EVENTS_CSV) tracker = PersonTracker( event_logger=event_logger, max_missing=max_missing, match_distance=match_distance, status_confirm_frames=status_confirm_frames, ) if stream_port > 0: start_stream_server(stream_port) log.info(f"Session started | source={source}") if show_gui: print("Running - press q to quit, s to save frame.") prev = time.time() frame_idx = 0 frame = first while True: frame_idx += 1 try: annotated, visible = process_frame( frame, model, tracker, frame_idx, conf, capture_dirs, write_csv=csv_every_frame, ) except Exception as e: log.exception(f"Frame error #{frame_idx}: {e}") annotated = frame visible = tracker.visible_tracks() now_t = time.time() fps = 1.0 / max(now_t - prev, 1e-9) prev = now_t draw_counters(annotated, visible, fps) if stream_port > 0: update_stream_frame(annotated) if show_gui: cv2.imshow("Saqr PPE Tracking", annotated) key = cv2.waitKey(1) & 0xFF if key == ord("q"): break if key == ord("s"): cv2.imwrite("saved_frame.jpg", annotated) log.info("Frame saved: saved_frame.jpg") ret, frame = cap.read() if not ret: break cap.release() if show_gui: cv2.destroyAllWindows() write_result_csv(list(tracker.tracks.values()), RESULT_CSV) log.info(f"Session ended | frames={frame_idx} tracks_created={tracker.next_id - 1}") def run_image(model, path, conf, capture_dirs: Dict[str, Path], show_gui): frame = cv2.imread(path) if frame is None: log.error(f"Cannot read image: {path}") return event_logger = EventLogger(EVENTS_CSV) tracker = PersonTracker(event_logger=event_logger) annotated, visible = process_frame(frame, model, tracker, 1, conf, capture_dirs) draw_counters(annotated, visible, 0.0) out = Path(path).stem + "_saqr.jpg" cv2.imwrite(out, annotated) log.info(f"Result saved: {out}") if show_gui: cv2.imshow("Saqr PPE Tracking", annotated) cv2.waitKey(0) cv2.destroyAllWindows() def main(): parser = argparse.ArgumentParser(description="Saqr PPE detection with tracking") parser.add_argument("--source", default="0", help="0/1=webcam, realsense, realsense:SERIAL, /dev/videoX, or video path") parser.add_argument("--model", default="saqr_best.pt", help="Trained YOLO weights (resolved under data/models/ by default)") parser.add_argument("--conf", type=float, default=0.35) parser.add_argument("--max-missing", type=int, default=90) parser.add_argument("--match-distance", type=float, default=250.0) parser.add_argument("--status-confirm-frames", type=int, default=5) parser.add_argument("--headless", action="store_true") parser.add_argument("--stream", type=int, default=0, metavar="PORT") parser.add_argument("--csv-on-exit", action="store_true") parser.add_argument("--device", default="0") parser.add_argument("--half", action="store_true") parser.add_argument("--imgsz", type=int, default=320) args = parser.parse_args() set_inference_config(device=args.device, half=args.half, imgsz=args.imgsz) try: import torch if torch.cuda.is_available(): log.info(f"CUDA available: {torch.cuda.get_device_name(0)} | " f"torch={torch.__version__} | cuda={torch.version.cuda}") else: log.warning("CUDA not available - running on CPU (slow)") if args.device != "cpu": log.warning(f"Falling back to CPU (you requested device={args.device})") set_inference_config(device="cpu", half=False, imgsz=args.imgsz) except ImportError: log.warning("PyTorch not found") capture_dirs = setup_capture_dirs() try: model_path = resolve_model_path(args.model) except FileNotFoundError as e: log.error(str(e)) log.error("Train first: saqr-train --dataset data/dataset") raise SystemExit(1) log.info(f"Loading model: {model_path}") model = YOLO(str(model_path)) log.info(f"Classes: {list(model.names.values())}") source = args.source is_live = (source.isdigit() or source.lower().startswith("realsense") or source.startswith("/dev/video")) is_video_file = source.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")) if is_live or is_video_file: run_video( model, source, args.conf, capture_dirs, show_gui=not args.headless, csv_every_frame=not args.csv_on_exit, max_missing=args.max_missing, match_distance=args.match_distance, status_confirm_frames=args.status_confirm_frames, stream_port=args.stream, ) elif Path(source).exists(): run_image(model, source, args.conf, capture_dirs, show_gui=not args.headless) else: log.error(f"Source not found: {source}") raise SystemExit(1) if __name__ == "__main__": main()