419 lines
16 KiB
Python
419 lines
16 KiB
Python
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import numpy as np
|
||
|
||
|
||
def _as_bool(v: Any, default: bool) -> bool:
|
||
if v is None:
|
||
return default
|
||
if isinstance(v, bool):
|
||
return v
|
||
if isinstance(v, (int, float)):
|
||
return bool(v)
|
||
return str(v).strip().lower() in ("1", "true", "yes", "on")
|
||
|
||
|
||
def _safe_float(v: Any, default: float) -> float:
|
||
try:
|
||
return float(v)
|
||
except Exception:
|
||
return float(default)
|
||
|
||
|
||
def _safe_int(v: Any, default: int) -> int:
|
||
try:
|
||
return int(v)
|
||
except Exception:
|
||
return int(default)
|
||
|
||
|
||
def _pose_delta(pose_a: np.ndarray, pose_b: np.ndarray) -> tuple[float, float]:
|
||
dp = pose_a[:3, 3] - pose_b[:3, 3]
|
||
trans = float(np.linalg.norm(dp))
|
||
r = pose_a[:3, :3] @ pose_b[:3, :3].T
|
||
ang = float(np.degrees(np.arccos(np.clip((np.trace(r) - 1.0) * 0.5, -1.0, 1.0))))
|
||
return trans, ang
|
||
|
||
|
||
def _voxel_downsample_numpy(points: np.ndarray, voxel: float) -> np.ndarray:
|
||
if points is None or len(points) == 0:
|
||
return np.zeros((0, 3), dtype=np.float32)
|
||
if voxel <= 0:
|
||
return points.astype(np.float32, copy=False)
|
||
keys = np.floor(points / voxel).astype(np.int32)
|
||
uniq, inv = np.unique(keys, axis=0, return_inverse=True)
|
||
out = np.zeros((len(uniq), 3), dtype=np.float64)
|
||
cnt = np.zeros((len(uniq),), dtype=np.int64)
|
||
np.add.at(out, inv, points.astype(np.float64, copy=False))
|
||
np.add.at(cnt, inv, 1)
|
||
out /= np.maximum(cnt[:, None], 1)
|
||
return out.astype(np.float32)
|
||
|
||
|
||
def _rigid_fit_svd(src_pts: np.ndarray, dst_pts: np.ndarray) -> Optional[np.ndarray]:
|
||
if src_pts is None or dst_pts is None:
|
||
return None
|
||
if len(src_pts) < 3 or len(dst_pts) < 3:
|
||
return None
|
||
a = np.asarray(src_pts, dtype=np.float64)
|
||
b = np.asarray(dst_pts, dtype=np.float64)
|
||
if a.shape != b.shape or a.ndim != 2 or a.shape[1] != 3:
|
||
return None
|
||
c_a = np.mean(a, axis=0)
|
||
c_b = np.mean(b, axis=0)
|
||
aa = a - c_a
|
||
bb = b - c_b
|
||
H = aa.T @ bb
|
||
try:
|
||
U, _, Vt = np.linalg.svd(H, full_matrices=False)
|
||
except Exception:
|
||
return None
|
||
R = Vt.T @ U.T
|
||
if np.linalg.det(R) < 0:
|
||
Vt[2, :] *= -1.0
|
||
R = Vt.T @ U.T
|
||
t = c_b - (R @ c_a)
|
||
T = np.eye(4, dtype=np.float64)
|
||
T[:3, :3] = R
|
||
T[:3, 3] = t
|
||
return T
|
||
|
||
|
||
@dataclass
|
||
class LoopClosureConfig:
|
||
enabled: bool = False
|
||
keyframe_every_n_frames: int = 25
|
||
min_keyframe_translation_m: float = 0.30
|
||
min_keyframe_rotation_deg: float = 8.0
|
||
loop_search_radius_m: float = 2.0
|
||
min_loop_frame_gap: int = 30
|
||
downsample_voxel_m: float = 0.20
|
||
icp_max_corr_m: float = 1.2
|
||
icp_max_iter: int = 40
|
||
accept_fitness: float = 0.30
|
||
accept_rmse: float = 0.35
|
||
optimize_every_n_keyframes: int = 8
|
||
max_keyframes: int = 500
|
||
max_correction_translation_m: float = 2.0
|
||
max_correction_rotation_deg: float = 20.0
|
||
|
||
@staticmethod
|
||
def from_dict(d: Dict[str, Any] | None) -> "LoopClosureConfig":
|
||
src = d or {}
|
||
return LoopClosureConfig(
|
||
enabled=_as_bool(src.get("enabled", False), False),
|
||
keyframe_every_n_frames=max(1, _safe_int(src.get("keyframe_every_n_frames", 25), 25)),
|
||
min_keyframe_translation_m=max(0.01, _safe_float(src.get("min_keyframe_translation_m", 0.30), 0.30)),
|
||
min_keyframe_rotation_deg=max(0.1, _safe_float(src.get("min_keyframe_rotation_deg", 8.0), 8.0)),
|
||
loop_search_radius_m=max(0.2, _safe_float(src.get("loop_search_radius_m", 2.0), 2.0)),
|
||
min_loop_frame_gap=max(5, _safe_int(src.get("min_loop_frame_gap", 30), 30)),
|
||
downsample_voxel_m=max(0.01, _safe_float(src.get("downsample_voxel_m", 0.20), 0.20)),
|
||
icp_max_corr_m=max(0.2, _safe_float(src.get("icp_max_corr_m", 1.2), 1.2)),
|
||
icp_max_iter=max(10, _safe_int(src.get("icp_max_iter", 40), 40)),
|
||
accept_fitness=np.clip(_safe_float(src.get("accept_fitness", 0.30), 0.30), 0.0, 1.0),
|
||
accept_rmse=max(0.01, _safe_float(src.get("accept_rmse", 0.35), 0.35)),
|
||
optimize_every_n_keyframes=max(2, _safe_int(src.get("optimize_every_n_keyframes", 8), 8)),
|
||
max_keyframes=max(50, _safe_int(src.get("max_keyframes", 500), 500)),
|
||
max_correction_translation_m=max(
|
||
0.01, _safe_float(src.get("max_correction_translation_m", 2.0), 2.0)
|
||
),
|
||
max_correction_rotation_deg=max(
|
||
0.1, _safe_float(src.get("max_correction_rotation_deg", 20.0), 20.0)
|
||
),
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class LoopClosureResult:
|
||
keyframe_added: bool = False
|
||
loop_detected: bool = False
|
||
optimized: bool = False
|
||
correction: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float64))
|
||
info: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class LoopClosureBackend:
|
||
"""
|
||
Lightweight loop-closure backend using keyframes + pose-graph optimization.
|
||
"""
|
||
|
||
def __init__(self, cfg: LoopClosureConfig):
|
||
self.cfg = cfg
|
||
self.frame_index = 0
|
||
self.last_keyframe_frame = -10**9
|
||
self.keyframe_poses: List[np.ndarray] = [] # world_T_sensor
|
||
self.keyframe_clouds: List[np.ndarray] = [] # sensor-frame clouds
|
||
self._o3d = None
|
||
self._pg = None
|
||
self._loop_count = 0
|
||
self._opt_count = 0
|
||
self._last_opt_reject: Optional[Dict[str, Any]] = None
|
||
self._import_open3d()
|
||
|
||
def _import_open3d(self) -> None:
|
||
# SAFE mode: avoid Open3D pose-graph path (native crashes observed).
|
||
self._o3d = None
|
||
self._pg = None
|
||
|
||
def reset(self) -> None:
|
||
self.frame_index = 0
|
||
self.last_keyframe_frame = -10**9
|
||
self.keyframe_poses.clear()
|
||
self.keyframe_clouds.clear()
|
||
self._loop_count = 0
|
||
self._opt_count = 0
|
||
self._last_opt_reject = None
|
||
self._pg = None
|
||
|
||
def _should_add_keyframe(self, pose_world: np.ndarray) -> bool:
|
||
if len(self.keyframe_poses) == 0:
|
||
return True
|
||
if (self.frame_index - self.last_keyframe_frame) >= int(self.cfg.keyframe_every_n_frames):
|
||
return True
|
||
trans, ang = _pose_delta(pose_world, self.keyframe_poses[-1])
|
||
return (trans >= float(self.cfg.min_keyframe_translation_m)) or (ang >= float(self.cfg.min_keyframe_rotation_deg))
|
||
|
||
def _to_o3d_cloud(self, pts: np.ndarray):
|
||
o3d = self._o3d
|
||
if o3d is None:
|
||
return None
|
||
pc = o3d.geometry.PointCloud()
|
||
pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64, copy=False))
|
||
return pc
|
||
|
||
def _relative_source_to_target(self, pose_source: np.ndarray, pose_target: np.ndarray) -> np.ndarray:
|
||
# T_source_to_target
|
||
return np.linalg.inv(pose_target) @ pose_source
|
||
|
||
def _find_candidate(self, idx_cur: int) -> Optional[int]:
|
||
if idx_cur <= int(self.cfg.min_loop_frame_gap):
|
||
return None
|
||
cur_t = self.keyframe_poses[idx_cur][:3, 3]
|
||
best = None
|
||
best_d = float("inf")
|
||
for i in range(0, idx_cur - int(self.cfg.min_loop_frame_gap)):
|
||
d = float(np.linalg.norm(cur_t - self.keyframe_poses[i][:3, 3]))
|
||
if d <= float(self.cfg.loop_search_radius_m) and d < best_d:
|
||
best = i
|
||
best_d = d
|
||
return best
|
||
|
||
def _run_icp(self, src_pts: np.ndarray, tgt_pts: np.ndarray, init: np.ndarray, max_corr: float) -> Optional[Dict[str, Any]]:
|
||
src0 = np.asarray(src_pts, dtype=np.float64)
|
||
tgt = np.asarray(tgt_pts, dtype=np.float64)
|
||
if src0.ndim != 2 or tgt.ndim != 2 or src0.shape[1] != 3 or tgt.shape[1] != 3:
|
||
return None
|
||
if len(src0) < 3 or len(tgt) < 3:
|
||
return None
|
||
try:
|
||
from scipy.spatial import cKDTree
|
||
except Exception:
|
||
return None
|
||
|
||
T = np.asarray(init, dtype=np.float64).copy() if np.asarray(init).shape == (4, 4) else np.eye(4, dtype=np.float64)
|
||
tree = cKDTree(tgt)
|
||
max_corr_m = float(max(0.05, max_corr))
|
||
min_corr = max(20, int(0.12 * len(src0)))
|
||
last_rmse = 9e9
|
||
last_inliers = 0
|
||
|
||
for _ in range(int(max(6, self.cfg.icp_max_iter))):
|
||
src_w = (src0 @ T[:3, :3].T) + T[:3, 3]
|
||
d, idx = tree.query(src_w, k=1)
|
||
keep = np.isfinite(d) & (d <= max_corr_m)
|
||
n_in = int(np.count_nonzero(keep))
|
||
if n_in < min_corr:
|
||
break
|
||
src_corr = src_w[keep]
|
||
tgt_corr = tgt[np.asarray(idx[keep], dtype=np.int32)]
|
||
delta = _rigid_fit_svd(src_corr, tgt_corr)
|
||
if delta is None:
|
||
break
|
||
T = np.asarray(delta, dtype=np.float64) @ T
|
||
prev_rmse = float(last_rmse)
|
||
rmse = float(np.sqrt(np.mean(np.square(d[keep]))))
|
||
last_rmse = rmse
|
||
last_inliers = n_in
|
||
if abs(prev_rmse - rmse) < 1e-5:
|
||
break
|
||
|
||
fitness = float(last_inliers) / max(1.0, float(len(src0)))
|
||
return {
|
||
"fitness": float(fitness),
|
||
"inlier_rmse": float(last_rmse),
|
||
"transformation": np.asarray(T, dtype=np.float64),
|
||
"inliers": int(last_inliers),
|
||
}
|
||
|
||
def _rebuild_graph_odometry_only(self) -> None:
|
||
self._pg = None
|
||
|
||
def _maybe_trim(self) -> None:
|
||
kmax = int(self.cfg.max_keyframes)
|
||
if len(self.keyframe_poses) <= kmax:
|
||
return
|
||
drop = len(self.keyframe_poses) - kmax
|
||
self.keyframe_poses = self.keyframe_poses[drop:]
|
||
self.keyframe_clouds = self.keyframe_clouds[drop:]
|
||
self._rebuild_graph_odometry_only()
|
||
|
||
@staticmethod
|
||
def _slerp_rotation(R0: np.ndarray, R1: np.ndarray, t: float) -> np.ndarray:
|
||
"""Spherical linear interpolation between two rotation matrices."""
|
||
dR = R0.T @ R1
|
||
cos_angle = np.clip((np.trace(dR) - 1.0) * 0.5, -1.0, 1.0)
|
||
angle = float(np.arccos(cos_angle))
|
||
if abs(angle) < 1e-8:
|
||
return R0.copy()
|
||
sin_a = np.sin(angle)
|
||
axis = np.array([
|
||
dR[2, 1] - dR[1, 2],
|
||
dR[0, 2] - dR[2, 0],
|
||
dR[1, 0] - dR[0, 1],
|
||
]) / (2.0 * sin_a)
|
||
angle_t = angle * float(t)
|
||
K = np.array([
|
||
[0.0, -axis[2], axis[1]],
|
||
[axis[2], 0.0, -axis[0]],
|
||
[-axis[1], axis[0], 0.0],
|
||
])
|
||
dRt = np.eye(3) + np.sin(angle_t) * K + (1.0 - np.cos(angle_t)) * (K @ K)
|
||
return R0 @ dRt
|
||
|
||
@staticmethod
|
||
def _interpolate_se3(T0: np.ndarray, T1: np.ndarray, alpha: float) -> np.ndarray:
|
||
"""Interpolate between two SE3 transforms; alpha=0 → T0, alpha=1 → T1."""
|
||
a = float(np.clip(alpha, 0.0, 1.0))
|
||
out = np.eye(4, dtype=np.float64)
|
||
out[:3, :3] = LoopClosureBackend._slerp_rotation(T0[:3, :3], T1[:3, :3], a)
|
||
out[:3, 3] = (1.0 - a) * T0[:3, 3] + a * T1[:3, 3]
|
||
return out
|
||
|
||
def _optimize(self, src_idx: int, dst_idx: int, T_meas: np.ndarray) -> Optional[List[np.ndarray]]:
|
||
"""
|
||
Distribute loop-closure error linearly across the keyframe chain.
|
||
|
||
Poses from dst_idx+1 to src_idx receive a fraction of the correction
|
||
proportional to their distance from dst_idx. All poses after src_idx
|
||
receive the full correction. Poses at or before dst_idx are unchanged.
|
||
|
||
Parameters
|
||
----------
|
||
src_idx : index of the current (drifted) keyframe
|
||
dst_idx : index of the matching earlier keyframe
|
||
T_meas : 4×4 relative transform — src expressed in dst's frame
|
||
(output of ICP; transforms src points into dst frame)
|
||
|
||
Returns
|
||
-------
|
||
List of updated 4×4 pose matrices, or None if inputs are invalid.
|
||
"""
|
||
n = len(self.keyframe_poses)
|
||
if n < 2:
|
||
return None
|
||
if not (0 <= dst_idx < src_idx < n):
|
||
return None
|
||
|
||
poses = [np.array(p, dtype=np.float64) for p in self.keyframe_poses]
|
||
|
||
# Where src *should* be in world frame if there were no drift:
|
||
T_src_expected = poses[dst_idx] @ T_meas
|
||
T_src_actual = poses[src_idx]
|
||
|
||
# Full correction needed to move src_actual → src_expected:
|
||
try:
|
||
T_correction = T_src_expected @ np.linalg.inv(T_src_actual)
|
||
except np.linalg.LinAlgError:
|
||
return None
|
||
|
||
# Sanity-check the correction magnitude
|
||
corr_t = float(np.linalg.norm(T_correction[:3, 3]))
|
||
_, corr_deg = _pose_delta(T_correction, np.eye(4, dtype=np.float64))
|
||
if corr_t > float(self.cfg.max_correction_translation_m) or corr_deg > float(self.cfg.max_correction_rotation_deg):
|
||
return None
|
||
|
||
T_identity = np.eye(4, dtype=np.float64)
|
||
chain_len = src_idx - dst_idx
|
||
|
||
# Linearly distribute from dst_idx+1 (alpha≈0) to src_idx (alpha=1)
|
||
for i in range(dst_idx + 1, src_idx + 1):
|
||
alpha = float(i - dst_idx) / float(chain_len)
|
||
T_partial = self._interpolate_se3(T_identity, T_correction, alpha)
|
||
poses[i] = T_partial @ poses[i]
|
||
|
||
# Propagate full correction to all poses after src_idx
|
||
for i in range(src_idx + 1, n):
|
||
poses[i] = T_correction @ poses[i]
|
||
|
||
return poses
|
||
|
||
def process_frame(self, points_sensor: np.ndarray, pose_world: np.ndarray) -> LoopClosureResult:
|
||
self.frame_index += 1
|
||
result = LoopClosureResult()
|
||
|
||
if not self.cfg.enabled:
|
||
return result
|
||
if points_sensor is None or len(points_sensor) < 40:
|
||
return result
|
||
if pose_world is None or np.asarray(pose_world).shape != (4, 4):
|
||
return result
|
||
|
||
pose = np.asarray(pose_world, dtype=np.float64)
|
||
if not self._should_add_keyframe(pose):
|
||
return result
|
||
|
||
ds = _voxel_downsample_numpy(points_sensor.astype(np.float32, copy=False), float(self.cfg.downsample_voxel_m))
|
||
if len(ds) < 40:
|
||
return result
|
||
|
||
self.keyframe_clouds.append(ds)
|
||
self.keyframe_poses.append(pose.copy())
|
||
self.last_keyframe_frame = self.frame_index
|
||
result.keyframe_added = True
|
||
|
||
idx = len(self.keyframe_poses) - 1
|
||
cand = self._find_candidate(idx)
|
||
if cand is not None:
|
||
init = self._relative_source_to_target(self.keyframe_poses[idx], self.keyframe_poses[cand])
|
||
reg = self._run_icp(
|
||
self.keyframe_clouds[idx],
|
||
self.keyframe_clouds[cand],
|
||
init=init,
|
||
max_corr=float(self.cfg.icp_max_corr_m),
|
||
)
|
||
if reg is not None:
|
||
fit = float(reg.get("fitness", 0.0))
|
||
rmse = float(reg.get("inlier_rmse", 9e9))
|
||
ok = (fit >= float(self.cfg.accept_fitness)) and (rmse <= float(self.cfg.accept_rmse))
|
||
result.info["loop_candidate"] = int(cand)
|
||
result.info["loop_fitness"] = fit
|
||
result.info["loop_rmse"] = rmse
|
||
if ok:
|
||
result.loop_detected = True
|
||
self._loop_count += 1
|
||
t_meas = np.asarray(reg.get("transformation", np.eye(4, dtype=np.float64)), dtype=np.float64)
|
||
optimized_poses = self._optimize(idx, cand, t_meas)
|
||
if optimized_poses is None:
|
||
self._last_opt_reject = {
|
||
"reason": "optimize_rejected",
|
||
"src_idx": int(idx),
|
||
"dst_idx": int(cand),
|
||
}
|
||
result.info["rejected"] = dict(self._last_opt_reject)
|
||
else:
|
||
# Correction seen at src_idx (for callers that need a global correction estimate)
|
||
pose_cur = np.asarray(self.keyframe_poses[idx], dtype=np.float64)
|
||
correction = optimized_poses[idx] @ np.linalg.inv(pose_cur)
|
||
result.optimized = True
|
||
result.correction = np.asarray(correction, dtype=np.float64)
|
||
self._opt_count += 1
|
||
self.keyframe_poses = optimized_poses
|
||
result.info["loop_count"] = int(self._loop_count)
|
||
result.info["opt_count"] = int(self._opt_count)
|
||
|
||
self._maybe_trim()
|
||
return result
|