112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
"""Train YOLO11n on the PPE dataset under data/dataset."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from saqr.core.paths import DATASET_DIR, MODELS_DIR, PROJECT_ROOT, RUNS_DIR
|
|
from saqr.utils.logger import get_logger
|
|
|
|
log = get_logger("Training", "train")
|
|
|
|
EXPECTED_CLASSES = [
|
|
"boots", "gloves", "goggles", "helmet", "no-boots",
|
|
"no-gloves", "no-goggles", "no-helmet", "no-vest", "vest",
|
|
]
|
|
|
|
|
|
def fix_data_yaml(dataset_root: Path) -> Path:
|
|
"""Rewrite data.yaml with absolute paths for the current dataset location."""
|
|
yaml_path = dataset_root / "data.yaml"
|
|
if not yaml_path.exists():
|
|
log.error(f"data.yaml not found at {yaml_path}")
|
|
raise SystemExit(1)
|
|
|
|
with open(yaml_path) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
changed = False
|
|
for key, subdir in [("train", "train"), ("val", "valid"), ("test", "test")]:
|
|
img_dir = dataset_root / subdir / "images"
|
|
if img_dir.exists() and cfg.get(key) != str(img_dir):
|
|
cfg[key] = str(img_dir)
|
|
changed = True
|
|
|
|
if "path" not in cfg or cfg["path"] != str(dataset_root):
|
|
cfg["path"] = str(dataset_root)
|
|
changed = True
|
|
|
|
if changed:
|
|
with open(yaml_path, "w") as f:
|
|
yaml.dump(cfg, f, default_flow_style=False)
|
|
log.info(f"Fixed data.yaml paths -> {yaml_path}")
|
|
|
|
log.info(f"Classes ({cfg.get('nc', '?')}): {cfg.get('names', [])}")
|
|
return yaml_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train Saqr PPE detector (YOLO11n)")
|
|
parser.add_argument("--dataset", default=str(DATASET_DIR),
|
|
help="Root folder containing data.yaml + train/valid/test")
|
|
parser.add_argument("--epochs", type=int, default=100)
|
|
parser.add_argument("--imgsz", type=int, default=640)
|
|
parser.add_argument("--batch", type=int, default=16)
|
|
parser.add_argument("--model", default="yolo11n.pt",
|
|
help="Base YOLO model (auto-downloaded if not present)")
|
|
parser.add_argument("--name", default="saqr_det")
|
|
parser.add_argument("--device", default="0")
|
|
args = parser.parse_args()
|
|
|
|
dataset_root = Path(args.dataset)
|
|
if not dataset_root.is_absolute():
|
|
dataset_root = PROJECT_ROOT / dataset_root
|
|
if not dataset_root.exists():
|
|
log.error(f"Dataset folder not found: {dataset_root}")
|
|
raise SystemExit(1)
|
|
|
|
yaml_path = fix_data_yaml(dataset_root)
|
|
|
|
from ultralytics import YOLO
|
|
|
|
base = Path(args.model)
|
|
if not base.is_absolute() and not base.exists():
|
|
candidate = MODELS_DIR / base.name
|
|
if candidate.exists():
|
|
base = candidate
|
|
log.info(f"Loading base model: {base}")
|
|
model = YOLO(str(base))
|
|
|
|
log.info(f"Training | epochs={args.epochs} imgsz={args.imgsz} "
|
|
f"batch={args.batch} device={args.device}")
|
|
model.train(
|
|
data=str(yaml_path),
|
|
epochs=args.epochs,
|
|
imgsz=args.imgsz,
|
|
batch=args.batch,
|
|
device=args.device,
|
|
name=args.name,
|
|
project=str(RUNS_DIR / "train"),
|
|
exist_ok=True,
|
|
)
|
|
|
|
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
|
weights_dir = RUNS_DIR / "train" / args.name / "weights"
|
|
for name in ("best.pt", "last.pt"):
|
|
src = weights_dir / name
|
|
dst = MODELS_DIR / f"saqr_{name}"
|
|
if src.exists():
|
|
shutil.copy(src, dst)
|
|
log.info(f"Saved: {dst}")
|
|
|
|
metrics = model.val()
|
|
log.info(f"mAP50={metrics.box.map50:.4f} mAP50-95={metrics.box.map:.4f}")
|
|
log.info("Next: saqr --source 0")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|