206 lines
7.2 KiB
Python
206 lines
7.2 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
|
|
def _safe_float(v: Any, default: float) -> float:
|
|
try:
|
|
return float(v)
|
|
except Exception:
|
|
return float(default)
|
|
|
|
|
|
def _as_bool(v: Any, default: bool) -> bool:
|
|
if v is None:
|
|
return default
|
|
if isinstance(v, bool):
|
|
return v
|
|
if isinstance(v, (int, float)):
|
|
return bool(v)
|
|
return str(v).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
@dataclass
|
|
class FusionConfig:
|
|
enabled: bool = False
|
|
max_age_sec: float = 0.30
|
|
max_future_sec: float = 0.10
|
|
max_timestamp_skew_sec: float = 0.50
|
|
lidar_weight: float = 1.0
|
|
imu_weight: float = 0.35
|
|
wheel_weight: float = 0.45
|
|
vision_weight: float = 0.70
|
|
use_translation: bool = True
|
|
use_rotation: bool = True
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any] | None) -> "FusionConfig":
|
|
src = d or {}
|
|
return FusionConfig(
|
|
enabled=_as_bool(src.get("enabled", False), False),
|
|
max_age_sec=max(0.01, _safe_float(src.get("max_age_sec", 0.30), 0.30)),
|
|
max_future_sec=max(0.0, _safe_float(src.get("max_future_sec", 0.10), 0.10)),
|
|
max_timestamp_skew_sec=max(0.01, _safe_float(src.get("max_timestamp_skew_sec", 0.50), 0.50)),
|
|
lidar_weight=max(0.0, _safe_float(src.get("lidar_weight", 1.0), 1.0)),
|
|
imu_weight=max(0.0, _safe_float(src.get("imu_weight", 0.35), 0.35)),
|
|
wheel_weight=max(0.0, _safe_float(src.get("wheel_weight", 0.45), 0.45)),
|
|
vision_weight=max(0.0, _safe_float(src.get("vision_weight", 0.70), 0.70)),
|
|
use_translation=_as_bool(src.get("use_translation", True), True),
|
|
use_rotation=_as_bool(src.get("use_rotation", True), True),
|
|
)
|
|
|
|
|
|
def _rot_to_quat(r: np.ndarray) -> np.ndarray:
|
|
tr = float(np.trace(r))
|
|
if tr > 0.0:
|
|
s = np.sqrt(tr + 1.0) * 2.0
|
|
w = 0.25 * s
|
|
x = (r[2, 1] - r[1, 2]) / s
|
|
y = (r[0, 2] - r[2, 0]) / s
|
|
z = (r[1, 0] - r[0, 1]) / s
|
|
elif (r[0, 0] > r[1, 1]) and (r[0, 0] > r[2, 2]):
|
|
s = np.sqrt(1.0 + r[0, 0] - r[1, 1] - r[2, 2]) * 2.0
|
|
w = (r[2, 1] - r[1, 2]) / s
|
|
x = 0.25 * s
|
|
y = (r[0, 1] + r[1, 0]) / s
|
|
z = (r[0, 2] + r[2, 0]) / s
|
|
elif r[1, 1] > r[2, 2]:
|
|
s = np.sqrt(1.0 + r[1, 1] - r[0, 0] - r[2, 2]) * 2.0
|
|
w = (r[0, 2] - r[2, 0]) / s
|
|
x = (r[0, 1] + r[1, 0]) / s
|
|
y = 0.25 * s
|
|
z = (r[1, 2] + r[2, 1]) / s
|
|
else:
|
|
s = np.sqrt(1.0 + r[2, 2] - r[0, 0] - r[1, 1]) * 2.0
|
|
w = (r[1, 0] - r[0, 1]) / s
|
|
x = (r[0, 2] + r[2, 0]) / s
|
|
y = (r[1, 2] + r[2, 1]) / s
|
|
z = 0.25 * s
|
|
q = np.array([w, x, y, z], dtype=np.float64)
|
|
n = float(np.linalg.norm(q))
|
|
if n <= 1e-12:
|
|
return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
return q / n
|
|
|
|
|
|
def _quat_to_rot(q: np.ndarray) -> np.ndarray:
|
|
w, x, y, z = q
|
|
return np.array(
|
|
[
|
|
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
|
|
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
|
|
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
|
|
],
|
|
dtype=np.float64,
|
|
)
|
|
|
|
|
|
class SensorPoseFusion:
|
|
def __init__(self, cfg: FusionConfig):
|
|
self.cfg = cfg
|
|
self._priors: Dict[str, Dict[str, Any]] = {}
|
|
self._last_lidar_t = 0.0
|
|
|
|
def reset(self) -> None:
|
|
self._priors.clear()
|
|
self._last_lidar_t = 0.0
|
|
|
|
def _sensor_weight(self, sensor: str) -> float:
|
|
s = str(sensor).lower().strip()
|
|
if s == "imu":
|
|
return float(self.cfg.imu_weight)
|
|
if s in ("wheel", "odom", "odometry"):
|
|
return float(self.cfg.wheel_weight)
|
|
if s in ("vision", "camera", "vio"):
|
|
return float(self.cfg.vision_weight)
|
|
return 0.0
|
|
|
|
def update_prior(self, sensor: str, pose: Any, confidence: float = 1.0, timestamp: Optional[float] = None) -> bool:
|
|
arr = np.asarray(pose, dtype=np.float64)
|
|
if arr.shape != (4, 4):
|
|
return False
|
|
t = float(time.time()) if timestamp is None else float(timestamp)
|
|
if (t - float(time.time())) > float(self.cfg.max_future_sec):
|
|
return False
|
|
conf = float(np.clip(confidence, 0.0, 1.0))
|
|
self._priors[str(sensor).lower().strip()] = {
|
|
"pose": arr,
|
|
"confidence": conf,
|
|
"timestamp": t,
|
|
}
|
|
return True
|
|
|
|
def fuse_pose(self, lidar_pose: np.ndarray, now: Optional[float] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
|
|
pose_lidar = np.asarray(lidar_pose, dtype=np.float64)
|
|
if pose_lidar.shape != (4, 4):
|
|
return pose_lidar, {"used": [], "enabled": bool(self.cfg.enabled)}
|
|
t_now = float(time.time()) if now is None else float(now)
|
|
|
|
if not self.cfg.enabled:
|
|
self._last_lidar_t = t_now
|
|
return pose_lidar, {"used": [], "enabled": False}
|
|
|
|
used = []
|
|
weights = [float(self.cfg.lidar_weight)]
|
|
poses = [pose_lidar]
|
|
rejected = []
|
|
|
|
for sensor, ent in self._priors.items():
|
|
ts = float(ent.get("timestamp", 0.0))
|
|
age = t_now - ts
|
|
if age < -float(self.cfg.max_future_sec):
|
|
rejected.append(f"{sensor}:future")
|
|
continue
|
|
if self._last_lidar_t > 0.0 and abs(ts - self._last_lidar_t) > float(self.cfg.max_timestamp_skew_sec):
|
|
rejected.append(f"{sensor}:skew")
|
|
continue
|
|
if age > float(self.cfg.max_age_sec):
|
|
rejected.append(f"{sensor}:stale")
|
|
continue
|
|
w = self._sensor_weight(sensor) * float(np.clip(ent.get("confidence", 1.0), 0.0, 1.0))
|
|
if w <= 0.0:
|
|
rejected.append(f"{sensor}:weight")
|
|
continue
|
|
p = np.asarray(ent.get("pose"), dtype=np.float64)
|
|
if p.shape != (4, 4):
|
|
rejected.append(f"{sensor}:pose")
|
|
continue
|
|
poses.append(p)
|
|
weights.append(w)
|
|
used.append(sensor)
|
|
|
|
if len(poses) <= 1:
|
|
self._last_lidar_t = t_now
|
|
return pose_lidar, {"used": [], "rejected": rejected, "enabled": True}
|
|
|
|
w_arr = np.asarray(weights, dtype=np.float64)
|
|
w_arr = w_arr / max(np.sum(w_arr), 1e-9)
|
|
|
|
out = np.array(pose_lidar, dtype=np.float64, copy=True)
|
|
|
|
if self.cfg.use_translation:
|
|
t = np.zeros((3,), dtype=np.float64)
|
|
for i, p in enumerate(poses):
|
|
t += w_arr[i] * p[:3, 3]
|
|
out[:3, 3] = t
|
|
|
|
if self.cfg.use_rotation:
|
|
q_ref = _rot_to_quat(poses[0][:3, :3])
|
|
q_sum = np.zeros((4,), dtype=np.float64)
|
|
for i, p in enumerate(poses):
|
|
q = _rot_to_quat(p[:3, :3])
|
|
if np.dot(q, q_ref) < 0:
|
|
q = -q
|
|
q_sum += w_arr[i] * q
|
|
qn = float(np.linalg.norm(q_sum))
|
|
if qn > 1e-12:
|
|
q_sum /= qn
|
|
out[:3, :3] = _quat_to_rot(q_sum)
|
|
|
|
self._last_lidar_t = t_now
|
|
return out, {"used": used, "rejected": rejected, "enabled": True}
|