414 lines
16 KiB
Python
414 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import threading
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Deque, Dict, Iterable, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
|
|
def _as_bool(v: Any, default: bool) -> bool:
|
|
if v is None:
|
|
return default
|
|
if isinstance(v, bool):
|
|
return v
|
|
if isinstance(v, (int, float)):
|
|
return bool(v)
|
|
return str(v).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def _safe_int(v: Any, default: int) -> int:
|
|
try:
|
|
return int(v)
|
|
except Exception:
|
|
return int(default)
|
|
|
|
|
|
def _safe_float(v: Any, default: float) -> float:
|
|
try:
|
|
return float(v)
|
|
except Exception:
|
|
return float(default)
|
|
|
|
|
|
def _safe_profiles(v: Any, default: Tuple[str, ...]) -> Tuple[str, ...]:
|
|
if isinstance(v, (list, tuple)):
|
|
out = []
|
|
for item in v:
|
|
s = str(item).upper().strip()
|
|
if s:
|
|
out.append(s)
|
|
if out:
|
|
return tuple(out)
|
|
return tuple(default)
|
|
|
|
|
|
def _voxel_downsample(points: np.ndarray, voxel_m: float) -> np.ndarray:
|
|
pts = np.asarray(points, dtype=np.float32)
|
|
if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0:
|
|
return np.zeros((0, 3), dtype=np.float32)
|
|
v = float(voxel_m)
|
|
if v <= 0.0:
|
|
return pts
|
|
keys = np.floor(pts / v).astype(np.int32)
|
|
_, idx = np.unique(keys, axis=0, return_index=True)
|
|
return pts[np.sort(idx)]
|
|
|
|
|
|
def _pose_delta(pose_a: np.ndarray, pose_b: np.ndarray) -> Tuple[float, float]:
|
|
dp = pose_a[:3, 3] - pose_b[:3, 3]
|
|
trans = float(np.linalg.norm(dp))
|
|
r = pose_a[:3, :3] @ pose_b[:3, :3].T
|
|
ang = float(np.degrees(np.arccos(np.clip((np.trace(r) - 1.0) * 0.5, -1.0, 1.0))))
|
|
return trans, ang
|
|
|
|
|
|
@dataclass
|
|
class SubmapConfig:
|
|
enabled: bool = True
|
|
local_window_frames: int = 18
|
|
local_voxel_m: float = 0.08
|
|
global_voxel_m: float = 0.14
|
|
merge_period_sec: float = 0.8
|
|
merge_min_translation_m: float = 0.25
|
|
merge_min_rotation_deg: float = 6.0
|
|
max_global_points: int = 350000
|
|
display_max_points: int = 250000
|
|
apply_profiles: Tuple[str, ...] = ("LOCALIZE_MAP", "LIVE_NAV_MAP")
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any] | None) -> "SubmapConfig":
|
|
src = d or {}
|
|
return SubmapConfig(
|
|
enabled=_as_bool(src.get("enabled", True), True),
|
|
local_window_frames=max(4, _safe_int(src.get("local_window_frames", 18), 18)),
|
|
local_voxel_m=max(0.02, _safe_float(src.get("local_voxel_m", 0.08), 0.08)),
|
|
global_voxel_m=max(0.02, _safe_float(src.get("global_voxel_m", 0.14), 0.14)),
|
|
merge_period_sec=max(0.1, _safe_float(src.get("merge_period_sec", 0.8), 0.8)),
|
|
merge_min_translation_m=max(
|
|
0.01, _safe_float(src.get("merge_min_translation_m", 0.25), 0.25)
|
|
),
|
|
merge_min_rotation_deg=max(
|
|
0.1, _safe_float(src.get("merge_min_rotation_deg", 6.0), 6.0)
|
|
),
|
|
max_global_points=max(20000, _safe_int(src.get("max_global_points", 350000), 350000)),
|
|
display_max_points=max(10000, _safe_int(src.get("display_max_points", 250000), 250000)),
|
|
apply_profiles=_safe_profiles(
|
|
src.get("apply_profiles", ("LOCALIZE_MAP", "LIVE_NAV_MAP")),
|
|
("LOCALIZE_MAP", "LIVE_NAV_MAP"),
|
|
),
|
|
)
|
|
|
|
|
|
class SubmapCheckpointer:
|
|
"""
|
|
Periodically saves LocalGlobalSubmapMapper state to disk so the accumulated
|
|
map survives restarts and crashes.
|
|
|
|
Usage
|
|
-----
|
|
ckpt = SubmapCheckpointer(save_dir, interval_s=60.0)
|
|
# in your main loop:
|
|
ckpt.maybe_save(mapper)
|
|
# on startup:
|
|
ckpt.load_into(mapper) # returns True if data was restored
|
|
# on shutdown:
|
|
ckpt.stop()
|
|
"""
|
|
|
|
_FILENAME = "submap_checkpoint.pkl"
|
|
_PROTO = 5 # pickle protocol (requires Python ≥ 3.8)
|
|
|
|
def __init__(self, save_dir: str, interval_s: float = 60.0) -> None:
|
|
self._dir = Path(save_dir)
|
|
self._dir.mkdir(parents=True, exist_ok=True)
|
|
self._interval = max(5.0, float(interval_s))
|
|
self._dst = self._dir / self._FILENAME
|
|
self._tmp = self._dir / (self._FILENAME + ".tmp")
|
|
self._last_save_t: float = 0.0
|
|
self._lock = threading.Lock()
|
|
|
|
def _payload(self, mapper: "LocalGlobalSubmapMapper") -> dict:
|
|
"""Snapshot all persistent fields from the mapper."""
|
|
return {
|
|
"global_pts": np.array(mapper._global_pts, dtype=np.float32, copy=True),
|
|
"frames": [np.array(f, dtype=np.float32, copy=True) for f in mapper._frames],
|
|
"last_merge_t": float(mapper._last_merge_t),
|
|
"last_merge_pose": (
|
|
np.array(mapper._last_merge_pose, dtype=np.float64, copy=True)
|
|
if mapper._last_merge_pose is not None
|
|
else None
|
|
),
|
|
"merge_count": int(mapper._merge_count),
|
|
"insert_count": int(mapper._insert_count),
|
|
}
|
|
|
|
def maybe_save(self, mapper: "LocalGlobalSubmapMapper") -> bool:
|
|
"""Save if the configured interval has elapsed. Thread-safe. Returns True on save."""
|
|
import time
|
|
now = time.time()
|
|
with self._lock:
|
|
if (now - self._last_save_t) < self._interval:
|
|
return False
|
|
self._last_save_t = now
|
|
|
|
return self._write(self._payload(mapper))
|
|
|
|
def save(self, mapper: "LocalGlobalSubmapMapper") -> bool:
|
|
"""Force an immediate save. Returns True on success."""
|
|
return self._write(self._payload(mapper))
|
|
|
|
def _write(self, payload: dict) -> bool:
|
|
"""Atomic write: pickle to .tmp then rename to avoid corrupt checkpoints."""
|
|
try:
|
|
with open(self._tmp, "wb") as f:
|
|
pickle.dump(payload, f, protocol=self._PROTO)
|
|
os.replace(self._tmp, self._dst) # atomic on POSIX
|
|
return True
|
|
except Exception as exc:
|
|
print(f"[SubmapCheckpointer] save failed: {exc}", file=sys.stderr)
|
|
return False
|
|
|
|
def load_into(self, mapper: "LocalGlobalSubmapMapper") -> bool:
|
|
"""
|
|
Restore checkpoint into mapper. Returns True if data was loaded.
|
|
Must be called before the mapper receives any frames.
|
|
"""
|
|
if not self._dst.exists():
|
|
return False
|
|
try:
|
|
with open(self._dst, "rb") as f:
|
|
payload = pickle.load(f)
|
|
global_pts = np.asarray(payload["global_pts"], dtype=np.float32)
|
|
frames_raw = payload.get("frames", [])
|
|
mapper._global_pts = global_pts
|
|
mapper._frames = deque(
|
|
[np.asarray(fr, dtype=np.float32) for fr in frames_raw],
|
|
maxlen=int(mapper._frames.maxlen or mapper.cfg.local_window_frames),
|
|
)
|
|
mapper._last_merge_t = float(payload.get("last_merge_t", 0.0))
|
|
lmp = payload.get("last_merge_pose")
|
|
mapper._last_merge_pose = (
|
|
np.asarray(lmp, dtype=np.float64) if lmp is not None else None
|
|
)
|
|
mapper._merge_count = int(payload.get("merge_count", 0))
|
|
mapper._insert_count = int(payload.get("insert_count", 0))
|
|
mapper._rebuild_local()
|
|
mapper._global_pts = mapper._trim(mapper._global_pts, int(mapper.cfg.max_global_points))
|
|
return True
|
|
except Exception as exc:
|
|
print(f"[SubmapCheckpointer] load failed (starting fresh): {exc}", file=sys.stderr)
|
|
return False
|
|
|
|
def delete(self) -> None:
|
|
"""Remove the checkpoint file (e.g., after a deliberate RESET)."""
|
|
try:
|
|
if self._dst.exists():
|
|
self._dst.unlink()
|
|
if self._tmp.exists():
|
|
self._tmp.unlink()
|
|
except Exception:
|
|
pass
|
|
|
|
def stop(self) -> None:
|
|
"""No-op — kept for API symmetry; saves are driven by maybe_save() calls."""
|
|
pass
|
|
|
|
|
|
class LocalGlobalSubmapMapper:
|
|
"""
|
|
Maintains a short-horizon local submap and periodically merges it into a
|
|
long-horizon global submap.
|
|
"""
|
|
|
|
def __init__(self, cfg: SubmapConfig):
|
|
self.cfg = cfg
|
|
self._frames: Deque[np.ndarray] = deque(maxlen=int(cfg.local_window_frames))
|
|
self._local_pts = np.zeros((0, 3), dtype=np.float32)
|
|
self._global_pts = np.zeros((0, 3), dtype=np.float32)
|
|
self._last_merge_t = 0.0
|
|
self._last_merge_pose: Optional[np.ndarray] = None
|
|
self._merge_count = 0
|
|
self._insert_count = 0
|
|
|
|
def set_config(self, cfg: SubmapConfig, keep_points: bool = True) -> None:
|
|
old_frames = list(self._frames) if keep_points else []
|
|
old_global = np.array(self._global_pts, dtype=np.float32, copy=True) if keep_points else np.zeros((0, 3), dtype=np.float32)
|
|
old_last_pose = np.array(self._last_merge_pose, dtype=np.float64, copy=True) if (keep_points and self._last_merge_pose is not None) else None
|
|
old_last_merge_t = float(self._last_merge_t) if keep_points else 0.0
|
|
old_merge_count = int(self._merge_count) if keep_points else 0
|
|
old_insert_count = int(self._insert_count) if keep_points else 0
|
|
|
|
self.cfg = cfg
|
|
self._frames = deque(maxlen=int(cfg.local_window_frames))
|
|
if keep_points and old_frames:
|
|
for fr in old_frames[-int(cfg.local_window_frames):]:
|
|
self._frames.append(np.asarray(fr, dtype=np.float32))
|
|
self._local_pts = np.zeros((0, 3), dtype=np.float32)
|
|
self._global_pts = np.asarray(old_global, dtype=np.float32) if keep_points else np.zeros((0, 3), dtype=np.float32)
|
|
self._last_merge_pose = old_last_pose
|
|
self._last_merge_t = old_last_merge_t
|
|
self._merge_count = old_merge_count
|
|
self._insert_count = old_insert_count
|
|
self._rebuild_local()
|
|
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
|
|
|
|
def reset(self) -> None:
|
|
self._frames.clear()
|
|
self._local_pts = np.zeros((0, 3), dtype=np.float32)
|
|
self._global_pts = np.zeros((0, 3), dtype=np.float32)
|
|
self._last_merge_t = 0.0
|
|
self._last_merge_pose = None
|
|
self._merge_count = 0
|
|
self._insert_count = 0
|
|
|
|
def apply_correction(self, transform: np.ndarray) -> bool:
|
|
"""
|
|
Apply a rigid transform to all currently held submap data.
|
|
Useful when localization updates the global frame estimate and we
|
|
want old submap points to remain consistent with new incoming points.
|
|
"""
|
|
if not self.has_points:
|
|
return False
|
|
tf = np.asarray(transform, dtype=np.float64)
|
|
if tf.shape != (4, 4):
|
|
return False
|
|
R = np.asarray(tf[:3, :3], dtype=np.float64)
|
|
t = np.asarray(tf[:3, 3], dtype=np.float64)
|
|
try:
|
|
if len(self._global_pts) > 0:
|
|
self._global_pts = np.asarray((self._global_pts @ R.T) + t, dtype=np.float32)
|
|
if len(self._frames) > 0:
|
|
frames_tx = deque(maxlen=int(self._frames.maxlen or len(self._frames)))
|
|
for fr in self._frames:
|
|
fr_np = np.asarray(fr, dtype=np.float32)
|
|
if len(fr_np) == 0:
|
|
frames_tx.append(fr_np)
|
|
else:
|
|
frames_tx.append(np.asarray((fr_np @ R.T) + t, dtype=np.float32))
|
|
self._frames = frames_tx
|
|
if self._last_merge_pose is not None and np.asarray(self._last_merge_pose).shape == (4, 4):
|
|
self._last_merge_pose = tf @ np.asarray(self._last_merge_pose, dtype=np.float64)
|
|
self._rebuild_local()
|
|
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
@property
|
|
def has_points(self) -> bool:
|
|
return bool(len(self._local_pts) > 0 or len(self._global_pts) > 0)
|
|
|
|
def _trim(self, pts: np.ndarray, max_n: int) -> np.ndarray:
|
|
n = int(len(pts))
|
|
if n <= int(max_n):
|
|
return pts
|
|
stride = max(2, int(np.ceil(float(n) / float(max_n))))
|
|
out = pts[::stride]
|
|
if len(out) > int(max_n):
|
|
out = out[: int(max_n)]
|
|
return np.asarray(out, dtype=np.float32)
|
|
|
|
def _rebuild_local(self) -> None:
|
|
if len(self._frames) == 0:
|
|
self._local_pts = np.zeros((0, 3), dtype=np.float32)
|
|
return
|
|
cat = np.concatenate(list(self._frames), axis=0)
|
|
ds = _voxel_downsample(cat, float(self.cfg.local_voxel_m))
|
|
self._local_pts = self._trim(ds, max(5000, int(self.cfg.display_max_points)))
|
|
|
|
def _should_merge(self, now: float, pose_world: Optional[np.ndarray]) -> bool:
|
|
if len(self._local_pts) == 0:
|
|
return False
|
|
if len(self._global_pts) == 0:
|
|
return True
|
|
if (float(now) - float(self._last_merge_t)) >= float(self.cfg.merge_period_sec):
|
|
return True
|
|
if pose_world is None or self._last_merge_pose is None:
|
|
return False
|
|
try:
|
|
pose = np.asarray(pose_world, dtype=np.float64)
|
|
if pose.shape != (4, 4):
|
|
return False
|
|
trans, ang = _pose_delta(pose, self._last_merge_pose)
|
|
return bool(
|
|
trans >= float(self.cfg.merge_min_translation_m)
|
|
or ang >= float(self.cfg.merge_min_rotation_deg)
|
|
)
|
|
except Exception:
|
|
return False
|
|
|
|
def integrate(
|
|
self,
|
|
points_world: np.ndarray,
|
|
now: float,
|
|
pose_world: Optional[np.ndarray] = None,
|
|
) -> Dict[str, Any]:
|
|
pts = np.asarray(points_world, dtype=np.float32)
|
|
if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0:
|
|
return self.status(active=False, reason="empty")
|
|
|
|
ds = _voxel_downsample(pts, float(self.cfg.local_voxel_m))
|
|
if len(ds) == 0:
|
|
return self.status(active=False, reason="empty_downsample")
|
|
|
|
self._frames.append(ds)
|
|
self._insert_count += 1
|
|
self._rebuild_local()
|
|
|
|
merged = False
|
|
if self._should_merge(float(now), pose_world):
|
|
if len(self._global_pts) == 0:
|
|
combo = np.array(self._local_pts, dtype=np.float32, copy=True)
|
|
else:
|
|
combo = np.concatenate([self._global_pts, self._local_pts], axis=0)
|
|
self._global_pts = _voxel_downsample(combo, float(self.cfg.global_voxel_m))
|
|
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
|
|
self._last_merge_t = float(now)
|
|
if pose_world is not None:
|
|
pose = np.asarray(pose_world, dtype=np.float64)
|
|
if pose.shape == (4, 4):
|
|
self._last_merge_pose = np.array(pose, dtype=np.float64, copy=True)
|
|
merged = True
|
|
self._merge_count += 1
|
|
|
|
out = self.status(active=True, reason="ok")
|
|
out["merged"] = bool(merged)
|
|
return out
|
|
|
|
def get_display_points(self) -> np.ndarray:
|
|
if len(self._global_pts) == 0 and len(self._local_pts) == 0:
|
|
return np.zeros((0, 3), dtype=np.float32)
|
|
if len(self._global_pts) == 0:
|
|
return np.asarray(self._local_pts, dtype=np.float32)
|
|
if len(self._local_pts) == 0:
|
|
return self._trim(np.asarray(self._global_pts, dtype=np.float32), int(self.cfg.display_max_points))
|
|
combo = np.concatenate([self._global_pts, self._local_pts], axis=0)
|
|
disp = _voxel_downsample(combo, float(self.cfg.local_voxel_m))
|
|
return self._trim(disp, int(self.cfg.display_max_points))
|
|
|
|
def status(
|
|
self,
|
|
active: bool,
|
|
reason: str = "ok",
|
|
profile: str = "",
|
|
profile_allowed: bool = True,
|
|
) -> Dict[str, Any]:
|
|
return {
|
|
"enabled": bool(self.cfg.enabled),
|
|
"active": bool(active),
|
|
"reason": str(reason),
|
|
"profile": str(profile).upper().strip(),
|
|
"profile_allowed": bool(profile_allowed),
|
|
"local_points": int(len(self._local_pts)),
|
|
"global_points": int(len(self._global_pts)),
|
|
"window_frames": int(len(self._frames)),
|
|
"merge_count": int(self._merge_count),
|
|
"insert_count": int(self._insert_count),
|
|
}
|