""" Saqr - PPE Detection | Training Script ========================================= Train YOLO11n object detection on a PPE dataset (10 classes). Classes: helmet, no-helmet, vest, no-vest, boots, no-boots, gloves, no-gloves, goggles, no-goggles Usage: python train.py --dataset dataset python train.py --dataset dataset --epochs 50 --batch 8 """ import argparse import shutil from pathlib import Path import yaml from logger import get_logger log = get_logger("Training", "train") EXPECTED_CLASSES = [ "boots", "gloves", "goggles", "helmet", "no-boots", "no-gloves", "no-goggles", "no-helmet", "no-vest", "vest", ] def fix_data_yaml(dataset_root: Path) -> Path: """Ensure data.yaml has correct absolute paths for each split.""" yaml_path = dataset_root / "data.yaml" if not yaml_path.exists(): log.error(f"data.yaml not found at {yaml_path}") raise SystemExit(1) with open(yaml_path) as f: cfg = yaml.safe_load(f) changed = False for key, subdir in [("train", "train"), ("val", "valid"), ("test", "test")]: img_dir = dataset_root / subdir / "images" if img_dir.exists() and cfg.get(key) != str(img_dir): cfg[key] = str(img_dir) changed = True if "path" not in cfg or cfg["path"] != str(dataset_root): cfg["path"] = str(dataset_root) changed = True if changed: with open(yaml_path, "w") as f: yaml.dump(cfg, f, default_flow_style=False) log.info(f"Fixed data.yaml paths -> {yaml_path}") log.info(f"Classes ({cfg.get('nc', '?')}): {cfg.get('names', [])}") return yaml_path def main(): parser = argparse.ArgumentParser(description="Train Saqr PPE detector (YOLO11n)") parser.add_argument("--dataset", default="dataset", help="Root folder containing data.yaml + train/valid/test") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--imgsz", type=int, default=640) parser.add_argument("--batch", type=int, default=16) parser.add_argument("--model", default="yolo11n.pt", help="Base YOLO model (auto-downloaded if not present)") parser.add_argument("--name", default="saqr_det") parser.add_argument("--device", default="0", help="Training device: 'cpu', '0', 'cuda:0', etc.") args = parser.parse_args() root = Path(__file__).parent dataset_root = root / args.dataset if not dataset_root.exists(): log.error(f"Dataset folder not found: {dataset_root}") raise SystemExit(1) yaml_path = fix_data_yaml(dataset_root) from ultralytics import YOLO log.info(f"Loading base model: {args.model}") model = YOLO(args.model) log.info(f"Training | epochs={args.epochs} imgsz={args.imgsz} " f"batch={args.batch} device={args.device}") model.train( data=str(yaml_path), epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, device=args.device, name=args.name, project=str(root / "runs" / "train"), exist_ok=True, ) # Copy best/last weights to models/ models_dir = root / "models" models_dir.mkdir(exist_ok=True) weights_dir = root / "runs" / "train" / args.name / "weights" for name in ("best.pt", "last.pt"): src = weights_dir / name dst = models_dir / f"saqr_{name}" if src.exists(): shutil.copy(src, dst) log.info(f"Saved: {dst}") metrics = model.val() log.info(f"mAP50={metrics.box.map50:.4f} mAP50-95={metrics.box.map:.4f}") log.info("Next: python saqr.py --source 0") if __name__ == "__main__": main()