Saqr/saqr/gui/app.py

479 lines
17 KiB
Python

"""Saqr PySide6 desktop GUI for live PPE compliance monitoring."""
from __future__ import annotations
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import cv2
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal, Slot
from PySide6.QtGui import QColor, QImage, QPixmap
from PySide6.QtWidgets import (
QApplication, QComboBox, QDoubleSpinBox, QFileDialog, QGridLayout,
QGroupBox, QHBoxLayout, QLabel, QMainWindow, QMessageBox, QPushButton,
QSpinBox, QTextEdit, QVBoxLayout, QWidget,
)
from ultralytics import YOLO
from saqr.core.capture import save_track_image, setup_capture_dirs
from saqr.core.compliance import split_wearing_missing
from saqr.core.detection import STATUSES, collect_detections
from saqr.core.drawing import draw_counters, draw_track
from saqr.core.events import EventLogger, emit_event, write_result_csv
from saqr.core.grouping import group_detections_to_people
from saqr.core.model import resolve_model_path
from saqr.core.paths import EVENTS_CSV, MODELS_DIR, PROJECT_ROOT, RESULT_CSV
from saqr.core.tracking import PersonTracker
from saqr.utils.logger import get_logger
log = get_logger("Inference", "gui")
def list_cameras(max_idx: int = 10) -> List[str]:
sources: List[str] = []
for i in range(max_idx):
dev = f"/dev/video{i}"
if Path(dev).exists():
sources.append(dev)
if not sources:
for i in range(4):
cap = cv2.VideoCapture(i)
if cap.isOpened():
sources.append(str(i))
cap.release()
return sources if sources else ["0"]
def open_camera(source: str, width: int = 640, height: int = 480, fps: int = 30):
if source.startswith("/dev/video"):
cap = cv2.VideoCapture(source, cv2.CAP_V4L2)
elif source.isdigit():
cap = cv2.VideoCapture(int(source))
else:
cap = cv2.VideoCapture(source)
if cap.isOpened():
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
cap.set(cv2.CAP_PROP_FPS, fps)
return cap
class DetectionWorker(QThread):
frame_ready = Signal(np.ndarray, list)
event_fired = Signal(str)
stats_updated = Signal(dict)
def __init__(self, parent=None):
super().__init__(parent)
self._running = False
self.model: Optional[YOLO] = None
self.source = "0"
self.conf = 0.35
self.max_missing = 90
self.match_distance = 250.0
self.status_confirm = 5
self.capture_dirs: Dict[str, Path] = {}
def configure(self, model_path: str, source: str, conf: float,
max_missing: int, match_dist: float, status_confirm: int):
self.source = source
self.conf = conf
self.max_missing = max_missing
self.match_distance = match_dist
self.status_confirm = status_confirm
self.capture_dirs = setup_capture_dirs()
if self.model is None or str(model_path) != getattr(self, "_last_model", ""):
self.model = YOLO(model_path)
self._last_model = str(model_path)
def run(self):
self._running = True
cap = open_camera(self.source)
if not cap.isOpened():
self.event_fired.emit(f"[ERROR] Cannot open camera: {self.source}")
return
ok, first = cap.read()
if not ok:
self.event_fired.emit("[ERROR] Cannot read first frame")
cap.release()
return
event_logger = EventLogger(EVENTS_CSV)
tracker = PersonTracker(
event_logger=event_logger,
max_missing=self.max_missing,
match_distance=self.match_distance,
status_confirm_frames=self.status_confirm,
)
self.event_fired.emit(f"Session started | source={self.source}")
prev = time.time()
frame_idx = 0
frame = first
while self._running:
frame_idx += 1
h, w = frame.shape[:2]
annotated = frame.copy()
try:
detections = collect_detections(frame, self.model, self.conf)
candidates = group_detections_to_people(detections, w, h)
created, changed = tracker.update(candidates, frame_idx)
visible = tracker.visible_tracks()
created_ids = {t.track_id for t in created}
changed_ids = {t.track_id for t in changed}
event_ids = created_ids | changed_ids
for track in visible:
save_track_image(frame, track, self.capture_dirs)
if track.track_id in event_ids:
ev_type = "NEW" if track.track_id in created_ids else "STATUS_CHANGE"
wearing, missing, _unknown = split_wearing_missing(track.items)
msg = (
f"ID {track.track_id:04d} | {ev_type} | {track.status} | "
f"W: {', '.join(wearing) or 'none'} | "
f"M: {', '.join(missing) or 'none'}"
)
self.event_fired.emit(msg)
emit_event(track, event_logger, ev_type)
draw_track(annotated, track)
if frame_idx % 30 == 0:
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
except Exception as e:
self.event_fired.emit(f"[ERROR] Frame {frame_idx}: {e}")
visible = tracker.visible_tracks()
now_t = time.time()
fps = 1.0 / max(now_t - prev, 1e-9)
prev = now_t
draw_counters(annotated, visible, fps)
counts = {s: 0 for s in STATUSES}
for t in visible:
counts[t.status] += 1
counts["fps"] = fps
counts["tracks"] = len(visible)
self.frame_ready.emit(annotated, visible)
self.stats_updated.emit(counts)
ret, frame = cap.read()
if not ret:
break
cap.release()
write_result_csv(list(tracker.tracks.values()), RESULT_CSV)
self.event_fired.emit("Session ended.")
def stop(self):
self._running = False
def cv_to_qpixmap(frame: np.ndarray, max_w: int = 960, max_h: int = 720) -> QPixmap:
h, w, ch = frame.shape
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
qimg = QImage(rgb.data, w, h, ch * w, QImage.Format.Format_RGB888)
pix = QPixmap.fromImage(qimg)
return pix.scaled(max_w, max_h, Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation)
class SaqrWindow(QMainWindow):
def __init__(self, default_model: str = "saqr_best.pt", default_source: str = "0"):
super().__init__()
self.setWindowTitle("Saqr - PPE Safety Tracking")
self.setMinimumSize(1200, 700)
self._default_model = default_model
self._default_source = default_source
self.worker: Optional[DetectionWorker] = None
self._build_ui()
self._scan_cameras()
def _build_ui(self):
central = QWidget()
self.setCentralWidget(central)
main_layout = QHBoxLayout(central)
left = QVBoxLayout()
model_grp = QGroupBox("Model")
model_lay = QVBoxLayout(model_grp)
self.model_label = QLabel(self._default_model)
self.model_label.setWordWrap(True)
btn_model = QPushButton("Browse...")
btn_model.clicked.connect(self._browse_model)
model_lay.addWidget(self.model_label)
model_lay.addWidget(btn_model)
left.addWidget(model_grp)
cam_grp = QGroupBox("Camera Source")
cam_lay = QVBoxLayout(cam_grp)
self.cam_combo = QComboBox()
btn_refresh = QPushButton("Refresh")
btn_refresh.clicked.connect(self._scan_cameras)
cam_lay.addWidget(self.cam_combo)
cam_lay.addWidget(btn_refresh)
left.addWidget(cam_grp)
param_grp = QGroupBox("Parameters")
param_lay = QGridLayout(param_grp)
param_lay.addWidget(QLabel("Confidence:"), 0, 0)
self.conf_spin = QDoubleSpinBox()
self.conf_spin.setRange(0.1, 0.9)
self.conf_spin.setSingleStep(0.05)
self.conf_spin.setValue(0.35)
param_lay.addWidget(self.conf_spin, 0, 1)
param_lay.addWidget(QLabel("Max Missing:"), 1, 0)
self.missing_spin = QSpinBox()
self.missing_spin.setRange(10, 300)
self.missing_spin.setValue(90)
param_lay.addWidget(self.missing_spin, 1, 1)
param_lay.addWidget(QLabel("Match Dist:"), 2, 0)
self.dist_spin = QDoubleSpinBox()
self.dist_spin.setRange(50, 500)
self.dist_spin.setSingleStep(10)
self.dist_spin.setValue(250)
param_lay.addWidget(self.dist_spin, 2, 1)
param_lay.addWidget(QLabel("Confirm Frames:"), 3, 0)
self.confirm_spin = QSpinBox()
self.confirm_spin.setRange(1, 20)
self.confirm_spin.setValue(5)
param_lay.addWidget(self.confirm_spin, 3, 1)
left.addWidget(param_grp)
btn_lay = QHBoxLayout()
self.btn_start = QPushButton("Start")
self.btn_start.setStyleSheet("background-color: #2ecc71; color: white; font-weight: bold; padding: 8px;")
self.btn_start.clicked.connect(self._start)
self.btn_stop = QPushButton("Stop")
self.btn_stop.setStyleSheet("background-color: #e74c3c; color: white; font-weight: bold; padding: 8px;")
self.btn_stop.clicked.connect(self._stop)
self.btn_stop.setEnabled(False)
btn_lay.addWidget(self.btn_start)
btn_lay.addWidget(self.btn_stop)
left.addLayout(btn_lay)
stats_grp = QGroupBox("Live Status")
stats_lay = QGridLayout(stats_grp)
self.lbl_fps = QLabel("FPS: -")
self.lbl_safe = QLabel("SAFE: 0")
self.lbl_partial = QLabel("PARTIAL: 0")
self.lbl_unsafe = QLabel("UNSAFE: 0")
self.lbl_tracks = QLabel("TRACKS: 0")
self.lbl_safe.setStyleSheet("color: #27ae60; font-weight: bold; font-size: 14px;")
self.lbl_partial.setStyleSheet("color: #f39c12; font-weight: bold; font-size: 14px;")
self.lbl_unsafe.setStyleSheet("color: #e74c3c; font-weight: bold; font-size: 14px;")
self.lbl_tracks.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;")
stats_lay.addWidget(self.lbl_fps, 0, 0)
stats_lay.addWidget(self.lbl_tracks, 0, 1)
stats_lay.addWidget(self.lbl_safe, 1, 0)
stats_lay.addWidget(self.lbl_partial, 1, 1)
stats_lay.addWidget(self.lbl_unsafe, 2, 0, 1, 2)
left.addWidget(stats_grp)
left.addStretch()
centre = QVBoxLayout()
self.video_label = QLabel("No camera feed")
self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.video_label.setStyleSheet(
"background-color: #1a1a2e; color: #666; font-size: 18px; border-radius: 8px;"
)
self.video_label.setMinimumSize(640, 480)
centre.addWidget(self.video_label)
right = QVBoxLayout()
log_grp = QGroupBox("Event Log")
log_lay = QVBoxLayout(log_grp)
self.event_log = QTextEdit()
self.event_log.setReadOnly(True)
self.event_log.setMaximumWidth(380)
self.event_log.setStyleSheet(
"background-color: #0d1117; color: #c9d1d9; font-family: monospace; font-size: 11px;"
)
log_lay.addWidget(self.event_log)
btn_clear = QPushButton("Clear Log")
btn_clear.clicked.connect(self.event_log.clear)
log_lay.addWidget(btn_clear)
btn_export = QPushButton("Export CSV Report")
btn_export.clicked.connect(self._export_csv)
log_lay.addWidget(btn_export)
right.addWidget(log_grp)
left_widget = QWidget()
left_widget.setLayout(left)
left_widget.setFixedWidth(260)
centre_widget = QWidget()
centre_widget.setLayout(centre)
right_widget = QWidget()
right_widget.setLayout(right)
right_widget.setFixedWidth(380)
main_layout.addWidget(left_widget)
main_layout.addWidget(centre_widget, stretch=1)
main_layout.addWidget(right_widget)
self.statusBar().showMessage("Ready - load a model and start detection")
def _scan_cameras(self):
self.cam_combo.clear()
sources = list_cameras()
self.cam_combo.addItems(sources)
idx = self.cam_combo.findText(self._default_source)
if idx >= 0:
self.cam_combo.setCurrentIndex(idx)
elif self.cam_combo.count() > 0:
self.cam_combo.addItem(self._default_source)
self.cam_combo.setCurrentIndex(self.cam_combo.count() - 1)
def _browse_model(self):
path, _ = QFileDialog.getOpenFileName(
self, "Select YOLO Model", str(MODELS_DIR), "Model Files (*.pt)"
)
if path:
self.model_label.setText(path)
def _start(self):
try:
model_path = resolve_model_path(self.model_label.text())
except FileNotFoundError as e:
QMessageBox.critical(self, "Error", str(e))
return
source = self.cam_combo.currentText()
self.worker = DetectionWorker()
self.worker.configure(
model_path=str(model_path),
source=source,
conf=self.conf_spin.value(),
max_missing=self.missing_spin.value(),
match_dist=self.dist_spin.value(),
status_confirm=self.confirm_spin.value(),
)
self.worker.frame_ready.connect(self._on_frame)
self.worker.event_fired.connect(self._on_event)
self.worker.stats_updated.connect(self._on_stats)
self.worker.finished.connect(self._on_finished)
self.worker.start()
self.btn_start.setEnabled(False)
self.btn_stop.setEnabled(True)
self.statusBar().showMessage(f"Running | source={source} | conf={self.conf_spin.value()}")
log.info(f"GUI session started | source={source}")
def _stop(self):
if self.worker and self.worker.isRunning():
self.worker.stop()
self.worker.wait(3000)
self.btn_start.setEnabled(True)
self.btn_stop.setEnabled(False)
self.statusBar().showMessage("Stopped")
@Slot(np.ndarray, list)
def _on_frame(self, frame, visible):
pix = cv_to_qpixmap(frame, self.video_label.width(), self.video_label.height())
self.video_label.setPixmap(pix)
@Slot(str)
def _on_event(self, msg):
ts = datetime.now().strftime("%H:%M:%S")
color = "#c9d1d9"
if "UNSAFE" in msg:
color = "#f85149"
elif "SAFE" in msg and "UNSAFE" not in msg:
color = "#3fb950"
elif "PARTIAL" in msg:
color = "#d29922"
elif "ERROR" in msg:
color = "#f85149"
self.event_log.append(f'<span style="color:{color}">[{ts}] {msg}</span>')
self.event_log.verticalScrollBar().setValue(
self.event_log.verticalScrollBar().maximum()
)
@Slot(dict)
def _on_stats(self, stats):
self.lbl_fps.setText(f"FPS: {stats.get('fps', 0):.1f}")
self.lbl_safe.setText(f"SAFE: {stats.get('SAFE', 0)}")
self.lbl_partial.setText(f"PARTIAL: {stats.get('PARTIAL', 0)}")
self.lbl_unsafe.setText(f"UNSAFE: {stats.get('UNSAFE', 0)}")
self.lbl_tracks.setText(f"TRACKS: {stats.get('tracks', 0)}")
def _on_finished(self):
self.btn_start.setEnabled(True)
self.btn_stop.setEnabled(False)
self.statusBar().showMessage("Session ended")
def _export_csv(self):
path, _ = QFileDialog.getSaveFileName(
self, "Export CSV",
str(PROJECT_ROOT / f"ppe_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"),
"CSV Files (*.csv)"
)
if path:
from saqr.apps.manager_cli import export_csv, load_photos
export_csv(load_photos(), Path(path))
self._on_event(f"Exported: {path}")
def closeEvent(self, event):
self._stop()
event.accept()
def main():
import argparse
parser = argparse.ArgumentParser(description="Saqr PPE GUI")
parser.add_argument("--model", default="saqr_best.pt")
parser.add_argument("--source", default="0")
args = parser.parse_args()
app = QApplication(sys.argv)
app.setStyle("Fusion")
from PySide6.QtGui import QPalette
palette = QPalette()
palette.setColor(QPalette.ColorRole.Window, QColor(30, 30, 46))
palette.setColor(QPalette.ColorRole.WindowText, QColor(205, 214, 244))
palette.setColor(QPalette.ColorRole.Base, QColor(24, 24, 37))
palette.setColor(QPalette.ColorRole.AlternateBase, QColor(30, 30, 46))
palette.setColor(QPalette.ColorRole.Text, QColor(205, 214, 244))
palette.setColor(QPalette.ColorRole.Button, QColor(49, 50, 68))
palette.setColor(QPalette.ColorRole.ButtonText, QColor(205, 214, 244))
palette.setColor(QPalette.ColorRole.Highlight, QColor(137, 180, 250))
palette.setColor(QPalette.ColorRole.HighlightedText, QColor(30, 30, 46))
app.setPalette(palette)
win = SaqrWindow(default_model=args.model, default_source=args.source)
win.show()
sys.exit(app.exec())
if __name__ == "__main__":
main()