479 lines
17 KiB
Python
479 lines
17 KiB
Python
"""Saqr PySide6 desktop GUI for live PPE compliance monitoring."""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PySide6.QtCore import Qt, QThread, Signal, Slot
|
|
from PySide6.QtGui import QColor, QImage, QPixmap
|
|
from PySide6.QtWidgets import (
|
|
QApplication, QComboBox, QDoubleSpinBox, QFileDialog, QGridLayout,
|
|
QGroupBox, QHBoxLayout, QLabel, QMainWindow, QMessageBox, QPushButton,
|
|
QSpinBox, QTextEdit, QVBoxLayout, QWidget,
|
|
)
|
|
from ultralytics import YOLO
|
|
|
|
from saqr.core.capture import save_track_image, setup_capture_dirs
|
|
from saqr.core.compliance import split_wearing_missing
|
|
from saqr.core.detection import STATUSES, collect_detections
|
|
from saqr.core.drawing import draw_counters, draw_track
|
|
from saqr.core.events import EventLogger, emit_event, write_result_csv
|
|
from saqr.core.grouping import group_detections_to_people
|
|
from saqr.core.model import resolve_model_path
|
|
from saqr.core.paths import EVENTS_CSV, MODELS_DIR, PROJECT_ROOT, RESULT_CSV
|
|
from saqr.core.tracking import PersonTracker
|
|
from saqr.utils.logger import get_logger
|
|
|
|
log = get_logger("Inference", "gui")
|
|
|
|
|
|
def list_cameras(max_idx: int = 10) -> List[str]:
|
|
sources: List[str] = []
|
|
for i in range(max_idx):
|
|
dev = f"/dev/video{i}"
|
|
if Path(dev).exists():
|
|
sources.append(dev)
|
|
if not sources:
|
|
for i in range(4):
|
|
cap = cv2.VideoCapture(i)
|
|
if cap.isOpened():
|
|
sources.append(str(i))
|
|
cap.release()
|
|
return sources if sources else ["0"]
|
|
|
|
|
|
def open_camera(source: str, width: int = 640, height: int = 480, fps: int = 30):
|
|
if source.startswith("/dev/video"):
|
|
cap = cv2.VideoCapture(source, cv2.CAP_V4L2)
|
|
elif source.isdigit():
|
|
cap = cv2.VideoCapture(int(source))
|
|
else:
|
|
cap = cv2.VideoCapture(source)
|
|
|
|
if cap.isOpened():
|
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
|
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
|
cap.set(cv2.CAP_PROP_FPS, fps)
|
|
return cap
|
|
|
|
|
|
class DetectionWorker(QThread):
|
|
frame_ready = Signal(np.ndarray, list)
|
|
event_fired = Signal(str)
|
|
stats_updated = Signal(dict)
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._running = False
|
|
self.model: Optional[YOLO] = None
|
|
self.source = "0"
|
|
self.conf = 0.35
|
|
self.max_missing = 90
|
|
self.match_distance = 250.0
|
|
self.status_confirm = 5
|
|
self.capture_dirs: Dict[str, Path] = {}
|
|
|
|
def configure(self, model_path: str, source: str, conf: float,
|
|
max_missing: int, match_dist: float, status_confirm: int):
|
|
self.source = source
|
|
self.conf = conf
|
|
self.max_missing = max_missing
|
|
self.match_distance = match_dist
|
|
self.status_confirm = status_confirm
|
|
self.capture_dirs = setup_capture_dirs()
|
|
if self.model is None or str(model_path) != getattr(self, "_last_model", ""):
|
|
self.model = YOLO(model_path)
|
|
self._last_model = str(model_path)
|
|
|
|
def run(self):
|
|
self._running = True
|
|
cap = open_camera(self.source)
|
|
if not cap.isOpened():
|
|
self.event_fired.emit(f"[ERROR] Cannot open camera: {self.source}")
|
|
return
|
|
|
|
ok, first = cap.read()
|
|
if not ok:
|
|
self.event_fired.emit("[ERROR] Cannot read first frame")
|
|
cap.release()
|
|
return
|
|
|
|
event_logger = EventLogger(EVENTS_CSV)
|
|
tracker = PersonTracker(
|
|
event_logger=event_logger,
|
|
max_missing=self.max_missing,
|
|
match_distance=self.match_distance,
|
|
status_confirm_frames=self.status_confirm,
|
|
)
|
|
|
|
self.event_fired.emit(f"Session started | source={self.source}")
|
|
prev = time.time()
|
|
frame_idx = 0
|
|
frame = first
|
|
|
|
while self._running:
|
|
frame_idx += 1
|
|
h, w = frame.shape[:2]
|
|
annotated = frame.copy()
|
|
|
|
try:
|
|
detections = collect_detections(frame, self.model, self.conf)
|
|
candidates = group_detections_to_people(detections, w, h)
|
|
created, changed = tracker.update(candidates, frame_idx)
|
|
visible = tracker.visible_tracks()
|
|
|
|
created_ids = {t.track_id for t in created}
|
|
changed_ids = {t.track_id for t in changed}
|
|
event_ids = created_ids | changed_ids
|
|
|
|
for track in visible:
|
|
save_track_image(frame, track, self.capture_dirs)
|
|
if track.track_id in event_ids:
|
|
ev_type = "NEW" if track.track_id in created_ids else "STATUS_CHANGE"
|
|
wearing, missing, _unknown = split_wearing_missing(track.items)
|
|
msg = (
|
|
f"ID {track.track_id:04d} | {ev_type} | {track.status} | "
|
|
f"W: {', '.join(wearing) or 'none'} | "
|
|
f"M: {', '.join(missing) or 'none'}"
|
|
)
|
|
self.event_fired.emit(msg)
|
|
emit_event(track, event_logger, ev_type)
|
|
draw_track(annotated, track)
|
|
|
|
if frame_idx % 30 == 0:
|
|
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
|
|
|
|
except Exception as e:
|
|
self.event_fired.emit(f"[ERROR] Frame {frame_idx}: {e}")
|
|
visible = tracker.visible_tracks()
|
|
|
|
now_t = time.time()
|
|
fps = 1.0 / max(now_t - prev, 1e-9)
|
|
prev = now_t
|
|
|
|
draw_counters(annotated, visible, fps)
|
|
|
|
counts = {s: 0 for s in STATUSES}
|
|
for t in visible:
|
|
counts[t.status] += 1
|
|
counts["fps"] = fps
|
|
counts["tracks"] = len(visible)
|
|
|
|
self.frame_ready.emit(annotated, visible)
|
|
self.stats_updated.emit(counts)
|
|
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
cap.release()
|
|
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
|
|
self.event_fired.emit("Session ended.")
|
|
|
|
def stop(self):
|
|
self._running = False
|
|
|
|
|
|
def cv_to_qpixmap(frame: np.ndarray, max_w: int = 960, max_h: int = 720) -> QPixmap:
|
|
h, w, ch = frame.shape
|
|
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
qimg = QImage(rgb.data, w, h, ch * w, QImage.Format.Format_RGB888)
|
|
pix = QPixmap.fromImage(qimg)
|
|
return pix.scaled(max_w, max_h, Qt.AspectRatioMode.KeepAspectRatio,
|
|
Qt.TransformationMode.SmoothTransformation)
|
|
|
|
|
|
class SaqrWindow(QMainWindow):
|
|
def __init__(self, default_model: str = "saqr_best.pt", default_source: str = "0"):
|
|
super().__init__()
|
|
self.setWindowTitle("Saqr - PPE Safety Tracking")
|
|
self.setMinimumSize(1200, 700)
|
|
self._default_model = default_model
|
|
self._default_source = default_source
|
|
|
|
self.worker: Optional[DetectionWorker] = None
|
|
self._build_ui()
|
|
self._scan_cameras()
|
|
|
|
def _build_ui(self):
|
|
central = QWidget()
|
|
self.setCentralWidget(central)
|
|
main_layout = QHBoxLayout(central)
|
|
|
|
left = QVBoxLayout()
|
|
|
|
model_grp = QGroupBox("Model")
|
|
model_lay = QVBoxLayout(model_grp)
|
|
self.model_label = QLabel(self._default_model)
|
|
self.model_label.setWordWrap(True)
|
|
btn_model = QPushButton("Browse...")
|
|
btn_model.clicked.connect(self._browse_model)
|
|
model_lay.addWidget(self.model_label)
|
|
model_lay.addWidget(btn_model)
|
|
left.addWidget(model_grp)
|
|
|
|
cam_grp = QGroupBox("Camera Source")
|
|
cam_lay = QVBoxLayout(cam_grp)
|
|
self.cam_combo = QComboBox()
|
|
btn_refresh = QPushButton("Refresh")
|
|
btn_refresh.clicked.connect(self._scan_cameras)
|
|
cam_lay.addWidget(self.cam_combo)
|
|
cam_lay.addWidget(btn_refresh)
|
|
left.addWidget(cam_grp)
|
|
|
|
param_grp = QGroupBox("Parameters")
|
|
param_lay = QGridLayout(param_grp)
|
|
|
|
param_lay.addWidget(QLabel("Confidence:"), 0, 0)
|
|
self.conf_spin = QDoubleSpinBox()
|
|
self.conf_spin.setRange(0.1, 0.9)
|
|
self.conf_spin.setSingleStep(0.05)
|
|
self.conf_spin.setValue(0.35)
|
|
param_lay.addWidget(self.conf_spin, 0, 1)
|
|
|
|
param_lay.addWidget(QLabel("Max Missing:"), 1, 0)
|
|
self.missing_spin = QSpinBox()
|
|
self.missing_spin.setRange(10, 300)
|
|
self.missing_spin.setValue(90)
|
|
param_lay.addWidget(self.missing_spin, 1, 1)
|
|
|
|
param_lay.addWidget(QLabel("Match Dist:"), 2, 0)
|
|
self.dist_spin = QDoubleSpinBox()
|
|
self.dist_spin.setRange(50, 500)
|
|
self.dist_spin.setSingleStep(10)
|
|
self.dist_spin.setValue(250)
|
|
param_lay.addWidget(self.dist_spin, 2, 1)
|
|
|
|
param_lay.addWidget(QLabel("Confirm Frames:"), 3, 0)
|
|
self.confirm_spin = QSpinBox()
|
|
self.confirm_spin.setRange(1, 20)
|
|
self.confirm_spin.setValue(5)
|
|
param_lay.addWidget(self.confirm_spin, 3, 1)
|
|
|
|
left.addWidget(param_grp)
|
|
|
|
btn_lay = QHBoxLayout()
|
|
self.btn_start = QPushButton("Start")
|
|
self.btn_start.setStyleSheet("background-color: #2ecc71; color: white; font-weight: bold; padding: 8px;")
|
|
self.btn_start.clicked.connect(self._start)
|
|
self.btn_stop = QPushButton("Stop")
|
|
self.btn_stop.setStyleSheet("background-color: #e74c3c; color: white; font-weight: bold; padding: 8px;")
|
|
self.btn_stop.clicked.connect(self._stop)
|
|
self.btn_stop.setEnabled(False)
|
|
btn_lay.addWidget(self.btn_start)
|
|
btn_lay.addWidget(self.btn_stop)
|
|
left.addLayout(btn_lay)
|
|
|
|
stats_grp = QGroupBox("Live Status")
|
|
stats_lay = QGridLayout(stats_grp)
|
|
self.lbl_fps = QLabel("FPS: -")
|
|
self.lbl_safe = QLabel("SAFE: 0")
|
|
self.lbl_partial = QLabel("PARTIAL: 0")
|
|
self.lbl_unsafe = QLabel("UNSAFE: 0")
|
|
self.lbl_tracks = QLabel("TRACKS: 0")
|
|
|
|
self.lbl_safe.setStyleSheet("color: #27ae60; font-weight: bold; font-size: 14px;")
|
|
self.lbl_partial.setStyleSheet("color: #f39c12; font-weight: bold; font-size: 14px;")
|
|
self.lbl_unsafe.setStyleSheet("color: #e74c3c; font-weight: bold; font-size: 14px;")
|
|
self.lbl_tracks.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;")
|
|
|
|
stats_lay.addWidget(self.lbl_fps, 0, 0)
|
|
stats_lay.addWidget(self.lbl_tracks, 0, 1)
|
|
stats_lay.addWidget(self.lbl_safe, 1, 0)
|
|
stats_lay.addWidget(self.lbl_partial, 1, 1)
|
|
stats_lay.addWidget(self.lbl_unsafe, 2, 0, 1, 2)
|
|
left.addWidget(stats_grp)
|
|
|
|
left.addStretch()
|
|
|
|
centre = QVBoxLayout()
|
|
self.video_label = QLabel("No camera feed")
|
|
self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
self.video_label.setStyleSheet(
|
|
"background-color: #1a1a2e; color: #666; font-size: 18px; border-radius: 8px;"
|
|
)
|
|
self.video_label.setMinimumSize(640, 480)
|
|
centre.addWidget(self.video_label)
|
|
|
|
right = QVBoxLayout()
|
|
log_grp = QGroupBox("Event Log")
|
|
log_lay = QVBoxLayout(log_grp)
|
|
self.event_log = QTextEdit()
|
|
self.event_log.setReadOnly(True)
|
|
self.event_log.setMaximumWidth(380)
|
|
self.event_log.setStyleSheet(
|
|
"background-color: #0d1117; color: #c9d1d9; font-family: monospace; font-size: 11px;"
|
|
)
|
|
log_lay.addWidget(self.event_log)
|
|
|
|
btn_clear = QPushButton("Clear Log")
|
|
btn_clear.clicked.connect(self.event_log.clear)
|
|
log_lay.addWidget(btn_clear)
|
|
|
|
btn_export = QPushButton("Export CSV Report")
|
|
btn_export.clicked.connect(self._export_csv)
|
|
log_lay.addWidget(btn_export)
|
|
|
|
right.addWidget(log_grp)
|
|
|
|
left_widget = QWidget()
|
|
left_widget.setLayout(left)
|
|
left_widget.setFixedWidth(260)
|
|
|
|
centre_widget = QWidget()
|
|
centre_widget.setLayout(centre)
|
|
|
|
right_widget = QWidget()
|
|
right_widget.setLayout(right)
|
|
right_widget.setFixedWidth(380)
|
|
|
|
main_layout.addWidget(left_widget)
|
|
main_layout.addWidget(centre_widget, stretch=1)
|
|
main_layout.addWidget(right_widget)
|
|
|
|
self.statusBar().showMessage("Ready - load a model and start detection")
|
|
|
|
def _scan_cameras(self):
|
|
self.cam_combo.clear()
|
|
sources = list_cameras()
|
|
self.cam_combo.addItems(sources)
|
|
idx = self.cam_combo.findText(self._default_source)
|
|
if idx >= 0:
|
|
self.cam_combo.setCurrentIndex(idx)
|
|
elif self.cam_combo.count() > 0:
|
|
self.cam_combo.addItem(self._default_source)
|
|
self.cam_combo.setCurrentIndex(self.cam_combo.count() - 1)
|
|
|
|
def _browse_model(self):
|
|
path, _ = QFileDialog.getOpenFileName(
|
|
self, "Select YOLO Model", str(MODELS_DIR), "Model Files (*.pt)"
|
|
)
|
|
if path:
|
|
self.model_label.setText(path)
|
|
|
|
def _start(self):
|
|
try:
|
|
model_path = resolve_model_path(self.model_label.text())
|
|
except FileNotFoundError as e:
|
|
QMessageBox.critical(self, "Error", str(e))
|
|
return
|
|
|
|
source = self.cam_combo.currentText()
|
|
|
|
self.worker = DetectionWorker()
|
|
self.worker.configure(
|
|
model_path=str(model_path),
|
|
source=source,
|
|
conf=self.conf_spin.value(),
|
|
max_missing=self.missing_spin.value(),
|
|
match_dist=self.dist_spin.value(),
|
|
status_confirm=self.confirm_spin.value(),
|
|
)
|
|
self.worker.frame_ready.connect(self._on_frame)
|
|
self.worker.event_fired.connect(self._on_event)
|
|
self.worker.stats_updated.connect(self._on_stats)
|
|
self.worker.finished.connect(self._on_finished)
|
|
self.worker.start()
|
|
|
|
self.btn_start.setEnabled(False)
|
|
self.btn_stop.setEnabled(True)
|
|
self.statusBar().showMessage(f"Running | source={source} | conf={self.conf_spin.value()}")
|
|
log.info(f"GUI session started | source={source}")
|
|
|
|
def _stop(self):
|
|
if self.worker and self.worker.isRunning():
|
|
self.worker.stop()
|
|
self.worker.wait(3000)
|
|
self.btn_start.setEnabled(True)
|
|
self.btn_stop.setEnabled(False)
|
|
self.statusBar().showMessage("Stopped")
|
|
|
|
@Slot(np.ndarray, list)
|
|
def _on_frame(self, frame, visible):
|
|
pix = cv_to_qpixmap(frame, self.video_label.width(), self.video_label.height())
|
|
self.video_label.setPixmap(pix)
|
|
|
|
@Slot(str)
|
|
def _on_event(self, msg):
|
|
ts = datetime.now().strftime("%H:%M:%S")
|
|
color = "#c9d1d9"
|
|
if "UNSAFE" in msg:
|
|
color = "#f85149"
|
|
elif "SAFE" in msg and "UNSAFE" not in msg:
|
|
color = "#3fb950"
|
|
elif "PARTIAL" in msg:
|
|
color = "#d29922"
|
|
elif "ERROR" in msg:
|
|
color = "#f85149"
|
|
self.event_log.append(f'<span style="color:{color}">[{ts}] {msg}</span>')
|
|
self.event_log.verticalScrollBar().setValue(
|
|
self.event_log.verticalScrollBar().maximum()
|
|
)
|
|
|
|
@Slot(dict)
|
|
def _on_stats(self, stats):
|
|
self.lbl_fps.setText(f"FPS: {stats.get('fps', 0):.1f}")
|
|
self.lbl_safe.setText(f"SAFE: {stats.get('SAFE', 0)}")
|
|
self.lbl_partial.setText(f"PARTIAL: {stats.get('PARTIAL', 0)}")
|
|
self.lbl_unsafe.setText(f"UNSAFE: {stats.get('UNSAFE', 0)}")
|
|
self.lbl_tracks.setText(f"TRACKS: {stats.get('tracks', 0)}")
|
|
|
|
def _on_finished(self):
|
|
self.btn_start.setEnabled(True)
|
|
self.btn_stop.setEnabled(False)
|
|
self.statusBar().showMessage("Session ended")
|
|
|
|
def _export_csv(self):
|
|
path, _ = QFileDialog.getSaveFileName(
|
|
self, "Export CSV",
|
|
str(PROJECT_ROOT / f"ppe_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"),
|
|
"CSV Files (*.csv)"
|
|
)
|
|
if path:
|
|
from saqr.apps.manager_cli import export_csv, load_photos
|
|
export_csv(load_photos(), Path(path))
|
|
self._on_event(f"Exported: {path}")
|
|
|
|
def closeEvent(self, event):
|
|
self._stop()
|
|
event.accept()
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Saqr PPE GUI")
|
|
parser.add_argument("--model", default="saqr_best.pt")
|
|
parser.add_argument("--source", default="0")
|
|
args = parser.parse_args()
|
|
|
|
app = QApplication(sys.argv)
|
|
app.setStyle("Fusion")
|
|
|
|
from PySide6.QtGui import QPalette
|
|
palette = QPalette()
|
|
palette.setColor(QPalette.ColorRole.Window, QColor(30, 30, 46))
|
|
palette.setColor(QPalette.ColorRole.WindowText, QColor(205, 214, 244))
|
|
palette.setColor(QPalette.ColorRole.Base, QColor(24, 24, 37))
|
|
palette.setColor(QPalette.ColorRole.AlternateBase, QColor(30, 30, 46))
|
|
palette.setColor(QPalette.ColorRole.Text, QColor(205, 214, 244))
|
|
palette.setColor(QPalette.ColorRole.Button, QColor(49, 50, 68))
|
|
palette.setColor(QPalette.ColorRole.ButtonText, QColor(205, 214, 244))
|
|
palette.setColor(QPalette.ColorRole.Highlight, QColor(137, 180, 250))
|
|
palette.setColor(QPalette.ColorRole.HighlightedText, QColor(30, 30, 46))
|
|
app.setPalette(palette)
|
|
|
|
win = SaqrWindow(default_model=args.model, default_source=args.source)
|
|
win.show()
|
|
sys.exit(app.exec())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|