Saqr/apps/saqr_cli.py

236 lines
8.7 KiB
Python

"""Saqr PPE tracking CLI — orchestrates capture → pipeline → display/stream."""
from __future__ import annotations
import argparse
import time
from pathlib import Path
from typing import Dict
import cv2
from ultralytics import YOLO
from core.camera import RealSenseCapture, open_capture
from core.capture import setup_capture_dirs, setup_snapshot_dirs
from core.detection import set_inference_config
from core.drawing import draw_counters
from core.events import EventLogger, write_result_csv
from core.model import resolve_model_path
from core.paths import EVENTS_CSV, RESULT_CSV
from core.pipeline import process_frame
from core.streaming import start_stream_server, update_stream_frame
from core.tracking import PersonTracker
from utils.config import load_config
from utils.logger import get_logger
log = get_logger("Inference", "saqr")
_CORE = load_config("core")
def run_video(model, source, conf, capture_dirs: Dict[str, Path], show_gui,
csv_interval, max_missing, match_distance, status_confirm_frames,
*,
snapshot_dirs=None, max_distance_m: float = 0.0,
stream_port: int = 0):
cap = open_capture(source)
if not cap.isOpened():
log.error(f"Cannot open source: {source}")
return
ok, first = cap.read()
if not ok or first is None or first.size == 0:
log.error(f"Cannot read first frame from source: {source}")
cap.release()
return
is_realsense = isinstance(cap, RealSenseCapture) and cap.has_depth
depth_scale = cap.depth_scale if is_realsense else 0.001
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(
event_logger=event_logger,
max_missing=max_missing,
match_distance=match_distance,
status_confirm_frames=status_confirm_frames,
)
if stream_port > 0:
start_stream_server(stream_port)
log.info(
f"Session started | source={source} depth={is_realsense} "
f"max_distance_m={max_distance_m if max_distance_m > 0 else 'off'} "
f"csv_interval={csv_interval}"
)
if show_gui:
print("Running - press q to quit, s to save frame.")
prev = time.time()
frame_idx = 0
frame = first
while True:
frame_idx += 1
depth_frame = cap.latest_depth if is_realsense else None
write_csv_this_frame = csv_interval > 0 and (frame_idx % csv_interval == 0)
try:
annotated, visible = process_frame(
frame, model, tracker, frame_idx, conf,
capture_dirs, write_csv=write_csv_this_frame,
snapshot_dirs=snapshot_dirs,
depth_frame=depth_frame, depth_scale=depth_scale,
max_distance_m=max_distance_m,
)
except Exception as e:
log.exception(f"Frame error #{frame_idx}: {e}")
annotated = frame
visible = tracker.visible_tracks()
now_t = time.time()
fps = 1.0 / max(now_t - prev, 1e-9)
prev = now_t
draw_counters(annotated, visible, fps)
if stream_port > 0:
update_stream_frame(annotated)
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
if key == ord("s"):
cv2.imwrite("saved_frame.jpg", annotated)
log.info("Frame saved: saved_frame.jpg")
ret, frame = cap.read()
if not ret:
break
cap.release()
if show_gui:
cv2.destroyAllWindows()
# Always write final state on exit so the last tracked people are recorded.
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
log.info(f"Session ended | frames={frame_idx} tracks_created={tracker.next_id - 1}")
def run_image(model, path, conf, capture_dirs: Dict[str, Path], show_gui, snapshot_dirs=None):
frame = cv2.imread(path)
if frame is None:
log.error(f"Cannot read image: {path}")
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(event_logger=event_logger)
annotated, visible = process_frame(
frame, model, tracker, 1, conf, capture_dirs,
snapshot_dirs=snapshot_dirs,
)
draw_counters(annotated, visible, 0.0)
out = Path(path).stem + "_saqr.jpg"
cv2.imwrite(out, annotated)
log.info(f"Result saved: {out}")
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
parser = argparse.ArgumentParser(description="Saqr PPE detection with tracking")
det = _CORE["detection"]
trk = _CORE["tracking"]
cam = _CORE["camera"]
cap_cfg = _CORE["capture"]
parser.add_argument("--source", default=cam["default_source"],
help="0/1=webcam, realsense, realsense:SERIAL, /dev/videoX, or video path")
parser.add_argument("--model", default=det["default_model"],
help="Trained YOLO weights (resolved under data/models/ by default)")
parser.add_argument("--conf", type=float, default=det["conf"])
parser.add_argument("--max-missing", type=int, default=trk["max_missing"])
parser.add_argument("--match-distance", type=float, default=trk["match_distance"])
parser.add_argument("--status-confirm-frames", type=int, default=trk["status_confirm_frames"])
parser.add_argument("--max-distance-m", type=float, default=det.get("max_distance_m", 0.0),
help="RealSense-only: drop candidates beyond this depth (0 = off)")
parser.add_argument("--headless", action="store_true")
parser.add_argument("--stream", type=int, default=0, metavar="PORT")
parser.add_argument("--csv-interval", type=int, default=trk.get("csv_write_every_n_frames", 30),
help="Write result.csv every N frames (0 = only on exit)")
parser.add_argument("--csv-on-exit", action="store_true",
help="Alias for --csv-interval 0")
parser.add_argument("--no-snapshots", action="store_true",
help="Disable full-frame snapshot on transitions")
parser.add_argument("--device", default=det["device"])
parser.add_argument("--half", action="store_true", default=det["half"])
parser.add_argument("--imgsz", type=int, default=det["imgsz"])
args = parser.parse_args()
set_inference_config(device=args.device, half=args.half, imgsz=args.imgsz)
try:
import torch
if torch.cuda.is_available():
log.info(f"CUDA available: {torch.cuda.get_device_name(0)} | "
f"torch={torch.__version__} | cuda={torch.version.cuda}")
else:
log.warning("CUDA not available - running on CPU (slow)")
if args.device != "cpu":
log.warning(f"Falling back to CPU (you requested device={args.device})")
set_inference_config(device="cpu", half=False, imgsz=args.imgsz)
except ImportError:
log.warning("PyTorch not found")
capture_dirs = setup_capture_dirs()
snapshot_dirs = None
if cap_cfg.get("save_event_snapshot", True) and not args.no_snapshots:
snapshot_dirs = setup_snapshot_dirs()
csv_interval = 0 if args.csv_on_exit else max(0, args.csv_interval)
try:
model_path = resolve_model_path(args.model)
except FileNotFoundError as e:
log.error(str(e))
log.error("Train first: python -m apps.train_cli --dataset data/dataset")
raise SystemExit(1)
log.info(f"Loading model: {model_path}")
model = YOLO(str(model_path))
log.info(f"Classes: {list(model.names.values())}")
source = args.source
is_live = (source.isdigit()
or source.lower().startswith("realsense")
or source.startswith("/dev/video"))
is_video_file = source.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
if is_live or is_video_file:
run_video(
model, source, args.conf, capture_dirs,
show_gui=not args.headless,
csv_interval=csv_interval,
max_missing=args.max_missing,
match_distance=args.match_distance,
status_confirm_frames=args.status_confirm_frames,
snapshot_dirs=snapshot_dirs,
max_distance_m=args.max_distance_m,
stream_port=args.stream,
)
elif Path(source).exists():
run_image(model, source, args.conf, capture_dirs,
show_gui=not args.headless, snapshot_dirs=snapshot_dirs)
else:
log.error(f"Source not found: {source}")
raise SystemExit(1)
if __name__ == "__main__":
main()