Saqr/train.py
2026-04-12 19:05:32 +04:00

119 lines
3.7 KiB
Python

"""
Saqr - PPE Detection | Training Script
=========================================
Train YOLO11n object detection on a PPE dataset (10 classes).
Classes: helmet, no-helmet, vest, no-vest, boots, no-boots,
gloves, no-gloves, goggles, no-goggles
Usage:
python train.py --dataset dataset
python train.py --dataset dataset --epochs 50 --batch 8
"""
import argparse
import shutil
from pathlib import Path
import yaml
from logger import get_logger
log = get_logger("Training", "train")
EXPECTED_CLASSES = [
"boots", "gloves", "goggles", "helmet", "no-boots",
"no-gloves", "no-goggles", "no-helmet", "no-vest", "vest",
]
def fix_data_yaml(dataset_root: Path) -> Path:
"""Ensure data.yaml has correct absolute paths for each split."""
yaml_path = dataset_root / "data.yaml"
if not yaml_path.exists():
log.error(f"data.yaml not found at {yaml_path}")
raise SystemExit(1)
with open(yaml_path) as f:
cfg = yaml.safe_load(f)
changed = False
for key, subdir in [("train", "train"), ("val", "valid"), ("test", "test")]:
img_dir = dataset_root / subdir / "images"
if img_dir.exists() and cfg.get(key) != str(img_dir):
cfg[key] = str(img_dir)
changed = True
if "path" not in cfg or cfg["path"] != str(dataset_root):
cfg["path"] = str(dataset_root)
changed = True
if changed:
with open(yaml_path, "w") as f:
yaml.dump(cfg, f, default_flow_style=False)
log.info(f"Fixed data.yaml paths -> {yaml_path}")
log.info(f"Classes ({cfg.get('nc', '?')}): {cfg.get('names', [])}")
return yaml_path
def main():
parser = argparse.ArgumentParser(description="Train Saqr PPE detector (YOLO11n)")
parser.add_argument("--dataset", default="dataset",
help="Root folder containing data.yaml + train/valid/test")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--model", default="yolo11n.pt",
help="Base YOLO model (auto-downloaded if not present)")
parser.add_argument("--name", default="saqr_det")
parser.add_argument("--device", default="0",
help="Training device: 'cpu', '0', 'cuda:0', etc.")
args = parser.parse_args()
root = Path(__file__).parent
dataset_root = root / args.dataset
if not dataset_root.exists():
log.error(f"Dataset folder not found: {dataset_root}")
raise SystemExit(1)
yaml_path = fix_data_yaml(dataset_root)
from ultralytics import YOLO
log.info(f"Loading base model: {args.model}")
model = YOLO(args.model)
log.info(f"Training | epochs={args.epochs} imgsz={args.imgsz} "
f"batch={args.batch} device={args.device}")
model.train(
data=str(yaml_path),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device,
name=args.name,
project=str(root / "runs" / "train"),
exist_ok=True,
)
# Copy best/last weights to models/
models_dir = root / "models"
models_dir.mkdir(exist_ok=True)
weights_dir = root / "runs" / "train" / args.name / "weights"
for name in ("best.pt", "last.pt"):
src = weights_dir / name
dst = models_dir / f"saqr_{name}"
if src.exists():
shutil.copy(src, dst)
log.info(f"Saved: {dst}")
metrics = model.val()
log.info(f"mAP50={metrics.box.map50:.4f} mAP50-95={metrics.box.map:.4f}")
log.info("Next: python saqr.py --source 0")
if __name__ == "__main__":
main()