GoWelcome/gowelcome/robot/webrtc_robot.py

515 lines
21 KiB
Python

"""WebRTC robot backend for the Go2 (the default ``webrtc`` transport).
Drives a Unitree Go2 over WebRTC using ``unitree_webrtc_connect`` -- the same
protocol the Unitree Go / Explore apps use, so it works on **AIR / PRO / EDU
over wifi** and can play the greeting **from the dog's own speaker** (AudioHub),
which the official DDS SDK cannot do on a Go2.
The library is fully asynchronous (asyncio + aiortc); GoWelcome is threaded and
synchronous. This class bridges the two: it runs the WebRTC connection on a
dedicated event-loop thread and exposes the synchronous :class:`RobotInterface`.
* High-rate commands (drive/gesture/posture) are sent **fire-and-forget** via
``publish_without_callback`` scheduled with ``loop.call_soon_threadsafe`` --
the control loop never blocks and no response-future accumulates.
* Setup steps that need confirmation (connect, mode switch, avoidance enable,
AudioHub upload) are awaited via ``run_coroutine_threadsafe(...).result()``.
Velocity convention (matches the rest of GoWelcome): vx forward+, vy left+,
vyaw CCW/left+ (rad/s).
All heavy deps (``unitree_webrtc_connect``, ``aiortc``, ``cv2``) are imported
lazily so this module imports fine off-robot.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import threading
from typing import TYPE_CHECKING, Optional, Tuple
import numpy as np
from gowelcome.robot.interface import GESTURES, RobotInterface
if TYPE_CHECKING: # pragma: no cover - type hints only
from config import GoWelcomeConfig
logger = logging.getLogger(__name__)
# GoWelcome gesture name -> Unitree SPORT_CMD key. (The Go2's heart gesture is
# the "FingerHeart" action; the rest map by name.)
_GESTURE_TO_CMD = {
"hello": "Hello",
"heart": "FingerHeart",
"stretch": "Stretch",
"dance1": "Dance1",
"dance2": "Dance2",
"scrape": "Scrape",
"content": "Content",
"sit": "Sit",
"rise_sit": "RiseSit",
# dog-play actions
"wiggle": "WiggleHips",
"wallow": "Wallow",
"pounce": "FrontPounce",
}
def _clamp(value: float, lo: float, hi: float) -> float:
return lo if value < lo else hi if value > hi else value
class Go2WebRTCRobot(RobotInterface):
"""Drive a Go2 over WebRTC (app protocol) via ``unitree_webrtc_connect``.
Args:
cfg: Fully-populated :class:`~config.GoWelcomeConfig`.
Raises:
ImportError: if ``unitree_webrtc_connect`` / ``cv2`` are not installed.
ValueError: if the WebRTC connection parameters are incomplete.
"""
def __init__(self, cfg: "GoWelcomeConfig") -> None:
try:
from unitree_webrtc_connect.webrtc_driver import UnitreeWebRTCConnection
from unitree_webrtc_connect.constants import (
DATA_CHANNEL_TYPE,
OBSTACLES_AVOID_API,
RTC_TOPIC,
SPORT_CMD,
WebRTCConnectionMethod,
)
from unitree_webrtc_connect.webrtc_audiohub import WebRTCAudioHub
except ImportError as exc: # pragma: no cover - off-robot path
raise ImportError(
"Go2WebRTCRobot requires 'unitree_webrtc_connect'. Install with: "
"pip install unitree_webrtc_connect (plus: sudo apt install "
"portaudio19-dev). For Go2 firmware >= 1.1.15 you also need the "
"per-device AES key -- see the repo's examples/fetch_aes_key.py."
) from exc
try:
import cv2
except ImportError as exc: # pragma: no cover - off-robot path
raise ImportError(
"Go2WebRTCRobot requires OpenCV: pip install opencv-python"
) from exc
self.cfg = cfg
self._cv2 = cv2
self._RTC_TOPIC = RTC_TOPIC
self._SPORT_CMD = SPORT_CMD
self._OA = OBSTACLES_AVOID_API
self._REQUEST = DATA_CHANNEL_TYPE["REQUEST"]
# --- camera frame buffer (written by the video callback) ------------
self._frame: Optional[np.ndarray] = None
self._frame_size: Tuple[int, int] = (0, 0)
self._frame_lock = threading.Lock()
# --- state ----------------------------------------------------------
self._closing = False
self._shutdown = False
self._shutdown_lock = threading.Lock()
self._avoidance_on = False
self._audiohub = None
self._greeting_uuid: Optional[str] = None
self._stream_player = None # keep a ref so the aiortc track isn't GC'd
self._cmd_id = 0
self._audio_method = (cfg.webrtc.audio_method or "audiohub").strip().lower()
# --- build the connection object ------------------------------------
self._conn = self._build_connection(cfg, UnitreeWebRTCConnection, WebRTCConnectionMethod)
# --- start the asyncio loop on its own daemon thread ----------------
self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
target=self._run_loop, name="Go2WebRTCLoop", daemon=True
)
self._loop_thread.start()
# --- connect (blocking) ---------------------------------------------
self._run(self._conn.connect(), timeout=30.0)
logger.info("WebRTC connected (%s)", cfg.webrtc.connection_method)
# --- ensure normal motion mode --------------------------------------
try:
self._run(
self._conn.datachannel.pub_sub.publish_request_new(
RTC_TOPIC["MOTION_SWITCHER"],
{"api_id": 1002, "parameter": {"name": "normal"}},
),
timeout=10.0,
)
except Exception as exc: # pragma: no cover - hardware path
logger.warning("motion-switcher to 'normal' failed: %s", exc)
# --- camera on + frame callback (run on the loop thread; switching the
# video channel sends on the data channel) -------------------------
async def _setup_video():
self._conn.video.switchVideoChannel(True)
self._conn.video.add_track_callback(self._on_video_track)
try:
self._run(_setup_video(), timeout=8.0)
except Exception as exc: # pragma: no cover - hardware path
logger.warning("video channel setup failed: %s", exc)
# --- audio backend (AudioHub) ---------------------------------------
try:
self._audiohub = WebRTCAudioHub(self._conn)
except Exception as exc: # pragma: no cover - hardware path
logger.warning("AudioHub init failed: %s", exc)
self._audiohub = None
if self._audio_method == "audiohub":
self._prepare_audiohub_greeting()
# --- obstacle avoidance ---------------------------------------------
if cfg.safety.use_lidar_avoidance:
self.set_avoidance(True)
# --- ready posture --------------------------------------------------
self.balance_stand()
logger.info(
"Go2WebRTCRobot ready (audio=%s, avoidance=%s, dry_run=%s)",
self._audio_method, self._avoidance_on, cfg.dry_run,
)
# ------------------------------------------------------------------ #
# Connection construction
# ------------------------------------------------------------------ #
def _build_connection(self, cfg, UnitreeWebRTCConnection, WebRTCConnectionMethod):
wc = cfg.webrtc
method = (wc.connection_method or "localsta").strip().lower()
kwargs = {"region": wc.region or "global", "device_type": "Go2"}
if wc.aes_128_key:
kwargs["aes_128_key"] = wc.aes_128_key
if wc.serial_number:
kwargs["serialNumber"] = wc.serial_number
if method == "localsta":
cm = WebRTCConnectionMethod.LocalSTA
if wc.ip:
kwargs["ip"] = wc.ip
elif not wc.serial_number:
raise ValueError(
"WebRTC localsta needs webrtc.ip or webrtc.serial_number "
"(pass --robot-ip <ip> or --serial <sn>)."
)
elif method == "localap":
cm = WebRTCConnectionMethod.LocalAP
elif method == "remote":
cm = WebRTCConnectionMethod.Remote
if not (wc.serial_number and wc.username and wc.password):
raise ValueError(
"WebRTC remote needs webrtc.serial_number + username + password."
)
kwargs["username"] = wc.username
kwargs["password"] = wc.password
else:
raise ValueError(f"Unknown webrtc.connection_method: {wc.connection_method!r}")
return UnitreeWebRTCConnection(cm, **kwargs)
# ------------------------------------------------------------------ #
# asyncio bridge
# ------------------------------------------------------------------ #
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _run(self, coro, timeout: float = 10.0):
"""Run ``coro`` on the loop thread and block for its result."""
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
return fut.result(timeout)
def _run_bg(self, coro) -> None:
"""Fire-and-forget a coroutine on the loop thread (non-blocking)."""
try:
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
fut.add_done_callback(
lambda f: f.exception() and logger.debug("bg coro error: %s", f.exception())
)
except Exception as exc: # pragma: no cover
logger.debug("schedule bg coro failed: %s", exc)
def _next_id(self) -> int:
self._cmd_id = (self._cmd_id + 1) % 2147483000
return self._cmd_id
def _fire(self, topic: str, api_id: int, parameter=None) -> None:
"""Fire-and-forget a request command from any thread (no reply wait).
Builds the same request envelope ``publish_request_new`` uses, then
sends it via the no-future ``publish_without_callback`` scheduled on the
loop thread -- so the control loop never blocks and no response-future
leaks (critical for the high-rate Move path).
"""
payload = {
"header": {"identity": {"id": self._next_id(), "api_id": api_id}},
"parameter": "",
}
if parameter is not None:
payload["parameter"] = (
parameter if isinstance(parameter, str) else json.dumps(parameter)
)
def _send():
try:
self._conn.datachannel.pub_sub.publish_without_callback(
topic, payload, self._REQUEST
)
except Exception as exc: # pragma: no cover - hardware path
logger.debug("publish failed (api_id=%s): %s", api_id, exc)
try:
self._loop.call_soon_threadsafe(_send)
except Exception as exc: # pragma: no cover
logger.debug("call_soon_threadsafe failed: %s", exc)
# ------------------------------------------------------------------ #
# Camera
# ------------------------------------------------------------------ #
async def _on_video_track(self, track) -> None:
"""Receive video frames into the shared buffer until shutdown."""
cv2 = self._cv2
target = (self.cfg.camera.width, self.cfg.camera.height)
while not self._closing:
try:
frame = await track.recv()
except Exception as exc: # pragma: no cover - track ended
logger.debug("video recv ended: %s", exc)
break
try:
img = frame.to_ndarray(format="bgr24")
if (img.shape[1], img.shape[0]) != target:
img = cv2.resize(img, target)
with self._frame_lock:
self._frame = img
self._frame_size = (img.shape[1], img.shape[0])
except Exception as exc: # pragma: no cover
logger.debug("video decode failed: %s", exc)
def get_frame(self) -> "Optional[np.ndarray]":
with self._frame_lock:
return self._frame
def frame_size(self) -> "tuple[int, int]":
with self._frame_lock:
return self._frame_size
# ------------------------------------------------------------------ #
# Locomotion
# ------------------------------------------------------------------ #
def drive(self, vx: float, vy: float, vyaw: float) -> None:
s = self.cfg.safety
vx = _clamp(vx, -s.max_vx, s.max_vx)
vy = _clamp(vy, -s.max_vy, s.max_vy)
vyaw = _clamp(vyaw, -s.max_vyaw, s.max_vyaw)
if self.cfg.dry_run:
logger.info(
"[dry_run] drive intent vx=%.3f vy=%.3f vyaw=%.3f (sending 0)",
vx, vy, vyaw,
)
vx = vy = vyaw = 0.0
if self._avoidance_on:
self._fire(self._RTC_TOPIC["OBSTACLES_AVOID"], self._OA["MOVE"],
{"x": vx, "y": vy, "yaw": vyaw, "mode": 0})
else:
self._fire(self._RTC_TOPIC["SPORT_MOD"], self._SPORT_CMD["Move"],
{"x": vx, "y": vy, "z": vyaw})
def stop(self) -> None:
if self._avoidance_on:
self._fire(self._RTC_TOPIC["OBSTACLES_AVOID"], self._OA["MOVE"],
{"x": 0.0, "y": 0.0, "yaw": 0.0, "mode": 0})
else:
self._fire(self._RTC_TOPIC["SPORT_MOD"], self._SPORT_CMD["StopMove"])
def set_avoidance(self, on: bool) -> None:
oa = self._RTC_TOPIC["OBSTACLES_AVOID"]
try:
if on:
self._run(self._conn.datachannel.pub_sub.publish_request_new(
oa, {"api_id": self._OA["SWITCH_SET"],
"parameter": {"enable": True}}), timeout=8.0)
self._run(self._conn.datachannel.pub_sub.publish_request_new(
oa, {"api_id": self._OA["USE_REMOTE_COMMAND_FROM_API"],
"parameter": {"is_remote_commands_from_api": True}}), timeout=8.0)
self._avoidance_on = True
logger.info("LiDAR avoidance enabled (WebRTC)")
else:
self._run(self._conn.datachannel.pub_sub.publish_request_new(
oa, {"api_id": self._OA["USE_REMOTE_COMMAND_FROM_API"],
"parameter": {"is_remote_commands_from_api": False}}), timeout=8.0)
self._run(self._conn.datachannel.pub_sub.publish_request_new(
oa, {"api_id": self._OA["SWITCH_SET"],
"parameter": {"enable": False}}), timeout=8.0)
self._avoidance_on = False
logger.info("LiDAR avoidance disabled (WebRTC)")
except Exception as exc: # pragma: no cover - hardware path
logger.warning(
"set_avoidance(%s) failed: %s -- falling back to sport Move path",
on, exc,
)
self._avoidance_on = False
# ------------------------------------------------------------------ #
# Posture / gestures
# ------------------------------------------------------------------ #
def _sport(self, cmd_name: str) -> None:
api = self._SPORT_CMD.get(cmd_name)
if api is None:
logger.warning("unknown sport command %r", cmd_name)
return
self._fire(self._RTC_TOPIC["SPORT_MOD"], api)
def balance_stand(self) -> None:
self._sport("BalanceStand")
def stand_up(self) -> None:
self._sport("StandUp")
def damp(self) -> None:
self._sport("Damp")
def gesture(self, name: str) -> None:
cmd = _GESTURE_TO_CMD.get(name)
if cmd is None:
logger.warning("unknown gesture %r (known: %s)", name, ", ".join(GESTURES))
return
self._sport(cmd)
# ------------------------------------------------------------------ #
# Greeting audio
# ------------------------------------------------------------------ #
def _prepare_audiohub_greeting(self) -> None:
"""Upload greeting.wav to the robot and resolve its AudioHub uuid."""
if self._audiohub is None:
return
path = self.cfg.greet.wav_path
if not os.path.isfile(path):
logger.warning("greeting wav not found for AudioHub upload: %s", path)
return
name = os.path.splitext(os.path.basename(path))[0]
try:
self._run(self._audiohub.upload_audio_file(path), timeout=60.0)
resp = self._run(self._audiohub.get_audio_list(), timeout=10.0)
self._greeting_uuid = self._find_uuid(resp, name)
if self._greeting_uuid:
logger.info("AudioHub greeting ready (uuid=%s)", self._greeting_uuid)
else:
logger.warning(
"uploaded greeting but could not resolve its uuid from the list"
)
except Exception as exc: # pragma: no cover - hardware path
logger.warning("AudioHub greeting prepare failed: %s", exc)
@staticmethod
def _find_uuid(resp, name: str) -> Optional[str]:
"""Best-effort extraction of an audio entry's uuid from a list response.
The exact shape of the AudioHub list response is firmware-dependent, so
this walks common shapes: ``resp['data']['data']`` (often a JSON string)
containing a list of ``{name/file_name, unique_id/uuid/id}`` records.
"""
try:
data = resp.get("data") if isinstance(resp, dict) else None
blob = data.get("data") if isinstance(data, dict) else data
if isinstance(blob, str):
blob = json.loads(blob)
items = []
if isinstance(blob, dict):
for v in blob.values():
if isinstance(v, list):
items = v
break
elif isinstance(blob, list):
items = blob
for it in items:
if not isinstance(it, dict):
continue
nm = it.get("name") or it.get("file_name") or it.get("custom_name")
uid = it.get("unique_id") or it.get("uuid") or it.get("id")
if uid and (nm == name or (isinstance(nm, str) and name in nm)):
return uid
if items and isinstance(items[-1], dict):
last = items[-1]
return last.get("unique_id") or last.get("uuid") or last.get("id")
except Exception as exc: # pragma: no cover
logger.debug("uuid parse failed: %s", exc)
return None
def play_greeting(self) -> None:
if self._audio_method == "stream":
self._play_stream()
return
if self._audiohub is not None and self._greeting_uuid:
self._run_bg(self._audiohub.play_by_uuid(self._greeting_uuid))
else:
logger.warning(
"AudioHub greeting unavailable (uuid=%s); no audio played",
self._greeting_uuid,
)
def _play_stream(self) -> None:
"""Stream the greeting wav live via an aiortc MediaPlayer track."""
path = self.cfg.greet.wav_path
if not os.path.isfile(path):
logger.warning("greeting wav not found: %s", path)
return
def _add():
try:
from aiortc.contrib.media import MediaPlayer
player = MediaPlayer(path)
self._stream_player = player # keep a reference alive
self._conn.pc.addTrack(player.audio)
except Exception as exc: # pragma: no cover - hardware path
logger.warning("stream greeting failed: %s", exc)
self._loop.call_soon_threadsafe(_add)
# ------------------------------------------------------------------ #
# Lifecycle
# ------------------------------------------------------------------ #
def shutdown(self) -> None:
"""Stop motion, release avoidance, disconnect, stop the loop. Idempotent."""
with self._shutdown_lock:
if self._shutdown:
return
self._shutdown = True
self._closing = True
try:
self.stop()
except Exception as exc: # pragma: no cover
logger.debug("stop() on shutdown failed: %s", exc)
if self._avoidance_on:
try:
self.set_avoidance(False)
except Exception as exc: # pragma: no cover
logger.warning("avoidance release on shutdown failed: %s", exc)
try:
self._run(self._conn.disconnect(), timeout=8.0)
except Exception as exc: # pragma: no cover
logger.warning("WebRTC disconnect failed: %s", exc)
try:
self._loop.call_soon_threadsafe(self._loop.stop)
except Exception: # pragma: no cover
pass
try:
if self._loop_thread.is_alive():
self._loop_thread.join(timeout=3.0)
except Exception: # pragma: no cover
pass
logger.info("Go2WebRTCRobot shutdown complete")