515 lines
21 KiB
Python
515 lines
21 KiB
Python
"""WebRTC robot backend for the Go2 (the default ``webrtc`` transport).
|
|
|
|
Drives a Unitree Go2 over WebRTC using ``unitree_webrtc_connect`` -- the same
|
|
protocol the Unitree Go / Explore apps use, so it works on **AIR / PRO / EDU
|
|
over wifi** and can play the greeting **from the dog's own speaker** (AudioHub),
|
|
which the official DDS SDK cannot do on a Go2.
|
|
|
|
The library is fully asynchronous (asyncio + aiortc); GoWelcome is threaded and
|
|
synchronous. This class bridges the two: it runs the WebRTC connection on a
|
|
dedicated event-loop thread and exposes the synchronous :class:`RobotInterface`.
|
|
|
|
* High-rate commands (drive/gesture/posture) are sent **fire-and-forget** via
|
|
``publish_without_callback`` scheduled with ``loop.call_soon_threadsafe`` --
|
|
the control loop never blocks and no response-future accumulates.
|
|
* Setup steps that need confirmation (connect, mode switch, avoidance enable,
|
|
AudioHub upload) are awaited via ``run_coroutine_threadsafe(...).result()``.
|
|
|
|
Velocity convention (matches the rest of GoWelcome): vx forward+, vy left+,
|
|
vyaw CCW/left+ (rad/s).
|
|
|
|
All heavy deps (``unitree_webrtc_connect``, ``aiortc``, ``cv2``) are imported
|
|
lazily so this module imports fine off-robot.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from gowelcome.robot.interface import GESTURES, RobotInterface
|
|
|
|
if TYPE_CHECKING: # pragma: no cover - type hints only
|
|
from config import GoWelcomeConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# GoWelcome gesture name -> Unitree SPORT_CMD key. (The Go2's heart gesture is
|
|
# the "FingerHeart" action; the rest map by name.)
|
|
_GESTURE_TO_CMD = {
|
|
"hello": "Hello",
|
|
"heart": "FingerHeart",
|
|
"stretch": "Stretch",
|
|
"dance1": "Dance1",
|
|
"dance2": "Dance2",
|
|
"scrape": "Scrape",
|
|
"content": "Content",
|
|
"sit": "Sit",
|
|
"rise_sit": "RiseSit",
|
|
# dog-play actions
|
|
"wiggle": "WiggleHips",
|
|
"wallow": "Wallow",
|
|
"pounce": "FrontPounce",
|
|
}
|
|
|
|
|
|
def _clamp(value: float, lo: float, hi: float) -> float:
|
|
return lo if value < lo else hi if value > hi else value
|
|
|
|
|
|
class Go2WebRTCRobot(RobotInterface):
|
|
"""Drive a Go2 over WebRTC (app protocol) via ``unitree_webrtc_connect``.
|
|
|
|
Args:
|
|
cfg: Fully-populated :class:`~config.GoWelcomeConfig`.
|
|
|
|
Raises:
|
|
ImportError: if ``unitree_webrtc_connect`` / ``cv2`` are not installed.
|
|
ValueError: if the WebRTC connection parameters are incomplete.
|
|
"""
|
|
|
|
def __init__(self, cfg: "GoWelcomeConfig") -> None:
|
|
try:
|
|
from unitree_webrtc_connect.webrtc_driver import UnitreeWebRTCConnection
|
|
from unitree_webrtc_connect.constants import (
|
|
DATA_CHANNEL_TYPE,
|
|
OBSTACLES_AVOID_API,
|
|
RTC_TOPIC,
|
|
SPORT_CMD,
|
|
WebRTCConnectionMethod,
|
|
)
|
|
from unitree_webrtc_connect.webrtc_audiohub import WebRTCAudioHub
|
|
except ImportError as exc: # pragma: no cover - off-robot path
|
|
raise ImportError(
|
|
"Go2WebRTCRobot requires 'unitree_webrtc_connect'. Install with: "
|
|
"pip install unitree_webrtc_connect (plus: sudo apt install "
|
|
"portaudio19-dev). For Go2 firmware >= 1.1.15 you also need the "
|
|
"per-device AES key -- see the repo's examples/fetch_aes_key.py."
|
|
) from exc
|
|
try:
|
|
import cv2
|
|
except ImportError as exc: # pragma: no cover - off-robot path
|
|
raise ImportError(
|
|
"Go2WebRTCRobot requires OpenCV: pip install opencv-python"
|
|
) from exc
|
|
|
|
self.cfg = cfg
|
|
self._cv2 = cv2
|
|
self._RTC_TOPIC = RTC_TOPIC
|
|
self._SPORT_CMD = SPORT_CMD
|
|
self._OA = OBSTACLES_AVOID_API
|
|
self._REQUEST = DATA_CHANNEL_TYPE["REQUEST"]
|
|
|
|
# --- camera frame buffer (written by the video callback) ------------
|
|
self._frame: Optional[np.ndarray] = None
|
|
self._frame_size: Tuple[int, int] = (0, 0)
|
|
self._frame_lock = threading.Lock()
|
|
|
|
# --- state ----------------------------------------------------------
|
|
self._closing = False
|
|
self._shutdown = False
|
|
self._shutdown_lock = threading.Lock()
|
|
self._avoidance_on = False
|
|
self._audiohub = None
|
|
self._greeting_uuid: Optional[str] = None
|
|
self._stream_player = None # keep a ref so the aiortc track isn't GC'd
|
|
self._cmd_id = 0
|
|
self._audio_method = (cfg.webrtc.audio_method or "audiohub").strip().lower()
|
|
|
|
# --- build the connection object ------------------------------------
|
|
self._conn = self._build_connection(cfg, UnitreeWebRTCConnection, WebRTCConnectionMethod)
|
|
|
|
# --- start the asyncio loop on its own daemon thread ----------------
|
|
self._loop = asyncio.new_event_loop()
|
|
self._loop_thread = threading.Thread(
|
|
target=self._run_loop, name="Go2WebRTCLoop", daemon=True
|
|
)
|
|
self._loop_thread.start()
|
|
|
|
# --- connect (blocking) ---------------------------------------------
|
|
self._run(self._conn.connect(), timeout=30.0)
|
|
logger.info("WebRTC connected (%s)", cfg.webrtc.connection_method)
|
|
|
|
# --- ensure normal motion mode --------------------------------------
|
|
try:
|
|
self._run(
|
|
self._conn.datachannel.pub_sub.publish_request_new(
|
|
RTC_TOPIC["MOTION_SWITCHER"],
|
|
{"api_id": 1002, "parameter": {"name": "normal"}},
|
|
),
|
|
timeout=10.0,
|
|
)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning("motion-switcher to 'normal' failed: %s", exc)
|
|
|
|
# --- camera on + frame callback (run on the loop thread; switching the
|
|
# video channel sends on the data channel) -------------------------
|
|
async def _setup_video():
|
|
self._conn.video.switchVideoChannel(True)
|
|
self._conn.video.add_track_callback(self._on_video_track)
|
|
|
|
try:
|
|
self._run(_setup_video(), timeout=8.0)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning("video channel setup failed: %s", exc)
|
|
|
|
# --- audio backend (AudioHub) ---------------------------------------
|
|
try:
|
|
self._audiohub = WebRTCAudioHub(self._conn)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning("AudioHub init failed: %s", exc)
|
|
self._audiohub = None
|
|
if self._audio_method == "audiohub":
|
|
self._prepare_audiohub_greeting()
|
|
|
|
# --- obstacle avoidance ---------------------------------------------
|
|
if cfg.safety.use_lidar_avoidance:
|
|
self.set_avoidance(True)
|
|
|
|
# --- ready posture --------------------------------------------------
|
|
self.balance_stand()
|
|
logger.info(
|
|
"Go2WebRTCRobot ready (audio=%s, avoidance=%s, dry_run=%s)",
|
|
self._audio_method, self._avoidance_on, cfg.dry_run,
|
|
)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Connection construction
|
|
# ------------------------------------------------------------------ #
|
|
def _build_connection(self, cfg, UnitreeWebRTCConnection, WebRTCConnectionMethod):
|
|
wc = cfg.webrtc
|
|
method = (wc.connection_method or "localsta").strip().lower()
|
|
kwargs = {"region": wc.region or "global", "device_type": "Go2"}
|
|
if wc.aes_128_key:
|
|
kwargs["aes_128_key"] = wc.aes_128_key
|
|
if wc.serial_number:
|
|
kwargs["serialNumber"] = wc.serial_number
|
|
|
|
if method == "localsta":
|
|
cm = WebRTCConnectionMethod.LocalSTA
|
|
if wc.ip:
|
|
kwargs["ip"] = wc.ip
|
|
elif not wc.serial_number:
|
|
raise ValueError(
|
|
"WebRTC localsta needs webrtc.ip or webrtc.serial_number "
|
|
"(pass --robot-ip <ip> or --serial <sn>)."
|
|
)
|
|
elif method == "localap":
|
|
cm = WebRTCConnectionMethod.LocalAP
|
|
elif method == "remote":
|
|
cm = WebRTCConnectionMethod.Remote
|
|
if not (wc.serial_number and wc.username and wc.password):
|
|
raise ValueError(
|
|
"WebRTC remote needs webrtc.serial_number + username + password."
|
|
)
|
|
kwargs["username"] = wc.username
|
|
kwargs["password"] = wc.password
|
|
else:
|
|
raise ValueError(f"Unknown webrtc.connection_method: {wc.connection_method!r}")
|
|
|
|
return UnitreeWebRTCConnection(cm, **kwargs)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# asyncio bridge
|
|
# ------------------------------------------------------------------ #
|
|
def _run_loop(self) -> None:
|
|
asyncio.set_event_loop(self._loop)
|
|
self._loop.run_forever()
|
|
|
|
def _run(self, coro, timeout: float = 10.0):
|
|
"""Run ``coro`` on the loop thread and block for its result."""
|
|
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
return fut.result(timeout)
|
|
|
|
def _run_bg(self, coro) -> None:
|
|
"""Fire-and-forget a coroutine on the loop thread (non-blocking)."""
|
|
try:
|
|
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
fut.add_done_callback(
|
|
lambda f: f.exception() and logger.debug("bg coro error: %s", f.exception())
|
|
)
|
|
except Exception as exc: # pragma: no cover
|
|
logger.debug("schedule bg coro failed: %s", exc)
|
|
|
|
def _next_id(self) -> int:
|
|
self._cmd_id = (self._cmd_id + 1) % 2147483000
|
|
return self._cmd_id
|
|
|
|
def _fire(self, topic: str, api_id: int, parameter=None) -> None:
|
|
"""Fire-and-forget a request command from any thread (no reply wait).
|
|
|
|
Builds the same request envelope ``publish_request_new`` uses, then
|
|
sends it via the no-future ``publish_without_callback`` scheduled on the
|
|
loop thread -- so the control loop never blocks and no response-future
|
|
leaks (critical for the high-rate Move path).
|
|
"""
|
|
payload = {
|
|
"header": {"identity": {"id": self._next_id(), "api_id": api_id}},
|
|
"parameter": "",
|
|
}
|
|
if parameter is not None:
|
|
payload["parameter"] = (
|
|
parameter if isinstance(parameter, str) else json.dumps(parameter)
|
|
)
|
|
|
|
def _send():
|
|
try:
|
|
self._conn.datachannel.pub_sub.publish_without_callback(
|
|
topic, payload, self._REQUEST
|
|
)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.debug("publish failed (api_id=%s): %s", api_id, exc)
|
|
|
|
try:
|
|
self._loop.call_soon_threadsafe(_send)
|
|
except Exception as exc: # pragma: no cover
|
|
logger.debug("call_soon_threadsafe failed: %s", exc)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Camera
|
|
# ------------------------------------------------------------------ #
|
|
async def _on_video_track(self, track) -> None:
|
|
"""Receive video frames into the shared buffer until shutdown."""
|
|
cv2 = self._cv2
|
|
target = (self.cfg.camera.width, self.cfg.camera.height)
|
|
while not self._closing:
|
|
try:
|
|
frame = await track.recv()
|
|
except Exception as exc: # pragma: no cover - track ended
|
|
logger.debug("video recv ended: %s", exc)
|
|
break
|
|
try:
|
|
img = frame.to_ndarray(format="bgr24")
|
|
if (img.shape[1], img.shape[0]) != target:
|
|
img = cv2.resize(img, target)
|
|
with self._frame_lock:
|
|
self._frame = img
|
|
self._frame_size = (img.shape[1], img.shape[0])
|
|
except Exception as exc: # pragma: no cover
|
|
logger.debug("video decode failed: %s", exc)
|
|
|
|
def get_frame(self) -> "Optional[np.ndarray]":
|
|
with self._frame_lock:
|
|
return self._frame
|
|
|
|
def frame_size(self) -> "tuple[int, int]":
|
|
with self._frame_lock:
|
|
return self._frame_size
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Locomotion
|
|
# ------------------------------------------------------------------ #
|
|
def drive(self, vx: float, vy: float, vyaw: float) -> None:
|
|
s = self.cfg.safety
|
|
vx = _clamp(vx, -s.max_vx, s.max_vx)
|
|
vy = _clamp(vy, -s.max_vy, s.max_vy)
|
|
vyaw = _clamp(vyaw, -s.max_vyaw, s.max_vyaw)
|
|
if self.cfg.dry_run:
|
|
logger.info(
|
|
"[dry_run] drive intent vx=%.3f vy=%.3f vyaw=%.3f (sending 0)",
|
|
vx, vy, vyaw,
|
|
)
|
|
vx = vy = vyaw = 0.0
|
|
if self._avoidance_on:
|
|
self._fire(self._RTC_TOPIC["OBSTACLES_AVOID"], self._OA["MOVE"],
|
|
{"x": vx, "y": vy, "yaw": vyaw, "mode": 0})
|
|
else:
|
|
self._fire(self._RTC_TOPIC["SPORT_MOD"], self._SPORT_CMD["Move"],
|
|
{"x": vx, "y": vy, "z": vyaw})
|
|
|
|
def stop(self) -> None:
|
|
if self._avoidance_on:
|
|
self._fire(self._RTC_TOPIC["OBSTACLES_AVOID"], self._OA["MOVE"],
|
|
{"x": 0.0, "y": 0.0, "yaw": 0.0, "mode": 0})
|
|
else:
|
|
self._fire(self._RTC_TOPIC["SPORT_MOD"], self._SPORT_CMD["StopMove"])
|
|
|
|
def set_avoidance(self, on: bool) -> None:
|
|
oa = self._RTC_TOPIC["OBSTACLES_AVOID"]
|
|
try:
|
|
if on:
|
|
self._run(self._conn.datachannel.pub_sub.publish_request_new(
|
|
oa, {"api_id": self._OA["SWITCH_SET"],
|
|
"parameter": {"enable": True}}), timeout=8.0)
|
|
self._run(self._conn.datachannel.pub_sub.publish_request_new(
|
|
oa, {"api_id": self._OA["USE_REMOTE_COMMAND_FROM_API"],
|
|
"parameter": {"is_remote_commands_from_api": True}}), timeout=8.0)
|
|
self._avoidance_on = True
|
|
logger.info("LiDAR avoidance enabled (WebRTC)")
|
|
else:
|
|
self._run(self._conn.datachannel.pub_sub.publish_request_new(
|
|
oa, {"api_id": self._OA["USE_REMOTE_COMMAND_FROM_API"],
|
|
"parameter": {"is_remote_commands_from_api": False}}), timeout=8.0)
|
|
self._run(self._conn.datachannel.pub_sub.publish_request_new(
|
|
oa, {"api_id": self._OA["SWITCH_SET"],
|
|
"parameter": {"enable": False}}), timeout=8.0)
|
|
self._avoidance_on = False
|
|
logger.info("LiDAR avoidance disabled (WebRTC)")
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning(
|
|
"set_avoidance(%s) failed: %s -- falling back to sport Move path",
|
|
on, exc,
|
|
)
|
|
self._avoidance_on = False
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Posture / gestures
|
|
# ------------------------------------------------------------------ #
|
|
def _sport(self, cmd_name: str) -> None:
|
|
api = self._SPORT_CMD.get(cmd_name)
|
|
if api is None:
|
|
logger.warning("unknown sport command %r", cmd_name)
|
|
return
|
|
self._fire(self._RTC_TOPIC["SPORT_MOD"], api)
|
|
|
|
def balance_stand(self) -> None:
|
|
self._sport("BalanceStand")
|
|
|
|
def stand_up(self) -> None:
|
|
self._sport("StandUp")
|
|
|
|
def damp(self) -> None:
|
|
self._sport("Damp")
|
|
|
|
def gesture(self, name: str) -> None:
|
|
cmd = _GESTURE_TO_CMD.get(name)
|
|
if cmd is None:
|
|
logger.warning("unknown gesture %r (known: %s)", name, ", ".join(GESTURES))
|
|
return
|
|
self._sport(cmd)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Greeting audio
|
|
# ------------------------------------------------------------------ #
|
|
def _prepare_audiohub_greeting(self) -> None:
|
|
"""Upload greeting.wav to the robot and resolve its AudioHub uuid."""
|
|
if self._audiohub is None:
|
|
return
|
|
path = self.cfg.greet.wav_path
|
|
if not os.path.isfile(path):
|
|
logger.warning("greeting wav not found for AudioHub upload: %s", path)
|
|
return
|
|
name = os.path.splitext(os.path.basename(path))[0]
|
|
try:
|
|
self._run(self._audiohub.upload_audio_file(path), timeout=60.0)
|
|
resp = self._run(self._audiohub.get_audio_list(), timeout=10.0)
|
|
self._greeting_uuid = self._find_uuid(resp, name)
|
|
if self._greeting_uuid:
|
|
logger.info("AudioHub greeting ready (uuid=%s)", self._greeting_uuid)
|
|
else:
|
|
logger.warning(
|
|
"uploaded greeting but could not resolve its uuid from the list"
|
|
)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning("AudioHub greeting prepare failed: %s", exc)
|
|
|
|
@staticmethod
|
|
def _find_uuid(resp, name: str) -> Optional[str]:
|
|
"""Best-effort extraction of an audio entry's uuid from a list response.
|
|
|
|
The exact shape of the AudioHub list response is firmware-dependent, so
|
|
this walks common shapes: ``resp['data']['data']`` (often a JSON string)
|
|
containing a list of ``{name/file_name, unique_id/uuid/id}`` records.
|
|
"""
|
|
try:
|
|
data = resp.get("data") if isinstance(resp, dict) else None
|
|
blob = data.get("data") if isinstance(data, dict) else data
|
|
if isinstance(blob, str):
|
|
blob = json.loads(blob)
|
|
items = []
|
|
if isinstance(blob, dict):
|
|
for v in blob.values():
|
|
if isinstance(v, list):
|
|
items = v
|
|
break
|
|
elif isinstance(blob, list):
|
|
items = blob
|
|
for it in items:
|
|
if not isinstance(it, dict):
|
|
continue
|
|
nm = it.get("name") or it.get("file_name") or it.get("custom_name")
|
|
uid = it.get("unique_id") or it.get("uuid") or it.get("id")
|
|
if uid and (nm == name or (isinstance(nm, str) and name in nm)):
|
|
return uid
|
|
if items and isinstance(items[-1], dict):
|
|
last = items[-1]
|
|
return last.get("unique_id") or last.get("uuid") or last.get("id")
|
|
except Exception as exc: # pragma: no cover
|
|
logger.debug("uuid parse failed: %s", exc)
|
|
return None
|
|
|
|
def play_greeting(self) -> None:
|
|
if self._audio_method == "stream":
|
|
self._play_stream()
|
|
return
|
|
if self._audiohub is not None and self._greeting_uuid:
|
|
self._run_bg(self._audiohub.play_by_uuid(self._greeting_uuid))
|
|
else:
|
|
logger.warning(
|
|
"AudioHub greeting unavailable (uuid=%s); no audio played",
|
|
self._greeting_uuid,
|
|
)
|
|
|
|
def _play_stream(self) -> None:
|
|
"""Stream the greeting wav live via an aiortc MediaPlayer track."""
|
|
path = self.cfg.greet.wav_path
|
|
if not os.path.isfile(path):
|
|
logger.warning("greeting wav not found: %s", path)
|
|
return
|
|
|
|
def _add():
|
|
try:
|
|
from aiortc.contrib.media import MediaPlayer
|
|
player = MediaPlayer(path)
|
|
self._stream_player = player # keep a reference alive
|
|
self._conn.pc.addTrack(player.audio)
|
|
except Exception as exc: # pragma: no cover - hardware path
|
|
logger.warning("stream greeting failed: %s", exc)
|
|
|
|
self._loop.call_soon_threadsafe(_add)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Lifecycle
|
|
# ------------------------------------------------------------------ #
|
|
def shutdown(self) -> None:
|
|
"""Stop motion, release avoidance, disconnect, stop the loop. Idempotent."""
|
|
with self._shutdown_lock:
|
|
if self._shutdown:
|
|
return
|
|
self._shutdown = True
|
|
self._closing = True
|
|
|
|
try:
|
|
self.stop()
|
|
except Exception as exc: # pragma: no cover
|
|
logger.debug("stop() on shutdown failed: %s", exc)
|
|
|
|
if self._avoidance_on:
|
|
try:
|
|
self.set_avoidance(False)
|
|
except Exception as exc: # pragma: no cover
|
|
logger.warning("avoidance release on shutdown failed: %s", exc)
|
|
|
|
try:
|
|
self._run(self._conn.disconnect(), timeout=8.0)
|
|
except Exception as exc: # pragma: no cover
|
|
logger.warning("WebRTC disconnect failed: %s", exc)
|
|
|
|
try:
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
except Exception: # pragma: no cover
|
|
pass
|
|
try:
|
|
if self._loop_thread.is_alive():
|
|
self._loop_thread.join(timeout=3.0)
|
|
except Exception: # pragma: no cover
|
|
pass
|
|
|
|
logger.info("Go2WebRTCRobot shutdown complete")
|