Sanad/motion/teaching.py

276 lines
10 KiB
Python

"""Teaching mode — safe hold → limp arms → record joint positions.
Ported from G1_Lootah/Manual_Recorder/g1_teach_v4_stable.py.
Sequence:
1. Safe hold (3s): Arms rigid at current pose, waist locked.
2. Teach phase: Arms go limp (KP=0), user physically moves them.
Joint positions are recorded at 60 Hz.
3. Return home: Smooth interpolation back to arm_home.jsonl.
4. Save: Writes JSONL to data/motions/<name>.jsonl.
"""
from __future__ import annotations
import json
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
from Project.Sanad.config import G1_NUM_MOTOR, MOTIONS_DIR, REPLAY_HZ
from Project.Sanad.core.config_loader import section as _cfg_section
from Project.Sanad.core.event_bus import bus
from Project.Sanad.core.logger import get_logger
log = get_logger("teaching")
_T = _cfg_section("motion", "teaching")
SAFE_HOLD_SEC = _T.get("safe_hold_sec", 3.0)
WAIST_KP = _T.get("waist_kp", 60.0)
WAIST_KD = _T.get("waist_kd", 4.0)
HOLD_ARM_KP = _T.get("hold_arm_kp", 60.0)
HOLD_ARM_KD = _T.get("hold_arm_kd", 4.0)
TEACH_ARM_KP = _T.get("teach_arm_kp", 0.0) # limp — no stiffness
TEACH_ARM_KD = _T.get("teach_arm_kd", 2.0) # damping only
try:
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_
from unitree_sdk2py.utils.crc import CRC
_HAS_SDK = True
except ImportError:
_HAS_SDK = False
class TeachingSession:
"""Records a teaching session (one at a time)."""
def __init__(self, arm_controller):
self._arm = arm_controller
self._lock = threading.Lock()
self._recording = False
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None
self._name = ""
self._frames: list[dict[str, Any]] = []
self._phase = "idle" # idle | holding | teaching | returning | done
self._started_at = 0.0
self._finalized = False
self._finalize_lock = threading.Lock()
self._final_result: dict[str, Any] | None = None
@property
def is_recording(self) -> bool:
return self._recording
def start(self, name: str, duration_sec: float = 15.0) -> dict[str, Any]:
with self._lock:
if self._recording:
raise RuntimeError("Teaching session already active.")
self._recording = True
self._finalized = False
self._final_result = None
self._name = name
self._frames = []
self._stop_event.clear()
self._phase = "holding"
self._started_at = time.monotonic()
self._thread = threading.Thread(
target=self._run, args=(name, duration_sec), daemon=True
)
self._thread.start()
log.info("Teaching started: %s (%.0fs)", name, duration_sec)
bus.emit_sync("motion.teaching_started", name=name, duration_sec=duration_sec)
return {"recording": True, "name": name, "duration_sec": duration_sec}
def stop(self) -> dict[str, Any]:
with self._lock:
if not self._recording:
raise RuntimeError("No teaching session active.")
self._stop_event.set()
if self._thread:
self._thread.join(timeout=10.0)
# Finalize is now ALWAYS done by the worker thread (_run).
# If for some reason the worker died without finalizing, do it here.
result = self._finalize()
return result
def _run(self, name: str, duration_sec: float):
interval = 1.0 / REPLAY_HZ
arm = self._arm
try:
if _HAS_SDK and arm._initialized:
self._run_hardware(name, duration_sec, interval)
else:
self._run_simulation(name, duration_sec, interval)
except Exception:
log.exception("Teaching session crashed")
finally:
# Always finalize from the worker thread — stop() will see _finalized=True.
self._finalize()
def _run_hardware(self, name: str, duration_sec: float, interval: float):
"""Real hardware teaching: hold → limp → record → home."""
arm = self._arm
low_cmd = arm._low_cmd
crc = arm._crc
initial_q = arm._get_current_q()
waist_lock = list(initial_q)
# Phase 1: Safe hold
self._phase = "holding"
hold_end = time.monotonic() + SAFE_HOLD_SEC
log.info("Safe hold (%.1fs) — arms rigid", SAFE_HOLD_SEC)
while time.monotonic() < hold_end and not self._stop_event.is_set():
for i in range(G1_NUM_MOTOR):
low_cmd.motor_cmd[i].mode = 1
low_cmd.motor_cmd[i].q = initial_q[i]
low_cmd.motor_cmd[i].dq = 0.0
low_cmd.motor_cmd[i].tau = 0.0
if i < 15: # body/waist
low_cmd.motor_cmd[i].kp = WAIST_KP
low_cmd.motor_cmd[i].kd = WAIST_KD
else: # arms
low_cmd.motor_cmd[i].kp = HOLD_ARM_KP
low_cmd.motor_cmd[i].kd = HOLD_ARM_KD
low_cmd.motor_cmd[29].q = 1.0
low_cmd.crc = crc.Crc(low_cmd)
arm._arm_pub.Write(low_cmd)
time.sleep(interval)
if self._stop_event.is_set():
return
# Phase 2: Teaching — arms go limp, record
self._phase = "teaching"
log.info("Arms released — move them now! Recording at %d Hz", int(REPLAY_HZ))
t0 = time.monotonic()
while not self._stop_event.is_set():
elapsed = time.monotonic() - t0
if elapsed >= duration_sec:
break
# Limp arms, locked waist
current_q = arm._get_current_q()
for i in range(G1_NUM_MOTOR):
low_cmd.motor_cmd[i].mode = 1
low_cmd.motor_cmd[i].dq = 0.0
low_cmd.motor_cmd[i].tau = 0.0
if i < 15:
low_cmd.motor_cmd[i].q = waist_lock[i]
low_cmd.motor_cmd[i].kp = WAIST_KP
low_cmd.motor_cmd[i].kd = WAIST_KD
else:
low_cmd.motor_cmd[i].q = current_q[i]
low_cmd.motor_cmd[i].kp = TEACH_ARM_KP
low_cmd.motor_cmd[i].kd = TEACH_ARM_KD
low_cmd.motor_cmd[29].q = 1.0
low_cmd.crc = crc.Crc(low_cmd)
arm._arm_pub.Write(low_cmd)
self._frames.append({"t": round(elapsed, 4), "q": current_q})
time.sleep(interval)
# Phase 3: Return home
self._phase = "returning"
from Project.Sanad.motion.arm_controller import _load_home_q, _lerp_q
home_q = _load_home_q() or initial_q
last_q = self._frames[-1]["q"] if self._frames else initial_q
for step in range(180):
t = (step + 1) / 180
interp = _lerp_q(last_q, home_q, t)
arm._send_frame(interp, waist_lock)
time.sleep(1.0 / REPLAY_HZ)
arm._disable_sdk()
def _run_simulation(self, name: str, duration_sec: float, interval: float):
"""Simulation: just record zero-pose frames for the given duration."""
self._phase = "holding"
time.sleep(min(SAFE_HOLD_SEC, 1.0)) # shortened in sim
self._phase = "teaching"
t0 = time.monotonic()
log.info("[SIM] Teaching — recording for %.0fs", duration_sec)
while not self._stop_event.is_set():
elapsed = time.monotonic() - t0
if elapsed >= duration_sec:
break
self._frames.append({"t": round(elapsed, 4), "q": [0.0] * G1_NUM_MOTOR})
time.sleep(interval)
self._phase = "returning"
time.sleep(0.5)
def _finalize(self) -> dict[str, Any]:
"""Save frames to JSONL and return result. Idempotent — safe to call twice."""
with self._finalize_lock:
if self._finalized:
return self._final_result or {
"name": self._name, "frames": len(self._frames),
"path": "", "duration_sec": 0,
}
self._phase = "done"
result: dict[str, Any] = {"name": self._name, "frames": len(self._frames)}
if self._frames:
MOTIONS_DIR.mkdir(parents=True, exist_ok=True)
out_path = MOTIONS_DIR / f"{self._name}.jsonl"
# Atomic write: tempfile + os.replace
content_lines = [
json.dumps({"meta": {"hz": REPLAY_HZ, "motors": G1_NUM_MOTOR}}),
]
for frame in self._frames:
content_lines.append(json.dumps(frame))
content = ("\n".join(content_lines) + "\n").encode("utf-8")
fd, tmp = tempfile.mkstemp(
prefix=f".{out_path.name}.", suffix=".tmp",
dir=str(out_path.parent),
)
try:
with os.fdopen(fd, "wb") as f:
f.write(content)
os.replace(tmp, out_path)
except Exception:
try:
os.unlink(tmp)
except OSError:
pass
raise
duration = self._frames[-1]["t"] if self._frames else 0
result["path"] = str(out_path)
result["duration_sec"] = round(duration, 2)
result["size_kb"] = round(out_path.stat().st_size / 1024, 1)
log.info("Teaching saved: %s (%d frames, %.1fs)", out_path.name, len(self._frames), duration)
else:
result["path"] = ""
result["duration_sec"] = 0
self._finalized = True
self._final_result = result
with self._lock:
self._recording = False
self._phase = "idle"
bus.emit_sync("motion.teaching_finished", name=result.get("name"), frames=result.get("frames"))
return result
def status(self) -> dict[str, Any]:
elapsed = time.monotonic() - self._started_at if self._recording else 0
return {
"recording": self._recording,
"phase": self._phase,
"name": self._name,
"elapsed_sec": round(elapsed, 1),
"frames_recorded": len(self._frames),
}