Saqr/saqr.py
2026-04-12 19:05:32 +04:00

910 lines
32 KiB
Python

"""
Saqr - PPE Safety Tracking
===========================
Real-time PPE monitoring with person tracking.
Pipeline:
1. YOLO detection -> PPE bounding boxes (helmet, no-helmet, vest, ...)
2. Heuristic grouping -> cluster nearby PPE boxes into person candidates
3. Person tracker -> assign stable IDs across frames
4. Compliance check -> SAFE / PARTIAL / UNSAFE per person
5. Auto-capture -> save latest crop per tracked person
6. CSV logging -> result.csv (current state) + events.csv (audit log)
Compliance rules (helmet + vest focus):
SAFE = helmet AND vest detected, no violations
PARTIAL = only one of helmet / vest detected
UNSAFE = no-helmet or no-vest detected, or nothing detected
Usage:
python saqr.py --source 0 # webcam (OpenCV)
python saqr.py --source realsense # Intel RealSense D435I
python saqr.py --source 1 --model models/saqr_best.pt
python saqr.py --source video.mp4 --headless
"""
from __future__ import annotations
import argparse
import csv
import math
import shutil
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
import cv2
import numpy as np
from ultralytics import YOLO
from logger import get_logger
# Optional RealSense support
try:
import pyrealsense2 as rs
HAS_REALSENSE = True
except ImportError:
HAS_REALSENSE = False
log = get_logger("Inference", "saqr")
# ── Paths ─────────────────────────────────────────────────────────────────────
ROOT = Path(__file__).resolve().parent
CAPTURES_DIR = ROOT / "captures"
RESULT_CSV = CAPTURES_DIR / "result.csv"
EVENTS_CSV = CAPTURES_DIR / "events.csv"
# ── Colours ───────────────────────────────────────────────────────────────────
GREEN = (0, 200, 0)
YELLOW = (0, 200, 255)
RED = (0, 0, 220)
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
GRAY = (120, 120, 120)
CYAN = (200, 200, 0)
# ── PPE class definitions ────────────────────────────────────────────────────
STATUSES = ("SAFE", "PARTIAL", "UNSAFE")
CLASS_ORDER = [
"boots", "gloves", "goggles", "helmet",
"no-boots", "no-gloves", "no-goggles", "no-helmet", "no-vest", "vest",
]
PPE_SET = set(CLASS_ORDER)
# Positive -> Negative mapping
POSITIVE_TO_NEGATIVE = {
"helmet": "no-helmet",
"vest": "no-vest",
"boots": "no-boots",
"gloves": "no-gloves",
"goggles": "no-goggles",
}
PPE_DISPLAY_ORDER = ["helmet", "vest", "gloves", "goggles", "boots"]
# ── Data classes ──────────────────────────────────────────────────────────────
@dataclass
class PPEItem:
label: str
conf: float
bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
@dataclass
class PersonCandidate:
bbox: Tuple[int, int, int, int]
items: Dict[str, float] # label -> best confidence
detections: List[PPEItem] = field(default_factory=list)
@dataclass
class Track:
track_id: int
bbox: Tuple[int, int, int, int]
items: Dict[str, float]
status: str
last_seen_frame: int = 0
last_seen_iso: str = ""
created_iso: str = ""
frames_missing: int = 0
photo_path: Optional[Path] = None
announced_status: Optional[str] = None
event_count: int = 0
pending_status: Optional[str] = None
pending_count: int = 0
# ── Utilities ─────────────────────────────────────────────────────────────────
def now_iso() -> str:
return datetime.now().isoformat(timespec="seconds")
def clamp_bbox(bbox, w, h):
x1, y1, x2, y2 = bbox
return max(0, x1), max(0, y1), min(w, x2), min(h, y2)
def expand_bbox(bbox, w, h, sx=0.8, sy=1.5):
x1, y1, x2, y2 = bbox
bw, bh = x2 - x1, y2 - y1
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
nw, nh = int(bw * (1 + sx)), int(bh * (1 + sy))
nx1 = max(0, cx - nw // 2)
ny1 = max(0, cy - nh // 2)
return nx1, ny1, min(w, nx1 + nw), min(h, ny1 + nh)
def merge_boxes(a, b):
return (min(a[0], b[0]), min(a[1], b[1]), max(a[2], b[2]), max(a[3], b[3]))
def box_center(bbox):
return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
def box_distance(a, b) -> float:
ca, cb = box_center(a), box_center(b)
return math.hypot(ca[0] - cb[0], ca[1] - cb[1])
def resolve_model_path(root: Path, model_arg: str) -> Path:
"""Find model weights with fallback: arg -> root/arg -> models/arg."""
p = Path(model_arg)
if p.exists():
return p
p = root / model_arg
if p.exists():
return p
p = root / "models" / Path(model_arg).name
if p.exists():
return p
raise FileNotFoundError(f"Model not found: {model_arg}")
# ── Detection ─────────────────────────────────────────────────────────────────
# Global inference config (set by main(), read by collect_detections)
_INFER_KWARGS: Dict = {"device": "cpu", "half": False, "imgsz": 640}
def collect_detections(frame, model: YOLO, conf: float) -> List[PPEItem]:
"""Run YOLO and return only PPE-class detections."""
results = model(frame, conf=conf, verbose=False, **_INFER_KWARGS)[0]
items = []
for box in results.boxes:
cls_id = int(box.cls)
label = model.names[cls_id]
if label not in PPE_SET:
continue
x1, y1, x2, y2 = map(int, box.xyxy[0])
items.append(PPEItem(label=label, conf=float(box.conf), bbox=(x1, y1, x2, y2)))
return items
# ── Grouping: PPE items -> Person candidates ─────────────────────────────────
def should_merge(candidate: PersonCandidate, item: PPEItem) -> bool:
"""Heuristic: is this PPE item close enough to belong to the candidate?"""
cx1, cy1, cx2, cy2 = candidate.bbox
ix1, iy1, ix2, iy2 = item.bbox
cw, ch = cx2 - cx1, cy2 - cy1
iw, ih = ix2 - ix1, iy2 - iy1
cxc, cyc = (cx1 + cx2) / 2, (cy1 + cy2) / 2
ixc, iyc = (ix1 + ix2) / 2, (iy1 + iy2) / 2
max_dx = max(cw, iw) * 1.2 + 40
max_dy = max(ch, ih) * 1.8 + 50
return abs(ixc - cxc) <= max_dx and abs(iyc - cyc) <= max_dy
def group_detections_to_people(detections: List[PPEItem], w: int, h: int) -> List[PersonCandidate]:
"""Cluster PPE detections into person candidates (2-pass merge)."""
if not detections:
return []
# Pass 1: greedy grouping
candidates: List[PersonCandidate] = []
for item in detections:
merged = False
for cand in candidates:
if should_merge(cand, item):
cand.bbox = merge_boxes(cand.bbox, item.bbox)
cand.items[item.label] = max(cand.items.get(item.label, 0.0), item.conf)
cand.detections.append(item)
merged = True
break
if not merged:
candidates.append(PersonCandidate(
bbox=item.bbox,
items={item.label: item.conf},
detections=[item],
))
# Pass 2: merge nearby candidates
again = True
while again:
again = False
merged_list: List[PersonCandidate] = []
for person in candidates:
matched = False
for prev in merged_list:
pw = prev.bbox[2] - prev.bbox[0]
ph = prev.bbox[3] - prev.bbox[1]
dist = box_distance(prev.bbox, person.bbox)
th = max(pw, ph) * 0.55
if dist <= th:
prev.bbox = merge_boxes(prev.bbox, person.bbox)
for label, conf in person.items.items():
prev.items[label] = max(prev.items.get(label, 0.0), conf)
prev.detections.extend(person.detections)
again = True
matched = True
break
if not matched:
merged_list.append(person)
candidates = merged_list
# Expand each person bbox for better crop coverage
for cand in candidates:
cand.bbox = expand_bbox(cand.bbox, w, h)
return candidates
# ── Status logic (helmet + vest focus) ────────────────────────────────────────
def status_from_items(items: Dict[str, float]) -> str:
has_helmet = items.get("helmet", 0.0) > items.get("no-helmet", 0.0) and items.get("helmet", 0.0) > 0
has_vest = items.get("vest", 0.0) > items.get("no-vest", 0.0) and items.get("vest", 0.0) > 0
no_helmet = items.get("no-helmet", 0.0) > 0
no_vest = items.get("no-vest", 0.0) > 0
if no_helmet or no_vest:
return "UNSAFE"
if has_helmet and has_vest:
return "SAFE"
if has_helmet or has_vest:
return "PARTIAL"
return "UNSAFE"
def split_wearing_missing(items: Dict[str, float]) -> Tuple[List[str], List[str], List[str]]:
wearing, missing, unknown = [], [], []
for pos in PPE_DISPLAY_ORDER:
neg = POSITIVE_TO_NEGATIVE[pos]
pos_conf = items.get(pos, 0.0)
neg_conf = items.get(neg, 0.0)
if pos_conf > neg_conf and pos_conf > 0:
wearing.append(pos)
elif neg_conf >= pos_conf and neg_conf > 0:
missing.append(pos)
else:
unknown.append(pos)
return wearing, missing, unknown
# ── CSV Writers ───────────────────────────────────────────────────────────────
class EventLogger:
FIELDS = ["timestamp", "track_id", "event_type", "status",
"wearing", "missing", "unknown", "photo", "path"]
def __init__(self, path: Path):
self.path = path
self.path.parent.mkdir(parents=True, exist_ok=True)
if not self.path.exists():
with open(self.path, "w", newline="", encoding="utf-8") as f:
csv.DictWriter(f, fieldnames=self.FIELDS).writeheader()
def append(self, row: Dict[str, str]) -> None:
with open(self.path, "a", newline="", encoding="utf-8") as f:
csv.DictWriter(f, fieldnames=self.FIELDS).writerow(row)
def write_result_csv(tracks: List[Track], output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
fields = ["photo", "track_id", "status", "last_seen",
"wearing", "missing", "unknown", *CLASS_ORDER, "path"]
rows = []
for track in sorted(tracks, key=lambda t: t.track_id):
wearing, missing, unknown = split_wearing_missing(track.items)
row = {
"photo": track.photo_path.name if track.photo_path else "",
"track_id": track.track_id,
"status": track.status,
"last_seen": track.last_seen_iso,
"wearing": ", ".join(wearing),
"missing": ", ".join(missing),
"unknown": ", ".join(unknown),
"path": str(track.photo_path) if track.photo_path else "",
}
for cls in CLASS_ORDER:
row[cls] = 1 if track.items.get(cls, 0.0) > 0 else 0
rows.append(row)
with open(output, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fields)
w.writeheader()
w.writerows(rows)
# ── Person Tracker ────────────────────────────────────────────────────────────
class PersonTracker:
def __init__(
self,
event_logger: EventLogger,
max_missing: int = 90,
match_distance: float = 250.0,
status_confirm_frames: int = 5,
):
self.event_logger = event_logger
self.max_missing = max_missing
self.match_distance = match_distance
self.status_confirm_frames = max(1, status_confirm_frames)
self.tracks: Dict[int, Track] = {}
self.next_id = 1
def _new_track(self, person: PersonCandidate, frame_idx: int) -> Track:
track = Track(
track_id=self.next_id,
bbox=person.bbox,
items=dict(person.items),
status=status_from_items(person.items),
last_seen_frame=frame_idx,
last_seen_iso=now_iso(),
created_iso=now_iso(),
)
self.next_id += 1
self.tracks[track.track_id] = track
return track
def _match(self, person: PersonCandidate, used: set[int]) -> Optional[Track]:
best, best_dist = None, float("inf")
for tid, track in self.tracks.items():
if tid in used:
continue
dist = box_distance(track.bbox, person.bbox)
if dist < best_dist and dist <= self.match_distance:
best_dist = dist
best = track
return best
def update(self, people: List[PersonCandidate], frame_idx: int):
used: set[int] = set()
created: List[Track] = []
changed: List[Track] = []
for person in people:
track = self._match(person, used)
if track is None:
track = self._new_track(person, frame_idx)
created.append(track)
else:
new_status = status_from_items(person.items)
track.bbox = person.bbox
track.items = dict(person.items)
track.last_seen_frame = frame_idx
track.last_seen_iso = now_iso()
track.frames_missing = 0
if new_status != track.status:
if track.pending_status == new_status:
track.pending_count += 1
else:
track.pending_status = new_status
track.pending_count = 1
if track.pending_count >= self.status_confirm_frames:
track.status = new_status
track.pending_status = None
track.pending_count = 0
changed.append(track)
else:
track.pending_status = None
track.pending_count = 0
used.add(track.track_id)
# Age and prune missing tracks
stale = []
for tid, track in self.tracks.items():
if tid not in used:
track.frames_missing += 1
if track.frames_missing > self.max_missing:
stale.append(tid)
for tid in stale:
del self.tracks[tid]
return created, changed
def visible_tracks(self) -> List[Track]:
return [t for t in self.tracks.values() if t.frames_missing == 0]
# ── Track image + event ───────────────────────────────────────────────────────
def save_track_image(frame, track: Track, capture_dirs: Dict[str, Path]) -> Optional[Path]:
h, w = frame.shape[:2]
x1, y1, x2, y2 = clamp_bbox(track.bbox, w, h)
if x2 <= x1 or y2 <= y1:
return None
crop = frame[y1:y2, x1:x2]
if crop.size == 0:
return None
target = capture_dirs[track.status] / f"track_{track.track_id:04d}.jpg"
# Move old image if status folder changed
if track.photo_path and track.photo_path != target and track.photo_path.exists():
try:
track.photo_path.unlink()
except OSError:
pass
cv2.imwrite(str(target), crop)
track.photo_path = target
return target
def emit_event(
track: Track,
event_logger: EventLogger,
event_type: str = "STATUS_CHANGE",
force: bool = False,
) -> None:
if track.photo_path is None:
return
if not force and track.announced_status == track.status:
return
wearing, missing, unknown = split_wearing_missing(track.items)
msg = (
f"ID {track.track_id:04d} | {event_type} | {track.status} | "
f"wearing: {', '.join(wearing) or 'none'} | "
f"missing: {', '.join(missing) or 'none'} | "
f"unknown: {', '.join(unknown) or 'none'}"
)
print(msg, flush=True)
event_logger.append({
"timestamp": now_iso(),
"track_id": str(track.track_id),
"event_type": event_type,
"status": track.status,
"wearing": ", ".join(wearing),
"missing": ", ".join(missing),
"unknown": ", ".join(unknown),
"photo": track.photo_path.name if track.photo_path else "",
"path": str(track.photo_path) if track.photo_path else "",
})
track.announced_status = track.status
track.event_count += 1
# ── Drawing ───────────────────────────────────────────────────────────────────
def status_color(status: str) -> Tuple:
return {"SAFE": GREEN, "PARTIAL": YELLOW, "UNSAFE": RED}.get(status, GRAY)
def draw_track(frame, track: Track):
x1, y1, x2, y2 = track.bbox
color = status_color(track.status)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
wearing, missing, unknown = split_wearing_missing(track.items)
line1 = f"ID {track.track_id:04d} {track.status}"
w_str = ", ".join(wearing) if wearing else "none"
m_str = ", ".join(missing) if missing else "-"
line2 = f"W:{w_str} M:{m_str}"
(tw1, th1), _ = cv2.getTextSize(line1, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
(tw2, th2), _ = cv2.getTextSize(line2, cv2.FONT_HERSHEY_SIMPLEX, 0.40, 1)
tw = max(tw1, tw2) + 8
total_h = th1 + th2 + 12
y_top = max(0, y1 - total_h - 2)
cv2.rectangle(frame, (x1, y_top), (x1 + tw, y1), color, -1)
cv2.putText(frame, line1, (x1 + 4, y_top + th1 + 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, WHITE, 1, cv2.LINE_AA)
cv2.putText(frame, line2, (x1 + 4, y_top + th1 + th2 + 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.40, WHITE, 1, cv2.LINE_AA)
def draw_counters(frame, tracks: List[Track], fps: float):
counts = {s: 0 for s in STATUSES}
for t in tracks:
counts[t.status] += 1
lines = [
(f"FPS: {fps:.1f}", WHITE),
(f"SAFE {counts['SAFE']}", GREEN),
(f"PARTIAL {counts['PARTIAL']}", YELLOW),
(f"UNSAFE {counts['UNSAFE']}", RED),
(f"TRACKS {len(tracks)}", CYAN),
]
y = 24
for text, color in lines:
cv2.putText(frame, text, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, BLACK, 4, cv2.LINE_AA)
cv2.putText(frame, text, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
y += 28
# ── Frame processing ──────────────────────────────────────────────────────────
def process_frame(
frame,
model: YOLO,
tracker: PersonTracker,
frame_idx: int,
conf: float,
capture_dirs: Dict[str, Path],
write_csv: bool = True,
):
annotated = frame.copy()
h, w = annotated.shape[:2]
detections = collect_detections(frame, model, conf)
candidates = group_detections_to_people(detections, w, h)
created, changed = tracker.update(candidates, frame_idx)
visible = tracker.visible_tracks()
created_ids = {t.track_id for t in created}
changed_ids = {t.track_id for t in changed}
event_ids = created_ids | changed_ids
for track in visible:
save_track_image(frame, track, capture_dirs)
if track.track_id in event_ids:
ev_type = "NEW" if track.track_id in created_ids else "STATUS_CHANGE"
emit_event(track, tracker.event_logger, ev_type)
draw_track(annotated, track)
if write_csv:
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
return annotated, visible
# ── MJPEG Stream Server (view on laptop browser) ─────────────────────────────
_stream_frame: Optional[bytes] = None
_stream_lock = threading.Lock()
class MJPEGHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(b'<html><body style="margin:0;background:#000">'
b'<img src="/stream" style="width:100%;height:auto">'
b'</body></html>')
elif self.path == "/stream":
self.send_response(200)
self.send_header("Content-Type", "multipart/x-mixed-replace; boundary=frame")
self.end_headers()
while True:
with _stream_lock:
jpeg = _stream_frame
if jpeg is None:
time.sleep(0.03)
continue
try:
self.wfile.write(b"--frame\r\n"
b"Content-Type: image/jpeg\r\n\r\n" + jpeg + b"\r\n")
except BrokenPipeError:
break
else:
self.send_error(404)
def log_message(self, format, *args):
pass # silence per-request logs
def start_stream_server(port: int = 8080):
server = HTTPServer(("0.0.0.0", port), MJPEGHandler)
t = threading.Thread(target=server.serve_forever, daemon=True)
t.start()
log.info(f"MJPEG stream server started on http://0.0.0.0:{port}")
return server
def update_stream_frame(frame):
global _stream_frame
_, jpeg = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 70])
with _stream_lock:
_stream_frame = jpeg.tobytes()
# ── Camera / video ────────────────────────────────────────────────────────────
class RealSenseCapture:
"""Wraps pyrealsense2 pipeline with an OpenCV-like read() interface."""
def __init__(self, width: int = 640, height: int = 480, fps: int = 30,
serial: Optional[str] = None):
if not HAS_REALSENSE:
raise RuntimeError("pyrealsense2 not installed")
self.pipeline = rs.pipeline()
cfg = rs.config()
if serial:
cfg.enable_device(serial)
cfg.enable_stream(rs.stream.color, width, height, rs.format.bgr8, fps)
self.profile = self.pipeline.start(cfg)
self._open = True
dev = self.profile.get_device()
log.info(f"RealSense opened | {dev.get_info(rs.camera_info.name)} "
f"serial={dev.get_info(rs.camera_info.serial_number)} "
f"{width}x{height}@{fps}")
def isOpened(self) -> bool:
return self._open
def read(self):
if not self._open:
return False, None
try:
frames = self.pipeline.wait_for_frames(timeout_ms=3000)
color = frames.get_color_frame()
if not color:
return False, None
return True, np.asanyarray(color.get_data())
except Exception:
return False, None
def release(self):
if self._open:
self.pipeline.stop()
self._open = False
def open_capture(source: str):
# RealSense source: "realsense" or "realsense:SERIAL"
if source.lower().startswith("realsense"):
serial = None
if ":" in source:
serial = source.split(":", 1)[1]
return RealSenseCapture(width=640, height=480, fps=30, serial=serial)
if str(source).isdigit():
idx = int(source)
cap = cv2.VideoCapture(idx)
if cap.isOpened():
return cap
cap = cv2.VideoCapture(idx, cv2.CAP_ANY)
if cap.isOpened():
return cap
cap = cv2.VideoCapture(idx, cv2.CAP_V4L2)
return cap
# V4L2 device path
if source.startswith("/dev/video"):
cap = cv2.VideoCapture(source, cv2.CAP_V4L2)
if cap.isOpened():
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
cap.set(cv2.CAP_PROP_FPS, 30)
return cap
return cv2.VideoCapture(source)
def setup_capture_dirs(base: Path) -> Dict[str, Path]:
dirs = {}
for s in STATUSES:
d = base / "captures" / s
d.mkdir(parents=True, exist_ok=True)
dirs[s] = d
return dirs
def run_video(
model: YOLO,
source: str,
conf: float,
capture_dirs: Dict[str, Path],
show_gui: bool,
csv_every_frame: bool,
max_missing: int,
match_distance: float,
status_confirm_frames: int,
stream_port: int = 0,
) -> None:
cap = open_capture(source)
if not cap.isOpened():
log.error(f"Cannot open source: {source}")
return
ok, first = cap.read()
if not ok or first is None or first.size == 0:
log.error(f"Cannot read first frame from source: {source}")
cap.release()
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(
event_logger=event_logger,
max_missing=max_missing,
match_distance=match_distance,
status_confirm_frames=status_confirm_frames,
)
# Start MJPEG stream server if requested
if stream_port > 0:
start_stream_server(stream_port)
log.info(f"Session started | source={source}")
if show_gui:
print("Running - press q to quit, s to save frame.")
prev = time.time()
frame_idx = 0
frame = first
while True:
frame_idx += 1
try:
annotated, visible = process_frame(
frame, model, tracker, frame_idx, conf,
capture_dirs, write_csv=csv_every_frame,
)
except Exception as e:
log.exception(f"Frame error #{frame_idx}: {e}")
annotated = frame
visible = tracker.visible_tracks()
now_t = time.time()
fps = 1.0 / max(now_t - prev, 1e-9)
prev = now_t
draw_counters(annotated, visible, fps)
# Send to stream
if stream_port > 0:
update_stream_frame(annotated)
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
if key == ord("s"):
cv2.imwrite("saved_frame.jpg", annotated)
log.info("Frame saved: saved_frame.jpg")
ret, frame = cap.read()
if not ret:
break
cap.release()
if show_gui:
cv2.destroyAllWindows()
# Final CSV write
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
log.info(f"Session ended | frames={frame_idx} tracks_created={tracker.next_id - 1}")
def run_image(model: YOLO, path: str, conf: float, capture_dirs: Dict[str, Path], show_gui: bool):
frame = cv2.imread(path)
if frame is None:
log.error(f"Cannot read image: {path}")
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(event_logger=event_logger)
annotated, visible = process_frame(frame, model, tracker, 1, conf, capture_dirs)
draw_counters(annotated, visible, 0.0)
out = Path(path).stem + "_saqr.jpg"
cv2.imwrite(out, annotated)
log.info(f"Result saved: {out}")
if show_gui:
cv2.imshow("Saqr PPE Tracking", annotated)
cv2.waitKey(0)
cv2.destroyAllWindows()
# ── CLI ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Saqr PPE detection with tracking")
parser.add_argument("--source", default="0",
help="0/1=webcam, realsense, realsense:SERIAL, /dev/videoX, or video path")
parser.add_argument("--model", default="models/saqr_best.pt",
help="Trained YOLO weights")
parser.add_argument("--conf", type=float, default=0.35,
help="Detection confidence threshold")
parser.add_argument("--max-missing", type=int, default=90,
help="Frames to keep a lost track alive")
parser.add_argument("--match-distance", type=float, default=250.0,
help="Max pixel distance for track matching")
parser.add_argument("--status-confirm-frames", type=int, default=5,
help="Frames needed to confirm a status change")
parser.add_argument("--headless", action="store_true",
help="Disable OpenCV GUI window")
parser.add_argument("--stream", type=int, default=0, metavar="PORT",
help="Start MJPEG stream on this port (e.g. --stream 8080)")
parser.add_argument("--csv-on-exit", action="store_true",
help="Write result.csv only at session end")
# GPU / inference tuning
parser.add_argument("--device", default="0",
help="Device: 'cpu', '0' (first GPU), 'cuda:0', etc.")
parser.add_argument("--half", action="store_true",
help="Enable FP16 inference (Jetson / RTX GPUs)")
parser.add_argument("--imgsz", type=int, default=320,
help="Inference image size (320 fast, 640 accurate)")
args = parser.parse_args()
# ── Configure global inference kwargs ────────────────────────────────
global _INFER_KWARGS
_INFER_KWARGS = {
"device": args.device,
"half": args.half,
"imgsz": args.imgsz,
}
# ── Log CUDA status ──────────────────────────────────────────────────
try:
import torch
if torch.cuda.is_available():
dev_name = torch.cuda.get_device_name(0)
log.info(f"CUDA available: {dev_name} | torch={torch.__version__} | "
f"cuda={torch.version.cuda}")
else:
log.warning("CUDA not available - running on CPU (slow)")
if args.device != "cpu":
log.warning(f"Falling back to CPU (you requested device={args.device})")
_INFER_KWARGS["device"] = "cpu"
_INFER_KWARGS["half"] = False
except ImportError:
log.warning("PyTorch not found")
log.info(f"Inference config: device={_INFER_KWARGS['device']} "
f"half={_INFER_KWARGS['half']} imgsz={_INFER_KWARGS['imgsz']}")
capture_dirs = setup_capture_dirs(ROOT)
try:
model_path = resolve_model_path(ROOT, args.model)
except FileNotFoundError as e:
log.error(str(e))
log.error("Train first: python train.py --dataset dataset")
raise SystemExit(1)
log.info(f"Loading model: {model_path}")
model = YOLO(str(model_path))
log.info(f"Classes: {list(model.names.values())}")
source = args.source
is_live = (
source.isdigit()
or source.lower().startswith("realsense")
or source.startswith("/dev/video")
)
is_video_file = source.lower().endswith(
(".mp4", ".avi", ".mov", ".mkv", ".webm")
)
if is_live or is_video_file:
run_video(
model, source, args.conf, capture_dirs,
show_gui=not args.headless,
csv_every_frame=not args.csv_on_exit,
max_missing=args.max_missing,
match_distance=args.match_distance,
status_confirm_frames=args.status_confirm_frames,
stream_port=args.stream,
)
elif Path(source).exists():
run_image(model, source, args.conf, capture_dirs, show_gui=not args.headless)
else:
log.error(f"Source not found: {source}")
raise SystemExit(1)
if __name__ == "__main__":
main()