from __future__ import annotations import time from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple import numpy as np def _safe_float(v: Any, default: float) -> float: try: return float(v) except Exception: return float(default) def _as_bool(v: Any, default: bool) -> bool: if v is None: return default if isinstance(v, bool): return v if isinstance(v, (int, float)): return bool(v) return str(v).strip().lower() in ("1", "true", "yes", "on") @dataclass class FusionConfig: enabled: bool = False max_age_sec: float = 0.30 max_future_sec: float = 0.10 max_timestamp_skew_sec: float = 0.50 lidar_weight: float = 1.0 imu_weight: float = 0.35 wheel_weight: float = 0.45 vision_weight: float = 0.70 use_translation: bool = True use_rotation: bool = True @staticmethod def from_dict(d: Dict[str, Any] | None) -> "FusionConfig": src = d or {} return FusionConfig( enabled=_as_bool(src.get("enabled", False), False), max_age_sec=max(0.01, _safe_float(src.get("max_age_sec", 0.30), 0.30)), max_future_sec=max(0.0, _safe_float(src.get("max_future_sec", 0.10), 0.10)), max_timestamp_skew_sec=max(0.01, _safe_float(src.get("max_timestamp_skew_sec", 0.50), 0.50)), lidar_weight=max(0.0, _safe_float(src.get("lidar_weight", 1.0), 1.0)), imu_weight=max(0.0, _safe_float(src.get("imu_weight", 0.35), 0.35)), wheel_weight=max(0.0, _safe_float(src.get("wheel_weight", 0.45), 0.45)), vision_weight=max(0.0, _safe_float(src.get("vision_weight", 0.70), 0.70)), use_translation=_as_bool(src.get("use_translation", True), True), use_rotation=_as_bool(src.get("use_rotation", True), True), ) def _rot_to_quat(r: np.ndarray) -> np.ndarray: tr = float(np.trace(r)) if tr > 0.0: s = np.sqrt(tr + 1.0) * 2.0 w = 0.25 * s x = (r[2, 1] - r[1, 2]) / s y = (r[0, 2] - r[2, 0]) / s z = (r[1, 0] - r[0, 1]) / s elif (r[0, 0] > r[1, 1]) and (r[0, 0] > r[2, 2]): s = np.sqrt(1.0 + r[0, 0] - r[1, 1] - r[2, 2]) * 2.0 w = (r[2, 1] - r[1, 2]) / s x = 0.25 * s y = (r[0, 1] + r[1, 0]) / s z = (r[0, 2] + r[2, 0]) / s elif r[1, 1] > r[2, 2]: s = np.sqrt(1.0 + r[1, 1] - r[0, 0] - r[2, 2]) * 2.0 w = (r[0, 2] - r[2, 0]) / s x = (r[0, 1] + r[1, 0]) / s y = 0.25 * s z = (r[1, 2] + r[2, 1]) / s else: s = np.sqrt(1.0 + r[2, 2] - r[0, 0] - r[1, 1]) * 2.0 w = (r[1, 0] - r[0, 1]) / s x = (r[0, 2] + r[2, 0]) / s y = (r[1, 2] + r[2, 1]) / s z = 0.25 * s q = np.array([w, x, y, z], dtype=np.float64) n = float(np.linalg.norm(q)) if n <= 1e-12: return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) return q / n def _quat_to_rot(q: np.ndarray) -> np.ndarray: w, x, y, z = q return np.array( [ [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], ], dtype=np.float64, ) class SensorPoseFusion: def __init__(self, cfg: FusionConfig): self.cfg = cfg self._priors: Dict[str, Dict[str, Any]] = {} self._last_lidar_t = 0.0 def reset(self) -> None: self._priors.clear() self._last_lidar_t = 0.0 def _sensor_weight(self, sensor: str) -> float: s = str(sensor).lower().strip() if s == "imu": return float(self.cfg.imu_weight) if s in ("wheel", "odom", "odometry"): return float(self.cfg.wheel_weight) if s in ("vision", "camera", "vio"): return float(self.cfg.vision_weight) return 0.0 def update_prior(self, sensor: str, pose: Any, confidence: float = 1.0, timestamp: Optional[float] = None) -> bool: arr = np.asarray(pose, dtype=np.float64) if arr.shape != (4, 4): return False t = float(time.time()) if timestamp is None else float(timestamp) if (t - float(time.time())) > float(self.cfg.max_future_sec): return False conf = float(np.clip(confidence, 0.0, 1.0)) self._priors[str(sensor).lower().strip()] = { "pose": arr, "confidence": conf, "timestamp": t, } return True def fuse_pose(self, lidar_pose: np.ndarray, now: Optional[float] = None) -> Tuple[np.ndarray, Dict[str, Any]]: pose_lidar = np.asarray(lidar_pose, dtype=np.float64) if pose_lidar.shape != (4, 4): return pose_lidar, {"used": [], "enabled": bool(self.cfg.enabled)} t_now = float(time.time()) if now is None else float(now) if not self.cfg.enabled: self._last_lidar_t = t_now return pose_lidar, {"used": [], "enabled": False} used = [] weights = [float(self.cfg.lidar_weight)] poses = [pose_lidar] rejected = [] for sensor, ent in self._priors.items(): ts = float(ent.get("timestamp", 0.0)) age = t_now - ts if age < -float(self.cfg.max_future_sec): rejected.append(f"{sensor}:future") continue if self._last_lidar_t > 0.0 and abs(ts - self._last_lidar_t) > float(self.cfg.max_timestamp_skew_sec): rejected.append(f"{sensor}:skew") continue if age > float(self.cfg.max_age_sec): rejected.append(f"{sensor}:stale") continue w = self._sensor_weight(sensor) * float(np.clip(ent.get("confidence", 1.0), 0.0, 1.0)) if w <= 0.0: rejected.append(f"{sensor}:weight") continue p = np.asarray(ent.get("pose"), dtype=np.float64) if p.shape != (4, 4): rejected.append(f"{sensor}:pose") continue poses.append(p) weights.append(w) used.append(sensor) if len(poses) <= 1: self._last_lidar_t = t_now return pose_lidar, {"used": [], "rejected": rejected, "enabled": True} w_arr = np.asarray(weights, dtype=np.float64) w_arr = w_arr / max(np.sum(w_arr), 1e-9) out = np.array(pose_lidar, dtype=np.float64, copy=True) if self.cfg.use_translation: t = np.zeros((3,), dtype=np.float64) for i, p in enumerate(poses): t += w_arr[i] * p[:3, 3] out[:3, 3] = t if self.cfg.use_rotation: q_ref = _rot_to_quat(poses[0][:3, :3]) q_sum = np.zeros((4,), dtype=np.float64) for i, p in enumerate(poses): q = _rot_to_quat(p[:3, :3]) if np.dot(q, q_ref) < 0: q = -q q_sum += w_arr[i] * q qn = float(np.linalg.norm(q_sum)) if qn > 1e-12: q_sum /= qn out[:3, :3] = _quat_to_rot(q_sum) self._last_lidar_t = t_now return out, {"used": used, "rejected": rejected, "enabled": True}