"""Train YOLO11n on the PPE dataset under data/dataset.""" from __future__ import annotations import argparse import shutil from pathlib import Path import yaml from core.paths import DATASET_DIR, MODELS_DIR, PROJECT_ROOT, RUNS_DIR from utils.config import load_config from utils.logger import get_logger log = get_logger("Training", "train") _TRAIN = load_config("core")["training"] EXPECTED_CLASSES = [ "boots", "gloves", "goggles", "helmet", "no-boots", "no-gloves", "no-goggles", "no-helmet", "no-vest", "vest", ] def fix_data_yaml(dataset_root: Path) -> Path: """Rewrite data.yaml with absolute paths for the current dataset location.""" yaml_path = dataset_root / "data.yaml" if not yaml_path.exists(): log.error(f"data.yaml not found at {yaml_path}") raise SystemExit(1) with open(yaml_path) as f: cfg = yaml.safe_load(f) changed = False for key, subdir in [("train", "train"), ("val", "valid"), ("test", "test")]: img_dir = dataset_root / subdir / "images" if img_dir.exists() and cfg.get(key) != str(img_dir): cfg[key] = str(img_dir) changed = True if "path" not in cfg or cfg["path"] != str(dataset_root): cfg["path"] = str(dataset_root) changed = True if changed: with open(yaml_path, "w") as f: yaml.dump(cfg, f, default_flow_style=False) log.info(f"Fixed data.yaml paths -> {yaml_path}") log.info(f"Classes ({cfg.get('nc', '?')}): {cfg.get('names', [])}") return yaml_path def main(): parser = argparse.ArgumentParser(description="Train Saqr PPE detector (YOLO11n)") parser.add_argument("--dataset", default=str(DATASET_DIR), help="Root folder containing data.yaml + train/valid/test") parser.add_argument("--epochs", type=int, default=_TRAIN["epochs"]) parser.add_argument("--imgsz", type=int, default=_TRAIN["imgsz"]) parser.add_argument("--batch", type=int, default=_TRAIN["batch"]) parser.add_argument("--model", default=_TRAIN["base_model"], help="Base YOLO model (auto-downloaded if not present)") parser.add_argument("--name", default=_TRAIN["run_name"]) parser.add_argument("--device", default=_TRAIN["device"]) args = parser.parse_args() dataset_root = Path(args.dataset) if not dataset_root.is_absolute(): dataset_root = PROJECT_ROOT / dataset_root if not dataset_root.exists(): log.error(f"Dataset folder not found: {dataset_root}") raise SystemExit(1) yaml_path = fix_data_yaml(dataset_root) from ultralytics import YOLO base = Path(args.model) if not base.is_absolute() and not base.exists(): candidate = MODELS_DIR / base.name if candidate.exists(): base = candidate log.info(f"Loading base model: {base}") model = YOLO(str(base)) log.info(f"Training | epochs={args.epochs} imgsz={args.imgsz} " f"batch={args.batch} device={args.device}") model.train( data=str(yaml_path), epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, device=args.device, name=args.name, project=str(RUNS_DIR / "train"), exist_ok=True, ) MODELS_DIR.mkdir(parents=True, exist_ok=True) weights_dir = RUNS_DIR / "train" / args.name / "weights" for name in ("best.pt", "last.pt"): src = weights_dir / name dst = MODELS_DIR / f"saqr_{name}" if src.exists(): shutil.copy(src, dst) log.info(f"Saved: {dst}") metrics = model.val() log.info(f"mAP50={metrics.box.map50:.4f} mAP50-95={metrics.box.map:.4f}") log.info("Next: saqr --source 0") if __name__ == "__main__": main()