Saqr/saqr/apps/train_cli.py

112 lines
3.6 KiB
Python

"""Train YOLO11n on the PPE dataset under data/dataset."""
from __future__ import annotations
import argparse
import shutil
from pathlib import Path
import yaml
from saqr.core.paths import DATASET_DIR, MODELS_DIR, PROJECT_ROOT, RUNS_DIR
from saqr.utils.logger import get_logger
log = get_logger("Training", "train")
EXPECTED_CLASSES = [
"boots", "gloves", "goggles", "helmet", "no-boots",
"no-gloves", "no-goggles", "no-helmet", "no-vest", "vest",
]
def fix_data_yaml(dataset_root: Path) -> Path:
"""Rewrite data.yaml with absolute paths for the current dataset location."""
yaml_path = dataset_root / "data.yaml"
if not yaml_path.exists():
log.error(f"data.yaml not found at {yaml_path}")
raise SystemExit(1)
with open(yaml_path) as f:
cfg = yaml.safe_load(f)
changed = False
for key, subdir in [("train", "train"), ("val", "valid"), ("test", "test")]:
img_dir = dataset_root / subdir / "images"
if img_dir.exists() and cfg.get(key) != str(img_dir):
cfg[key] = str(img_dir)
changed = True
if "path" not in cfg or cfg["path"] != str(dataset_root):
cfg["path"] = str(dataset_root)
changed = True
if changed:
with open(yaml_path, "w") as f:
yaml.dump(cfg, f, default_flow_style=False)
log.info(f"Fixed data.yaml paths -> {yaml_path}")
log.info(f"Classes ({cfg.get('nc', '?')}): {cfg.get('names', [])}")
return yaml_path
def main():
parser = argparse.ArgumentParser(description="Train Saqr PPE detector (YOLO11n)")
parser.add_argument("--dataset", default=str(DATASET_DIR),
help="Root folder containing data.yaml + train/valid/test")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--model", default="yolo11n.pt",
help="Base YOLO model (auto-downloaded if not present)")
parser.add_argument("--name", default="saqr_det")
parser.add_argument("--device", default="0")
args = parser.parse_args()
dataset_root = Path(args.dataset)
if not dataset_root.is_absolute():
dataset_root = PROJECT_ROOT / dataset_root
if not dataset_root.exists():
log.error(f"Dataset folder not found: {dataset_root}")
raise SystemExit(1)
yaml_path = fix_data_yaml(dataset_root)
from ultralytics import YOLO
base = Path(args.model)
if not base.is_absolute() and not base.exists():
candidate = MODELS_DIR / base.name
if candidate.exists():
base = candidate
log.info(f"Loading base model: {base}")
model = YOLO(str(base))
log.info(f"Training | epochs={args.epochs} imgsz={args.imgsz} "
f"batch={args.batch} device={args.device}")
model.train(
data=str(yaml_path),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device,
name=args.name,
project=str(RUNS_DIR / "train"),
exist_ok=True,
)
MODELS_DIR.mkdir(parents=True, exist_ok=True)
weights_dir = RUNS_DIR / "train" / args.name / "weights"
for name in ("best.pt", "last.pt"):
src = weights_dir / name
dst = MODELS_DIR / f"saqr_{name}"
if src.exists():
shutil.copy(src, dst)
log.info(f"Saved: {dst}")
metrics = model.val()
log.info(f"mAP50={metrics.box.map50:.4f} mAP50-95={metrics.box.map:.4f}")
log.info("Next: saqr --source 0")
if __name__ == "__main__":
main()