from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import numpy as np def _as_bool(v: Any, default: bool) -> bool: if v is None: return default if isinstance(v, bool): return v if isinstance(v, (int, float)): return bool(v) return str(v).strip().lower() in ("1", "true", "yes", "on") def _safe_float(v: Any, default: float) -> float: try: return float(v) except Exception: return float(default) def _safe_int(v: Any, default: int) -> int: try: return int(v) except Exception: return int(default) def _pose_delta(pose_a: np.ndarray, pose_b: np.ndarray) -> tuple[float, float]: dp = pose_a[:3, 3] - pose_b[:3, 3] trans = float(np.linalg.norm(dp)) r = pose_a[:3, :3] @ pose_b[:3, :3].T ang = float(np.degrees(np.arccos(np.clip((np.trace(r) - 1.0) * 0.5, -1.0, 1.0)))) return trans, ang def _voxel_downsample_numpy(points: np.ndarray, voxel: float) -> np.ndarray: if points is None or len(points) == 0: return np.zeros((0, 3), dtype=np.float32) if voxel <= 0: return points.astype(np.float32, copy=False) keys = np.floor(points / voxel).astype(np.int32) uniq, inv = np.unique(keys, axis=0, return_inverse=True) out = np.zeros((len(uniq), 3), dtype=np.float64) cnt = np.zeros((len(uniq),), dtype=np.int64) np.add.at(out, inv, points.astype(np.float64, copy=False)) np.add.at(cnt, inv, 1) out /= np.maximum(cnt[:, None], 1) return out.astype(np.float32) def _rigid_fit_svd(src_pts: np.ndarray, dst_pts: np.ndarray) -> Optional[np.ndarray]: if src_pts is None or dst_pts is None: return None if len(src_pts) < 3 or len(dst_pts) < 3: return None a = np.asarray(src_pts, dtype=np.float64) b = np.asarray(dst_pts, dtype=np.float64) if a.shape != b.shape or a.ndim != 2 or a.shape[1] != 3: return None c_a = np.mean(a, axis=0) c_b = np.mean(b, axis=0) aa = a - c_a bb = b - c_b H = aa.T @ bb try: U, _, Vt = np.linalg.svd(H, full_matrices=False) except Exception: return None R = Vt.T @ U.T if np.linalg.det(R) < 0: Vt[2, :] *= -1.0 R = Vt.T @ U.T t = c_b - (R @ c_a) T = np.eye(4, dtype=np.float64) T[:3, :3] = R T[:3, 3] = t return T @dataclass class LoopClosureConfig: enabled: bool = False keyframe_every_n_frames: int = 25 min_keyframe_translation_m: float = 0.30 min_keyframe_rotation_deg: float = 8.0 loop_search_radius_m: float = 2.0 min_loop_frame_gap: int = 30 downsample_voxel_m: float = 0.20 icp_max_corr_m: float = 1.2 icp_max_iter: int = 40 accept_fitness: float = 0.30 accept_rmse: float = 0.35 optimize_every_n_keyframes: int = 8 max_keyframes: int = 500 max_correction_translation_m: float = 2.0 max_correction_rotation_deg: float = 20.0 @staticmethod def from_dict(d: Dict[str, Any] | None) -> "LoopClosureConfig": src = d or {} return LoopClosureConfig( enabled=_as_bool(src.get("enabled", False), False), keyframe_every_n_frames=max(1, _safe_int(src.get("keyframe_every_n_frames", 25), 25)), min_keyframe_translation_m=max(0.01, _safe_float(src.get("min_keyframe_translation_m", 0.30), 0.30)), min_keyframe_rotation_deg=max(0.1, _safe_float(src.get("min_keyframe_rotation_deg", 8.0), 8.0)), loop_search_radius_m=max(0.2, _safe_float(src.get("loop_search_radius_m", 2.0), 2.0)), min_loop_frame_gap=max(5, _safe_int(src.get("min_loop_frame_gap", 30), 30)), downsample_voxel_m=max(0.01, _safe_float(src.get("downsample_voxel_m", 0.20), 0.20)), icp_max_corr_m=max(0.2, _safe_float(src.get("icp_max_corr_m", 1.2), 1.2)), icp_max_iter=max(10, _safe_int(src.get("icp_max_iter", 40), 40)), accept_fitness=np.clip(_safe_float(src.get("accept_fitness", 0.30), 0.30), 0.0, 1.0), accept_rmse=max(0.01, _safe_float(src.get("accept_rmse", 0.35), 0.35)), optimize_every_n_keyframes=max(2, _safe_int(src.get("optimize_every_n_keyframes", 8), 8)), max_keyframes=max(50, _safe_int(src.get("max_keyframes", 500), 500)), max_correction_translation_m=max( 0.01, _safe_float(src.get("max_correction_translation_m", 2.0), 2.0) ), max_correction_rotation_deg=max( 0.1, _safe_float(src.get("max_correction_rotation_deg", 20.0), 20.0) ), ) @dataclass class LoopClosureResult: keyframe_added: bool = False loop_detected: bool = False optimized: bool = False correction: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float64)) info: Dict[str, Any] = field(default_factory=dict) class LoopClosureBackend: """ Lightweight loop-closure backend using keyframes + pose-graph optimization. """ def __init__(self, cfg: LoopClosureConfig): self.cfg = cfg self.frame_index = 0 self.last_keyframe_frame = -10**9 self.keyframe_poses: List[np.ndarray] = [] # world_T_sensor self.keyframe_clouds: List[np.ndarray] = [] # sensor-frame clouds self._o3d = None self._pg = None self._loop_count = 0 self._opt_count = 0 self._last_opt_reject: Optional[Dict[str, Any]] = None self._import_open3d() def _import_open3d(self) -> None: # SAFE mode: avoid Open3D pose-graph path (native crashes observed). self._o3d = None self._pg = None def reset(self) -> None: self.frame_index = 0 self.last_keyframe_frame = -10**9 self.keyframe_poses.clear() self.keyframe_clouds.clear() self._loop_count = 0 self._opt_count = 0 self._last_opt_reject = None self._pg = None def _should_add_keyframe(self, pose_world: np.ndarray) -> bool: if len(self.keyframe_poses) == 0: return True if (self.frame_index - self.last_keyframe_frame) >= int(self.cfg.keyframe_every_n_frames): return True trans, ang = _pose_delta(pose_world, self.keyframe_poses[-1]) return (trans >= float(self.cfg.min_keyframe_translation_m)) or (ang >= float(self.cfg.min_keyframe_rotation_deg)) def _to_o3d_cloud(self, pts: np.ndarray): o3d = self._o3d if o3d is None: return None pc = o3d.geometry.PointCloud() pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64, copy=False)) return pc def _relative_source_to_target(self, pose_source: np.ndarray, pose_target: np.ndarray) -> np.ndarray: # T_source_to_target return np.linalg.inv(pose_target) @ pose_source def _find_candidate(self, idx_cur: int) -> Optional[int]: if idx_cur <= int(self.cfg.min_loop_frame_gap): return None cur_t = self.keyframe_poses[idx_cur][:3, 3] best = None best_d = float("inf") for i in range(0, idx_cur - int(self.cfg.min_loop_frame_gap)): d = float(np.linalg.norm(cur_t - self.keyframe_poses[i][:3, 3])) if d <= float(self.cfg.loop_search_radius_m) and d < best_d: best = i best_d = d return best def _run_icp(self, src_pts: np.ndarray, tgt_pts: np.ndarray, init: np.ndarray, max_corr: float) -> Optional[Dict[str, Any]]: src0 = np.asarray(src_pts, dtype=np.float64) tgt = np.asarray(tgt_pts, dtype=np.float64) if src0.ndim != 2 or tgt.ndim != 2 or src0.shape[1] != 3 or tgt.shape[1] != 3: return None if len(src0) < 3 or len(tgt) < 3: return None try: from scipy.spatial import cKDTree except Exception: return None T = np.asarray(init, dtype=np.float64).copy() if np.asarray(init).shape == (4, 4) else np.eye(4, dtype=np.float64) tree = cKDTree(tgt) max_corr_m = float(max(0.05, max_corr)) min_corr = max(20, int(0.12 * len(src0))) last_rmse = 9e9 last_inliers = 0 for _ in range(int(max(6, self.cfg.icp_max_iter))): src_w = (src0 @ T[:3, :3].T) + T[:3, 3] d, idx = tree.query(src_w, k=1) keep = np.isfinite(d) & (d <= max_corr_m) n_in = int(np.count_nonzero(keep)) if n_in < min_corr: break src_corr = src_w[keep] tgt_corr = tgt[np.asarray(idx[keep], dtype=np.int32)] delta = _rigid_fit_svd(src_corr, tgt_corr) if delta is None: break T = np.asarray(delta, dtype=np.float64) @ T prev_rmse = float(last_rmse) rmse = float(np.sqrt(np.mean(np.square(d[keep])))) last_rmse = rmse last_inliers = n_in if abs(prev_rmse - rmse) < 1e-5: break fitness = float(last_inliers) / max(1.0, float(len(src0))) return { "fitness": float(fitness), "inlier_rmse": float(last_rmse), "transformation": np.asarray(T, dtype=np.float64), "inliers": int(last_inliers), } def _rebuild_graph_odometry_only(self) -> None: self._pg = None def _maybe_trim(self) -> None: kmax = int(self.cfg.max_keyframes) if len(self.keyframe_poses) <= kmax: return drop = len(self.keyframe_poses) - kmax self.keyframe_poses = self.keyframe_poses[drop:] self.keyframe_clouds = self.keyframe_clouds[drop:] self._rebuild_graph_odometry_only() @staticmethod def _slerp_rotation(R0: np.ndarray, R1: np.ndarray, t: float) -> np.ndarray: """Spherical linear interpolation between two rotation matrices.""" dR = R0.T @ R1 cos_angle = np.clip((np.trace(dR) - 1.0) * 0.5, -1.0, 1.0) angle = float(np.arccos(cos_angle)) if abs(angle) < 1e-8: return R0.copy() sin_a = np.sin(angle) axis = np.array([ dR[2, 1] - dR[1, 2], dR[0, 2] - dR[2, 0], dR[1, 0] - dR[0, 1], ]) / (2.0 * sin_a) angle_t = angle * float(t) K = np.array([ [0.0, -axis[2], axis[1]], [axis[2], 0.0, -axis[0]], [-axis[1], axis[0], 0.0], ]) dRt = np.eye(3) + np.sin(angle_t) * K + (1.0 - np.cos(angle_t)) * (K @ K) return R0 @ dRt @staticmethod def _interpolate_se3(T0: np.ndarray, T1: np.ndarray, alpha: float) -> np.ndarray: """Interpolate between two SE3 transforms; alpha=0 → T0, alpha=1 → T1.""" a = float(np.clip(alpha, 0.0, 1.0)) out = np.eye(4, dtype=np.float64) out[:3, :3] = LoopClosureBackend._slerp_rotation(T0[:3, :3], T1[:3, :3], a) out[:3, 3] = (1.0 - a) * T0[:3, 3] + a * T1[:3, 3] return out def _optimize(self, src_idx: int, dst_idx: int, T_meas: np.ndarray) -> Optional[List[np.ndarray]]: """ Distribute loop-closure error linearly across the keyframe chain. Poses from dst_idx+1 to src_idx receive a fraction of the correction proportional to their distance from dst_idx. All poses after src_idx receive the full correction. Poses at or before dst_idx are unchanged. Parameters ---------- src_idx : index of the current (drifted) keyframe dst_idx : index of the matching earlier keyframe T_meas : 4×4 relative transform — src expressed in dst's frame (output of ICP; transforms src points into dst frame) Returns ------- List of updated 4×4 pose matrices, or None if inputs are invalid. """ n = len(self.keyframe_poses) if n < 2: return None if not (0 <= dst_idx < src_idx < n): return None poses = [np.array(p, dtype=np.float64) for p in self.keyframe_poses] # Where src *should* be in world frame if there were no drift: T_src_expected = poses[dst_idx] @ T_meas T_src_actual = poses[src_idx] # Full correction needed to move src_actual → src_expected: try: T_correction = T_src_expected @ np.linalg.inv(T_src_actual) except np.linalg.LinAlgError: return None # Sanity-check the correction magnitude corr_t = float(np.linalg.norm(T_correction[:3, 3])) _, corr_deg = _pose_delta(T_correction, np.eye(4, dtype=np.float64)) if corr_t > float(self.cfg.max_correction_translation_m) or corr_deg > float(self.cfg.max_correction_rotation_deg): return None T_identity = np.eye(4, dtype=np.float64) chain_len = src_idx - dst_idx # Linearly distribute from dst_idx+1 (alpha≈0) to src_idx (alpha=1) for i in range(dst_idx + 1, src_idx + 1): alpha = float(i - dst_idx) / float(chain_len) T_partial = self._interpolate_se3(T_identity, T_correction, alpha) poses[i] = T_partial @ poses[i] # Propagate full correction to all poses after src_idx for i in range(src_idx + 1, n): poses[i] = T_correction @ poses[i] return poses def process_frame(self, points_sensor: np.ndarray, pose_world: np.ndarray) -> LoopClosureResult: self.frame_index += 1 result = LoopClosureResult() if not self.cfg.enabled: return result if points_sensor is None or len(points_sensor) < 40: return result if pose_world is None or np.asarray(pose_world).shape != (4, 4): return result pose = np.asarray(pose_world, dtype=np.float64) if not self._should_add_keyframe(pose): return result ds = _voxel_downsample_numpy(points_sensor.astype(np.float32, copy=False), float(self.cfg.downsample_voxel_m)) if len(ds) < 40: return result self.keyframe_clouds.append(ds) self.keyframe_poses.append(pose.copy()) self.last_keyframe_frame = self.frame_index result.keyframe_added = True idx = len(self.keyframe_poses) - 1 cand = self._find_candidate(idx) if cand is not None: init = self._relative_source_to_target(self.keyframe_poses[idx], self.keyframe_poses[cand]) reg = self._run_icp( self.keyframe_clouds[idx], self.keyframe_clouds[cand], init=init, max_corr=float(self.cfg.icp_max_corr_m), ) if reg is not None: fit = float(reg.get("fitness", 0.0)) rmse = float(reg.get("inlier_rmse", 9e9)) ok = (fit >= float(self.cfg.accept_fitness)) and (rmse <= float(self.cfg.accept_rmse)) result.info["loop_candidate"] = int(cand) result.info["loop_fitness"] = fit result.info["loop_rmse"] = rmse if ok: result.loop_detected = True self._loop_count += 1 t_meas = np.asarray(reg.get("transformation", np.eye(4, dtype=np.float64)), dtype=np.float64) optimized_poses = self._optimize(idx, cand, t_meas) if optimized_poses is None: self._last_opt_reject = { "reason": "optimize_rejected", "src_idx": int(idx), "dst_idx": int(cand), } result.info["rejected"] = dict(self._last_opt_reject) else: # Correction seen at src_idx (for callers that need a global correction estimate) pose_cur = np.asarray(self.keyframe_poses[idx], dtype=np.float64) correction = optimized_poses[idx] @ np.linalg.inv(pose_cur) result.optimized = True result.correction = np.asarray(correction, dtype=np.float64) self._opt_count += 1 self.keyframe_poses = optimized_poses result.info["loop_count"] = int(self._loop_count) result.info["opt_count"] = int(self._opt_count) self._maybe_trim() return result