from __future__ import annotations import os import pickle import sys import threading from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Any, Deque, Dict, Iterable, Optional, Tuple import numpy as np def _as_bool(v: Any, default: bool) -> bool: if v is None: return default if isinstance(v, bool): return v if isinstance(v, (int, float)): return bool(v) return str(v).strip().lower() in ("1", "true", "yes", "on") def _safe_int(v: Any, default: int) -> int: try: return int(v) except Exception: return int(default) def _safe_float(v: Any, default: float) -> float: try: return float(v) except Exception: return float(default) def _safe_profiles(v: Any, default: Tuple[str, ...]) -> Tuple[str, ...]: if isinstance(v, (list, tuple)): out = [] for item in v: s = str(item).upper().strip() if s: out.append(s) if out: return tuple(out) return tuple(default) def _voxel_downsample(points: np.ndarray, voxel_m: float) -> np.ndarray: pts = np.asarray(points, dtype=np.float32) if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0: return np.zeros((0, 3), dtype=np.float32) v = float(voxel_m) if v <= 0.0: return pts keys = np.floor(pts / v).astype(np.int32) _, idx = np.unique(keys, axis=0, return_index=True) return pts[np.sort(idx)] def _pose_delta(pose_a: np.ndarray, pose_b: np.ndarray) -> Tuple[float, float]: dp = pose_a[:3, 3] - pose_b[:3, 3] trans = float(np.linalg.norm(dp)) r = pose_a[:3, :3] @ pose_b[:3, :3].T ang = float(np.degrees(np.arccos(np.clip((np.trace(r) - 1.0) * 0.5, -1.0, 1.0)))) return trans, ang @dataclass class SubmapConfig: enabled: bool = True local_window_frames: int = 18 local_voxel_m: float = 0.08 global_voxel_m: float = 0.14 merge_period_sec: float = 0.8 merge_min_translation_m: float = 0.25 merge_min_rotation_deg: float = 6.0 max_global_points: int = 350000 display_max_points: int = 250000 apply_profiles: Tuple[str, ...] = ("LOCALIZE_MAP", "LIVE_NAV_MAP") @staticmethod def from_dict(d: Dict[str, Any] | None) -> "SubmapConfig": src = d or {} return SubmapConfig( enabled=_as_bool(src.get("enabled", True), True), local_window_frames=max(4, _safe_int(src.get("local_window_frames", 18), 18)), local_voxel_m=max(0.02, _safe_float(src.get("local_voxel_m", 0.08), 0.08)), global_voxel_m=max(0.02, _safe_float(src.get("global_voxel_m", 0.14), 0.14)), merge_period_sec=max(0.1, _safe_float(src.get("merge_period_sec", 0.8), 0.8)), merge_min_translation_m=max( 0.01, _safe_float(src.get("merge_min_translation_m", 0.25), 0.25) ), merge_min_rotation_deg=max( 0.1, _safe_float(src.get("merge_min_rotation_deg", 6.0), 6.0) ), max_global_points=max(20000, _safe_int(src.get("max_global_points", 350000), 350000)), display_max_points=max(10000, _safe_int(src.get("display_max_points", 250000), 250000)), apply_profiles=_safe_profiles( src.get("apply_profiles", ("LOCALIZE_MAP", "LIVE_NAV_MAP")), ("LOCALIZE_MAP", "LIVE_NAV_MAP"), ), ) class SubmapCheckpointer: """ Periodically saves LocalGlobalSubmapMapper state to disk so the accumulated map survives restarts and crashes. Usage ----- ckpt = SubmapCheckpointer(save_dir, interval_s=60.0) # in your main loop: ckpt.maybe_save(mapper) # on startup: ckpt.load_into(mapper) # returns True if data was restored # on shutdown: ckpt.stop() """ _FILENAME = "submap_checkpoint.pkl" _PROTO = 5 # pickle protocol (requires Python ≥ 3.8) def __init__(self, save_dir: str, interval_s: float = 60.0) -> None: self._dir = Path(save_dir) self._dir.mkdir(parents=True, exist_ok=True) self._interval = max(5.0, float(interval_s)) self._dst = self._dir / self._FILENAME self._tmp = self._dir / (self._FILENAME + ".tmp") self._last_save_t: float = 0.0 self._lock = threading.Lock() def _payload(self, mapper: "LocalGlobalSubmapMapper") -> dict: """Snapshot all persistent fields from the mapper.""" return { "global_pts": np.array(mapper._global_pts, dtype=np.float32, copy=True), "frames": [np.array(f, dtype=np.float32, copy=True) for f in mapper._frames], "last_merge_t": float(mapper._last_merge_t), "last_merge_pose": ( np.array(mapper._last_merge_pose, dtype=np.float64, copy=True) if mapper._last_merge_pose is not None else None ), "merge_count": int(mapper._merge_count), "insert_count": int(mapper._insert_count), } def maybe_save(self, mapper: "LocalGlobalSubmapMapper") -> bool: """Save if the configured interval has elapsed. Thread-safe. Returns True on save.""" import time now = time.time() with self._lock: if (now - self._last_save_t) < self._interval: return False self._last_save_t = now return self._write(self._payload(mapper)) def save(self, mapper: "LocalGlobalSubmapMapper") -> bool: """Force an immediate save. Returns True on success.""" return self._write(self._payload(mapper)) def _write(self, payload: dict) -> bool: """Atomic write: pickle to .tmp then rename to avoid corrupt checkpoints.""" try: with open(self._tmp, "wb") as f: pickle.dump(payload, f, protocol=self._PROTO) os.replace(self._tmp, self._dst) # atomic on POSIX return True except Exception as exc: print(f"[SubmapCheckpointer] save failed: {exc}", file=sys.stderr) return False def load_into(self, mapper: "LocalGlobalSubmapMapper") -> bool: """ Restore checkpoint into mapper. Returns True if data was loaded. Must be called before the mapper receives any frames. """ if not self._dst.exists(): return False try: with open(self._dst, "rb") as f: payload = pickle.load(f) global_pts = np.asarray(payload["global_pts"], dtype=np.float32) frames_raw = payload.get("frames", []) mapper._global_pts = global_pts mapper._frames = deque( [np.asarray(fr, dtype=np.float32) for fr in frames_raw], maxlen=int(mapper._frames.maxlen or mapper.cfg.local_window_frames), ) mapper._last_merge_t = float(payload.get("last_merge_t", 0.0)) lmp = payload.get("last_merge_pose") mapper._last_merge_pose = ( np.asarray(lmp, dtype=np.float64) if lmp is not None else None ) mapper._merge_count = int(payload.get("merge_count", 0)) mapper._insert_count = int(payload.get("insert_count", 0)) mapper._rebuild_local() mapper._global_pts = mapper._trim(mapper._global_pts, int(mapper.cfg.max_global_points)) return True except Exception as exc: print(f"[SubmapCheckpointer] load failed (starting fresh): {exc}", file=sys.stderr) return False def delete(self) -> None: """Remove the checkpoint file (e.g., after a deliberate RESET).""" try: if self._dst.exists(): self._dst.unlink() if self._tmp.exists(): self._tmp.unlink() except Exception: pass def stop(self) -> None: """No-op — kept for API symmetry; saves are driven by maybe_save() calls.""" pass class LocalGlobalSubmapMapper: """ Maintains a short-horizon local submap and periodically merges it into a long-horizon global submap. """ def __init__(self, cfg: SubmapConfig): self.cfg = cfg self._frames: Deque[np.ndarray] = deque(maxlen=int(cfg.local_window_frames)) self._local_pts = np.zeros((0, 3), dtype=np.float32) self._global_pts = np.zeros((0, 3), dtype=np.float32) self._last_merge_t = 0.0 self._last_merge_pose: Optional[np.ndarray] = None self._merge_count = 0 self._insert_count = 0 def set_config(self, cfg: SubmapConfig, keep_points: bool = True) -> None: old_frames = list(self._frames) if keep_points else [] old_global = np.array(self._global_pts, dtype=np.float32, copy=True) if keep_points else np.zeros((0, 3), dtype=np.float32) old_last_pose = np.array(self._last_merge_pose, dtype=np.float64, copy=True) if (keep_points and self._last_merge_pose is not None) else None old_last_merge_t = float(self._last_merge_t) if keep_points else 0.0 old_merge_count = int(self._merge_count) if keep_points else 0 old_insert_count = int(self._insert_count) if keep_points else 0 self.cfg = cfg self._frames = deque(maxlen=int(cfg.local_window_frames)) if keep_points and old_frames: for fr in old_frames[-int(cfg.local_window_frames):]: self._frames.append(np.asarray(fr, dtype=np.float32)) self._local_pts = np.zeros((0, 3), dtype=np.float32) self._global_pts = np.asarray(old_global, dtype=np.float32) if keep_points else np.zeros((0, 3), dtype=np.float32) self._last_merge_pose = old_last_pose self._last_merge_t = old_last_merge_t self._merge_count = old_merge_count self._insert_count = old_insert_count self._rebuild_local() self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points)) def reset(self) -> None: self._frames.clear() self._local_pts = np.zeros((0, 3), dtype=np.float32) self._global_pts = np.zeros((0, 3), dtype=np.float32) self._last_merge_t = 0.0 self._last_merge_pose = None self._merge_count = 0 self._insert_count = 0 def apply_correction(self, transform: np.ndarray) -> bool: """ Apply a rigid transform to all currently held submap data. Useful when localization updates the global frame estimate and we want old submap points to remain consistent with new incoming points. """ if not self.has_points: return False tf = np.asarray(transform, dtype=np.float64) if tf.shape != (4, 4): return False R = np.asarray(tf[:3, :3], dtype=np.float64) t = np.asarray(tf[:3, 3], dtype=np.float64) try: if len(self._global_pts) > 0: self._global_pts = np.asarray((self._global_pts @ R.T) + t, dtype=np.float32) if len(self._frames) > 0: frames_tx = deque(maxlen=int(self._frames.maxlen or len(self._frames))) for fr in self._frames: fr_np = np.asarray(fr, dtype=np.float32) if len(fr_np) == 0: frames_tx.append(fr_np) else: frames_tx.append(np.asarray((fr_np @ R.T) + t, dtype=np.float32)) self._frames = frames_tx if self._last_merge_pose is not None and np.asarray(self._last_merge_pose).shape == (4, 4): self._last_merge_pose = tf @ np.asarray(self._last_merge_pose, dtype=np.float64) self._rebuild_local() self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points)) return True except Exception: return False @property def has_points(self) -> bool: return bool(len(self._local_pts) > 0 or len(self._global_pts) > 0) def _trim(self, pts: np.ndarray, max_n: int) -> np.ndarray: n = int(len(pts)) if n <= int(max_n): return pts stride = max(2, int(np.ceil(float(n) / float(max_n)))) out = pts[::stride] if len(out) > int(max_n): out = out[: int(max_n)] return np.asarray(out, dtype=np.float32) def _rebuild_local(self) -> None: if len(self._frames) == 0: self._local_pts = np.zeros((0, 3), dtype=np.float32) return cat = np.concatenate(list(self._frames), axis=0) ds = _voxel_downsample(cat, float(self.cfg.local_voxel_m)) self._local_pts = self._trim(ds, max(5000, int(self.cfg.display_max_points))) def _should_merge(self, now: float, pose_world: Optional[np.ndarray]) -> bool: if len(self._local_pts) == 0: return False if len(self._global_pts) == 0: return True if (float(now) - float(self._last_merge_t)) >= float(self.cfg.merge_period_sec): return True if pose_world is None or self._last_merge_pose is None: return False try: pose = np.asarray(pose_world, dtype=np.float64) if pose.shape != (4, 4): return False trans, ang = _pose_delta(pose, self._last_merge_pose) return bool( trans >= float(self.cfg.merge_min_translation_m) or ang >= float(self.cfg.merge_min_rotation_deg) ) except Exception: return False def integrate( self, points_world: np.ndarray, now: float, pose_world: Optional[np.ndarray] = None, ) -> Dict[str, Any]: pts = np.asarray(points_world, dtype=np.float32) if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0: return self.status(active=False, reason="empty") ds = _voxel_downsample(pts, float(self.cfg.local_voxel_m)) if len(ds) == 0: return self.status(active=False, reason="empty_downsample") self._frames.append(ds) self._insert_count += 1 self._rebuild_local() merged = False if self._should_merge(float(now), pose_world): if len(self._global_pts) == 0: combo = np.array(self._local_pts, dtype=np.float32, copy=True) else: combo = np.concatenate([self._global_pts, self._local_pts], axis=0) self._global_pts = _voxel_downsample(combo, float(self.cfg.global_voxel_m)) self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points)) self._last_merge_t = float(now) if pose_world is not None: pose = np.asarray(pose_world, dtype=np.float64) if pose.shape == (4, 4): self._last_merge_pose = np.array(pose, dtype=np.float64, copy=True) merged = True self._merge_count += 1 out = self.status(active=True, reason="ok") out["merged"] = bool(merged) return out def get_display_points(self) -> np.ndarray: if len(self._global_pts) == 0 and len(self._local_pts) == 0: return np.zeros((0, 3), dtype=np.float32) if len(self._global_pts) == 0: return np.asarray(self._local_pts, dtype=np.float32) if len(self._local_pts) == 0: return self._trim(np.asarray(self._global_pts, dtype=np.float32), int(self.cfg.display_max_points)) combo = np.concatenate([self._global_pts, self._local_pts], axis=0) disp = _voxel_downsample(combo, float(self.cfg.local_voxel_m)) return self._trim(disp, int(self.cfg.display_max_points)) def status( self, active: bool, reason: str = "ok", profile: str = "", profile_allowed: bool = True, ) -> Dict[str, Any]: return { "enabled": bool(self.cfg.enabled), "active": bool(active), "reason": str(reason), "profile": str(profile).upper().strip(), "profile_allowed": bool(profile_allowed), "local_points": int(len(self._local_pts)), "global_points": int(len(self._global_pts)), "window_frames": int(len(self._frames)), "merge_count": int(self._merge_count), "insert_count": int(self._insert_count), }