Marcus/Lidar/SLAM_Fusion.py
2026-04-12 18:50:22 +04:00

206 lines
7.2 KiB
Python

from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import numpy as np
def _safe_float(v: Any, default: float) -> float:
try:
return float(v)
except Exception:
return float(default)
def _as_bool(v: Any, default: bool) -> bool:
if v is None:
return default
if isinstance(v, bool):
return v
if isinstance(v, (int, float)):
return bool(v)
return str(v).strip().lower() in ("1", "true", "yes", "on")
@dataclass
class FusionConfig:
enabled: bool = False
max_age_sec: float = 0.30
max_future_sec: float = 0.10
max_timestamp_skew_sec: float = 0.50
lidar_weight: float = 1.0
imu_weight: float = 0.35
wheel_weight: float = 0.45
vision_weight: float = 0.70
use_translation: bool = True
use_rotation: bool = True
@staticmethod
def from_dict(d: Dict[str, Any] | None) -> "FusionConfig":
src = d or {}
return FusionConfig(
enabled=_as_bool(src.get("enabled", False), False),
max_age_sec=max(0.01, _safe_float(src.get("max_age_sec", 0.30), 0.30)),
max_future_sec=max(0.0, _safe_float(src.get("max_future_sec", 0.10), 0.10)),
max_timestamp_skew_sec=max(0.01, _safe_float(src.get("max_timestamp_skew_sec", 0.50), 0.50)),
lidar_weight=max(0.0, _safe_float(src.get("lidar_weight", 1.0), 1.0)),
imu_weight=max(0.0, _safe_float(src.get("imu_weight", 0.35), 0.35)),
wheel_weight=max(0.0, _safe_float(src.get("wheel_weight", 0.45), 0.45)),
vision_weight=max(0.0, _safe_float(src.get("vision_weight", 0.70), 0.70)),
use_translation=_as_bool(src.get("use_translation", True), True),
use_rotation=_as_bool(src.get("use_rotation", True), True),
)
def _rot_to_quat(r: np.ndarray) -> np.ndarray:
tr = float(np.trace(r))
if tr > 0.0:
s = np.sqrt(tr + 1.0) * 2.0
w = 0.25 * s
x = (r[2, 1] - r[1, 2]) / s
y = (r[0, 2] - r[2, 0]) / s
z = (r[1, 0] - r[0, 1]) / s
elif (r[0, 0] > r[1, 1]) and (r[0, 0] > r[2, 2]):
s = np.sqrt(1.0 + r[0, 0] - r[1, 1] - r[2, 2]) * 2.0
w = (r[2, 1] - r[1, 2]) / s
x = 0.25 * s
y = (r[0, 1] + r[1, 0]) / s
z = (r[0, 2] + r[2, 0]) / s
elif r[1, 1] > r[2, 2]:
s = np.sqrt(1.0 + r[1, 1] - r[0, 0] - r[2, 2]) * 2.0
w = (r[0, 2] - r[2, 0]) / s
x = (r[0, 1] + r[1, 0]) / s
y = 0.25 * s
z = (r[1, 2] + r[2, 1]) / s
else:
s = np.sqrt(1.0 + r[2, 2] - r[0, 0] - r[1, 1]) * 2.0
w = (r[1, 0] - r[0, 1]) / s
x = (r[0, 2] + r[2, 0]) / s
y = (r[1, 2] + r[2, 1]) / s
z = 0.25 * s
q = np.array([w, x, y, z], dtype=np.float64)
n = float(np.linalg.norm(q))
if n <= 1e-12:
return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)
return q / n
def _quat_to_rot(q: np.ndarray) -> np.ndarray:
w, x, y, z = q
return np.array(
[
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
],
dtype=np.float64,
)
class SensorPoseFusion:
def __init__(self, cfg: FusionConfig):
self.cfg = cfg
self._priors: Dict[str, Dict[str, Any]] = {}
self._last_lidar_t = 0.0
def reset(self) -> None:
self._priors.clear()
self._last_lidar_t = 0.0
def _sensor_weight(self, sensor: str) -> float:
s = str(sensor).lower().strip()
if s == "imu":
return float(self.cfg.imu_weight)
if s in ("wheel", "odom", "odometry"):
return float(self.cfg.wheel_weight)
if s in ("vision", "camera", "vio"):
return float(self.cfg.vision_weight)
return 0.0
def update_prior(self, sensor: str, pose: Any, confidence: float = 1.0, timestamp: Optional[float] = None) -> bool:
arr = np.asarray(pose, dtype=np.float64)
if arr.shape != (4, 4):
return False
t = float(time.time()) if timestamp is None else float(timestamp)
if (t - float(time.time())) > float(self.cfg.max_future_sec):
return False
conf = float(np.clip(confidence, 0.0, 1.0))
self._priors[str(sensor).lower().strip()] = {
"pose": arr,
"confidence": conf,
"timestamp": t,
}
return True
def fuse_pose(self, lidar_pose: np.ndarray, now: Optional[float] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
pose_lidar = np.asarray(lidar_pose, dtype=np.float64)
if pose_lidar.shape != (4, 4):
return pose_lidar, {"used": [], "enabled": bool(self.cfg.enabled)}
t_now = float(time.time()) if now is None else float(now)
if not self.cfg.enabled:
self._last_lidar_t = t_now
return pose_lidar, {"used": [], "enabled": False}
used = []
weights = [float(self.cfg.lidar_weight)]
poses = [pose_lidar]
rejected = []
for sensor, ent in self._priors.items():
ts = float(ent.get("timestamp", 0.0))
age = t_now - ts
if age < -float(self.cfg.max_future_sec):
rejected.append(f"{sensor}:future")
continue
if self._last_lidar_t > 0.0 and abs(ts - self._last_lidar_t) > float(self.cfg.max_timestamp_skew_sec):
rejected.append(f"{sensor}:skew")
continue
if age > float(self.cfg.max_age_sec):
rejected.append(f"{sensor}:stale")
continue
w = self._sensor_weight(sensor) * float(np.clip(ent.get("confidence", 1.0), 0.0, 1.0))
if w <= 0.0:
rejected.append(f"{sensor}:weight")
continue
p = np.asarray(ent.get("pose"), dtype=np.float64)
if p.shape != (4, 4):
rejected.append(f"{sensor}:pose")
continue
poses.append(p)
weights.append(w)
used.append(sensor)
if len(poses) <= 1:
self._last_lidar_t = t_now
return pose_lidar, {"used": [], "rejected": rejected, "enabled": True}
w_arr = np.asarray(weights, dtype=np.float64)
w_arr = w_arr / max(np.sum(w_arr), 1e-9)
out = np.array(pose_lidar, dtype=np.float64, copy=True)
if self.cfg.use_translation:
t = np.zeros((3,), dtype=np.float64)
for i, p in enumerate(poses):
t += w_arr[i] * p[:3, 3]
out[:3, 3] = t
if self.cfg.use_rotation:
q_ref = _rot_to_quat(poses[0][:3, :3])
q_sum = np.zeros((4,), dtype=np.float64)
for i, p in enumerate(poses):
q = _rot_to_quat(p[:3, :3])
if np.dot(q, q_ref) < 0:
q = -q
q_sum += w_arr[i] * q
qn = float(np.linalg.norm(q_sum))
if qn > 1e-12:
q_sum /= qn
out[:3, :3] = _quat_to_rot(q_sum)
self._last_lidar_t = t_now
return out, {"used": used, "rejected": rejected, "enabled": True}