Saqr/saqr/apps/saqr_cli.py

194 lines
6.5 KiB
Python

"""Saqr PPE tracking CLI — orchestrates capture → pipeline → display/stream."""
from __future__ import annotations
import argparse
import time
from pathlib import Path
from typing import Dict
import cv2
from ultralytics import YOLO
from saqr.core.camera import open_capture
from saqr.core.capture import setup_capture_dirs
from saqr.core.detection import set_inference_config
from saqr.core.drawing import draw_counters
from saqr.core.events import EventLogger, write_result_csv
from saqr.core.model import resolve_model_path
from saqr.core.paths import EVENTS_CSV, RESULT_CSV
from saqr.core.pipeline import process_frame
from saqr.core.streaming import start_stream_server, update_stream_frame
from saqr.core.tracking import PersonTracker
from saqr.utils.logger import get_logger
log = get_logger("Inference", "saqr")
def run_video(model, source, conf, capture_dirs: Dict[str, Path], show_gui, csv_every_frame,
max_missing, match_distance, status_confirm_frames, stream_port=0):
cap = open_capture(source)
if not cap.isOpened():
log.error(f"Cannot open source: {source}")
return
ok, first = cap.read()
if not ok or first is None or first.size == 0:
log.error(f"Cannot read first frame from source: {source}")
cap.release()
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(
event_logger=event_logger,
max_missing=max_missing,
match_distance=match_distance,
status_confirm_frames=status_confirm_frames,
)
if stream_port > 0:
start_stream_server(stream_port)
log.info(f"Session started | source={source}")
if show_gui:
print("Running - press q to quit, s to save frame.")
prev = time.time()
frame_idx = 0
frame = first
while True:
frame_idx += 1
try:
annotated, visible = process_frame(
frame, model, tracker, frame_idx, conf,
capture_dirs, write_csv=csv_every_frame,
)
except Exception as e:
log.exception(f"Frame error #{frame_idx}: {e}")
annotated = frame
visible = tracker.visible_tracks()
now_t = time.time()
fps = 1.0 / max(now_t - prev, 1e-9)
prev = now_t
draw_counters(annotated, visible, fps)
if stream_port > 0:
update_stream_frame(annotated)
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
if key == ord("s"):
cv2.imwrite("saved_frame.jpg", annotated)
log.info("Frame saved: saved_frame.jpg")
ret, frame = cap.read()
if not ret:
break
cap.release()
if show_gui:
cv2.destroyAllWindows()
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
log.info(f"Session ended | frames={frame_idx} tracks_created={tracker.next_id - 1}")
def run_image(model, path, conf, capture_dirs: Dict[str, Path], show_gui):
frame = cv2.imread(path)
if frame is None:
log.error(f"Cannot read image: {path}")
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(event_logger=event_logger)
annotated, visible = process_frame(frame, model, tracker, 1, conf, capture_dirs)
draw_counters(annotated, visible, 0.0)
out = Path(path).stem + "_saqr.jpg"
cv2.imwrite(out, annotated)
log.info(f"Result saved: {out}")
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
parser = argparse.ArgumentParser(description="Saqr PPE detection with tracking")
parser.add_argument("--source", default="0",
help="0/1=webcam, realsense, realsense:SERIAL, /dev/videoX, or video path")
parser.add_argument("--model", default="saqr_best.pt",
help="Trained YOLO weights (resolved under data/models/ by default)")
parser.add_argument("--conf", type=float, default=0.35)
parser.add_argument("--max-missing", type=int, default=90)
parser.add_argument("--match-distance", type=float, default=250.0)
parser.add_argument("--status-confirm-frames", type=int, default=5)
parser.add_argument("--headless", action="store_true")
parser.add_argument("--stream", type=int, default=0, metavar="PORT")
parser.add_argument("--csv-on-exit", action="store_true")
parser.add_argument("--device", default="0")
parser.add_argument("--half", action="store_true")
parser.add_argument("--imgsz", type=int, default=320)
args = parser.parse_args()
set_inference_config(device=args.device, half=args.half, imgsz=args.imgsz)
try:
import torch
if torch.cuda.is_available():
log.info(f"CUDA available: {torch.cuda.get_device_name(0)} | "
f"torch={torch.__version__} | cuda={torch.version.cuda}")
else:
log.warning("CUDA not available - running on CPU (slow)")
if args.device != "cpu":
log.warning(f"Falling back to CPU (you requested device={args.device})")
set_inference_config(device="cpu", half=False, imgsz=args.imgsz)
except ImportError:
log.warning("PyTorch not found")
capture_dirs = setup_capture_dirs()
try:
model_path = resolve_model_path(args.model)
except FileNotFoundError as e:
log.error(str(e))
log.error("Train first: saqr-train --dataset data/dataset")
raise SystemExit(1)
log.info(f"Loading model: {model_path}")
model = YOLO(str(model_path))
log.info(f"Classes: {list(model.names.values())}")
source = args.source
is_live = (source.isdigit()
or source.lower().startswith("realsense")
or source.startswith("/dev/video"))
is_video_file = source.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
if is_live or is_video_file:
run_video(
model, source, args.conf, capture_dirs,
show_gui=not args.headless,
csv_every_frame=not args.csv_on_exit,
max_missing=args.max_missing,
match_distance=args.match_distance,
status_confirm_frames=args.status_confirm_frames,
stream_port=args.stream,
)
elif Path(source).exists():
run_image(model, source, args.conf, capture_dirs, show_gui=not args.headless)
else:
log.error(f"Source not found: {source}")
raise SystemExit(1)
if __name__ == "__main__":
main()