236 lines
8.7 KiB
Python
236 lines
8.7 KiB
Python
"""Saqr PPE tracking CLI — orchestrates capture → pipeline → display/stream."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
from core.camera import RealSenseCapture, open_capture
|
|
from core.capture import setup_capture_dirs, setup_snapshot_dirs
|
|
from core.detection import set_inference_config
|
|
from core.drawing import draw_counters
|
|
from core.events import EventLogger, write_result_csv
|
|
from core.model import resolve_model_path
|
|
from core.paths import EVENTS_CSV, RESULT_CSV
|
|
from core.pipeline import process_frame
|
|
from core.streaming import start_stream_server, update_stream_frame
|
|
from core.tracking import PersonTracker
|
|
from utils.config import load_config
|
|
from utils.logger import get_logger
|
|
|
|
log = get_logger("Inference", "saqr")
|
|
_CORE = load_config("core")
|
|
|
|
|
|
def run_video(model, source, conf, capture_dirs: Dict[str, Path], show_gui,
|
|
csv_interval, max_missing, match_distance, status_confirm_frames,
|
|
*,
|
|
snapshot_dirs=None, max_distance_m: float = 0.0,
|
|
stream_port: int = 0):
|
|
cap = open_capture(source)
|
|
if not cap.isOpened():
|
|
log.error(f"Cannot open source: {source}")
|
|
return
|
|
|
|
ok, first = cap.read()
|
|
if not ok or first is None or first.size == 0:
|
|
log.error(f"Cannot read first frame from source: {source}")
|
|
cap.release()
|
|
return
|
|
|
|
is_realsense = isinstance(cap, RealSenseCapture) and cap.has_depth
|
|
depth_scale = cap.depth_scale if is_realsense else 0.001
|
|
|
|
event_logger = EventLogger(EVENTS_CSV)
|
|
tracker = PersonTracker(
|
|
event_logger=event_logger,
|
|
max_missing=max_missing,
|
|
match_distance=match_distance,
|
|
status_confirm_frames=status_confirm_frames,
|
|
)
|
|
|
|
if stream_port > 0:
|
|
start_stream_server(stream_port)
|
|
|
|
log.info(
|
|
f"Session started | source={source} depth={is_realsense} "
|
|
f"max_distance_m={max_distance_m if max_distance_m > 0 else 'off'} "
|
|
f"csv_interval={csv_interval}"
|
|
)
|
|
if show_gui:
|
|
print("Running - press q to quit, s to save frame.")
|
|
|
|
prev = time.time()
|
|
frame_idx = 0
|
|
frame = first
|
|
|
|
while True:
|
|
frame_idx += 1
|
|
depth_frame = cap.latest_depth if is_realsense else None
|
|
write_csv_this_frame = csv_interval > 0 and (frame_idx % csv_interval == 0)
|
|
|
|
try:
|
|
annotated, visible = process_frame(
|
|
frame, model, tracker, frame_idx, conf,
|
|
capture_dirs, write_csv=write_csv_this_frame,
|
|
snapshot_dirs=snapshot_dirs,
|
|
depth_frame=depth_frame, depth_scale=depth_scale,
|
|
max_distance_m=max_distance_m,
|
|
)
|
|
except Exception as e:
|
|
log.exception(f"Frame error #{frame_idx}: {e}")
|
|
annotated = frame
|
|
visible = tracker.visible_tracks()
|
|
|
|
now_t = time.time()
|
|
fps = 1.0 / max(now_t - prev, 1e-9)
|
|
prev = now_t
|
|
|
|
draw_counters(annotated, visible, fps)
|
|
|
|
if stream_port > 0:
|
|
update_stream_frame(annotated)
|
|
|
|
if show_gui:
|
|
cv2.imshow("Saqr PPE Tracking", annotated)
|
|
key = cv2.waitKey(1) & 0xFF
|
|
if key == ord("q"):
|
|
break
|
|
if key == ord("s"):
|
|
cv2.imwrite("saved_frame.jpg", annotated)
|
|
log.info("Frame saved: saved_frame.jpg")
|
|
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
cap.release()
|
|
if show_gui:
|
|
cv2.destroyAllWindows()
|
|
|
|
# Always write final state on exit so the last tracked people are recorded.
|
|
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
|
|
log.info(f"Session ended | frames={frame_idx} tracks_created={tracker.next_id - 1}")
|
|
|
|
|
|
def run_image(model, path, conf, capture_dirs: Dict[str, Path], show_gui, snapshot_dirs=None):
|
|
frame = cv2.imread(path)
|
|
if frame is None:
|
|
log.error(f"Cannot read image: {path}")
|
|
return
|
|
|
|
event_logger = EventLogger(EVENTS_CSV)
|
|
tracker = PersonTracker(event_logger=event_logger)
|
|
|
|
annotated, visible = process_frame(
|
|
frame, model, tracker, 1, conf, capture_dirs,
|
|
snapshot_dirs=snapshot_dirs,
|
|
)
|
|
draw_counters(annotated, visible, 0.0)
|
|
|
|
out = Path(path).stem + "_saqr.jpg"
|
|
cv2.imwrite(out, annotated)
|
|
log.info(f"Result saved: {out}")
|
|
|
|
if show_gui:
|
|
cv2.imshow("Saqr PPE Tracking", annotated)
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Saqr PPE detection with tracking")
|
|
det = _CORE["detection"]
|
|
trk = _CORE["tracking"]
|
|
cam = _CORE["camera"]
|
|
cap_cfg = _CORE["capture"]
|
|
|
|
parser.add_argument("--source", default=cam["default_source"],
|
|
help="0/1=webcam, realsense, realsense:SERIAL, /dev/videoX, or video path")
|
|
parser.add_argument("--model", default=det["default_model"],
|
|
help="Trained YOLO weights (resolved under data/models/ by default)")
|
|
parser.add_argument("--conf", type=float, default=det["conf"])
|
|
parser.add_argument("--max-missing", type=int, default=trk["max_missing"])
|
|
parser.add_argument("--match-distance", type=float, default=trk["match_distance"])
|
|
parser.add_argument("--status-confirm-frames", type=int, default=trk["status_confirm_frames"])
|
|
parser.add_argument("--max-distance-m", type=float, default=det.get("max_distance_m", 0.0),
|
|
help="RealSense-only: drop candidates beyond this depth (0 = off)")
|
|
parser.add_argument("--headless", action="store_true")
|
|
parser.add_argument("--stream", type=int, default=0, metavar="PORT")
|
|
parser.add_argument("--csv-interval", type=int, default=trk.get("csv_write_every_n_frames", 30),
|
|
help="Write result.csv every N frames (0 = only on exit)")
|
|
parser.add_argument("--csv-on-exit", action="store_true",
|
|
help="Alias for --csv-interval 0")
|
|
parser.add_argument("--no-snapshots", action="store_true",
|
|
help="Disable full-frame snapshot on transitions")
|
|
parser.add_argument("--device", default=det["device"])
|
|
parser.add_argument("--half", action="store_true", default=det["half"])
|
|
parser.add_argument("--imgsz", type=int, default=det["imgsz"])
|
|
args = parser.parse_args()
|
|
|
|
set_inference_config(device=args.device, half=args.half, imgsz=args.imgsz)
|
|
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
log.info(f"CUDA available: {torch.cuda.get_device_name(0)} | "
|
|
f"torch={torch.__version__} | cuda={torch.version.cuda}")
|
|
else:
|
|
log.warning("CUDA not available - running on CPU (slow)")
|
|
if args.device != "cpu":
|
|
log.warning(f"Falling back to CPU (you requested device={args.device})")
|
|
set_inference_config(device="cpu", half=False, imgsz=args.imgsz)
|
|
except ImportError:
|
|
log.warning("PyTorch not found")
|
|
|
|
capture_dirs = setup_capture_dirs()
|
|
snapshot_dirs = None
|
|
if cap_cfg.get("save_event_snapshot", True) and not args.no_snapshots:
|
|
snapshot_dirs = setup_snapshot_dirs()
|
|
|
|
csv_interval = 0 if args.csv_on_exit else max(0, args.csv_interval)
|
|
|
|
try:
|
|
model_path = resolve_model_path(args.model)
|
|
except FileNotFoundError as e:
|
|
log.error(str(e))
|
|
log.error("Train first: python -m apps.train_cli --dataset data/dataset")
|
|
raise SystemExit(1)
|
|
|
|
log.info(f"Loading model: {model_path}")
|
|
model = YOLO(str(model_path))
|
|
log.info(f"Classes: {list(model.names.values())}")
|
|
|
|
source = args.source
|
|
is_live = (source.isdigit()
|
|
or source.lower().startswith("realsense")
|
|
or source.startswith("/dev/video"))
|
|
is_video_file = source.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
|
|
|
|
if is_live or is_video_file:
|
|
run_video(
|
|
model, source, args.conf, capture_dirs,
|
|
show_gui=not args.headless,
|
|
csv_interval=csv_interval,
|
|
max_missing=args.max_missing,
|
|
match_distance=args.match_distance,
|
|
status_confirm_frames=args.status_confirm_frames,
|
|
snapshot_dirs=snapshot_dirs,
|
|
max_distance_m=args.max_distance_m,
|
|
stream_port=args.stream,
|
|
)
|
|
elif Path(source).exists():
|
|
run_image(model, source, args.conf, capture_dirs,
|
|
show_gui=not args.headless, snapshot_dirs=snapshot_dirs)
|
|
else:
|
|
log.error(f"Source not found: {source}")
|
|
raise SystemExit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|