Marcus/Lidar/SLAM_Submap.py
2026-04-12 18:50:22 +04:00

414 lines
16 KiB
Python

from __future__ import annotations
import os
import pickle
import sys
import threading
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Deque, Dict, Iterable, Optional, Tuple
import numpy as np
def _as_bool(v: Any, default: bool) -> bool:
if v is None:
return default
if isinstance(v, bool):
return v
if isinstance(v, (int, float)):
return bool(v)
return str(v).strip().lower() in ("1", "true", "yes", "on")
def _safe_int(v: Any, default: int) -> int:
try:
return int(v)
except Exception:
return int(default)
def _safe_float(v: Any, default: float) -> float:
try:
return float(v)
except Exception:
return float(default)
def _safe_profiles(v: Any, default: Tuple[str, ...]) -> Tuple[str, ...]:
if isinstance(v, (list, tuple)):
out = []
for item in v:
s = str(item).upper().strip()
if s:
out.append(s)
if out:
return tuple(out)
return tuple(default)
def _voxel_downsample(points: np.ndarray, voxel_m: float) -> np.ndarray:
pts = np.asarray(points, dtype=np.float32)
if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0:
return np.zeros((0, 3), dtype=np.float32)
v = float(voxel_m)
if v <= 0.0:
return pts
keys = np.floor(pts / v).astype(np.int32)
_, idx = np.unique(keys, axis=0, return_index=True)
return pts[np.sort(idx)]
def _pose_delta(pose_a: np.ndarray, pose_b: np.ndarray) -> Tuple[float, float]:
dp = pose_a[:3, 3] - pose_b[:3, 3]
trans = float(np.linalg.norm(dp))
r = pose_a[:3, :3] @ pose_b[:3, :3].T
ang = float(np.degrees(np.arccos(np.clip((np.trace(r) - 1.0) * 0.5, -1.0, 1.0))))
return trans, ang
@dataclass
class SubmapConfig:
enabled: bool = True
local_window_frames: int = 18
local_voxel_m: float = 0.08
global_voxel_m: float = 0.14
merge_period_sec: float = 0.8
merge_min_translation_m: float = 0.25
merge_min_rotation_deg: float = 6.0
max_global_points: int = 350000
display_max_points: int = 250000
apply_profiles: Tuple[str, ...] = ("LOCALIZE_MAP", "LIVE_NAV_MAP")
@staticmethod
def from_dict(d: Dict[str, Any] | None) -> "SubmapConfig":
src = d or {}
return SubmapConfig(
enabled=_as_bool(src.get("enabled", True), True),
local_window_frames=max(4, _safe_int(src.get("local_window_frames", 18), 18)),
local_voxel_m=max(0.02, _safe_float(src.get("local_voxel_m", 0.08), 0.08)),
global_voxel_m=max(0.02, _safe_float(src.get("global_voxel_m", 0.14), 0.14)),
merge_period_sec=max(0.1, _safe_float(src.get("merge_period_sec", 0.8), 0.8)),
merge_min_translation_m=max(
0.01, _safe_float(src.get("merge_min_translation_m", 0.25), 0.25)
),
merge_min_rotation_deg=max(
0.1, _safe_float(src.get("merge_min_rotation_deg", 6.0), 6.0)
),
max_global_points=max(20000, _safe_int(src.get("max_global_points", 350000), 350000)),
display_max_points=max(10000, _safe_int(src.get("display_max_points", 250000), 250000)),
apply_profiles=_safe_profiles(
src.get("apply_profiles", ("LOCALIZE_MAP", "LIVE_NAV_MAP")),
("LOCALIZE_MAP", "LIVE_NAV_MAP"),
),
)
class SubmapCheckpointer:
"""
Periodically saves LocalGlobalSubmapMapper state to disk so the accumulated
map survives restarts and crashes.
Usage
-----
ckpt = SubmapCheckpointer(save_dir, interval_s=60.0)
# in your main loop:
ckpt.maybe_save(mapper)
# on startup:
ckpt.load_into(mapper) # returns True if data was restored
# on shutdown:
ckpt.stop()
"""
_FILENAME = "submap_checkpoint.pkl"
_PROTO = 5 # pickle protocol (requires Python ≥ 3.8)
def __init__(self, save_dir: str, interval_s: float = 60.0) -> None:
self._dir = Path(save_dir)
self._dir.mkdir(parents=True, exist_ok=True)
self._interval = max(5.0, float(interval_s))
self._dst = self._dir / self._FILENAME
self._tmp = self._dir / (self._FILENAME + ".tmp")
self._last_save_t: float = 0.0
self._lock = threading.Lock()
def _payload(self, mapper: "LocalGlobalSubmapMapper") -> dict:
"""Snapshot all persistent fields from the mapper."""
return {
"global_pts": np.array(mapper._global_pts, dtype=np.float32, copy=True),
"frames": [np.array(f, dtype=np.float32, copy=True) for f in mapper._frames],
"last_merge_t": float(mapper._last_merge_t),
"last_merge_pose": (
np.array(mapper._last_merge_pose, dtype=np.float64, copy=True)
if mapper._last_merge_pose is not None
else None
),
"merge_count": int(mapper._merge_count),
"insert_count": int(mapper._insert_count),
}
def maybe_save(self, mapper: "LocalGlobalSubmapMapper") -> bool:
"""Save if the configured interval has elapsed. Thread-safe. Returns True on save."""
import time
now = time.time()
with self._lock:
if (now - self._last_save_t) < self._interval:
return False
self._last_save_t = now
return self._write(self._payload(mapper))
def save(self, mapper: "LocalGlobalSubmapMapper") -> bool:
"""Force an immediate save. Returns True on success."""
return self._write(self._payload(mapper))
def _write(self, payload: dict) -> bool:
"""Atomic write: pickle to .tmp then rename to avoid corrupt checkpoints."""
try:
with open(self._tmp, "wb") as f:
pickle.dump(payload, f, protocol=self._PROTO)
os.replace(self._tmp, self._dst) # atomic on POSIX
return True
except Exception as exc:
print(f"[SubmapCheckpointer] save failed: {exc}", file=sys.stderr)
return False
def load_into(self, mapper: "LocalGlobalSubmapMapper") -> bool:
"""
Restore checkpoint into mapper. Returns True if data was loaded.
Must be called before the mapper receives any frames.
"""
if not self._dst.exists():
return False
try:
with open(self._dst, "rb") as f:
payload = pickle.load(f)
global_pts = np.asarray(payload["global_pts"], dtype=np.float32)
frames_raw = payload.get("frames", [])
mapper._global_pts = global_pts
mapper._frames = deque(
[np.asarray(fr, dtype=np.float32) for fr in frames_raw],
maxlen=int(mapper._frames.maxlen or mapper.cfg.local_window_frames),
)
mapper._last_merge_t = float(payload.get("last_merge_t", 0.0))
lmp = payload.get("last_merge_pose")
mapper._last_merge_pose = (
np.asarray(lmp, dtype=np.float64) if lmp is not None else None
)
mapper._merge_count = int(payload.get("merge_count", 0))
mapper._insert_count = int(payload.get("insert_count", 0))
mapper._rebuild_local()
mapper._global_pts = mapper._trim(mapper._global_pts, int(mapper.cfg.max_global_points))
return True
except Exception as exc:
print(f"[SubmapCheckpointer] load failed (starting fresh): {exc}", file=sys.stderr)
return False
def delete(self) -> None:
"""Remove the checkpoint file (e.g., after a deliberate RESET)."""
try:
if self._dst.exists():
self._dst.unlink()
if self._tmp.exists():
self._tmp.unlink()
except Exception:
pass
def stop(self) -> None:
"""No-op — kept for API symmetry; saves are driven by maybe_save() calls."""
pass
class LocalGlobalSubmapMapper:
"""
Maintains a short-horizon local submap and periodically merges it into a
long-horizon global submap.
"""
def __init__(self, cfg: SubmapConfig):
self.cfg = cfg
self._frames: Deque[np.ndarray] = deque(maxlen=int(cfg.local_window_frames))
self._local_pts = np.zeros((0, 3), dtype=np.float32)
self._global_pts = np.zeros((0, 3), dtype=np.float32)
self._last_merge_t = 0.0
self._last_merge_pose: Optional[np.ndarray] = None
self._merge_count = 0
self._insert_count = 0
def set_config(self, cfg: SubmapConfig, keep_points: bool = True) -> None:
old_frames = list(self._frames) if keep_points else []
old_global = np.array(self._global_pts, dtype=np.float32, copy=True) if keep_points else np.zeros((0, 3), dtype=np.float32)
old_last_pose = np.array(self._last_merge_pose, dtype=np.float64, copy=True) if (keep_points and self._last_merge_pose is not None) else None
old_last_merge_t = float(self._last_merge_t) if keep_points else 0.0
old_merge_count = int(self._merge_count) if keep_points else 0
old_insert_count = int(self._insert_count) if keep_points else 0
self.cfg = cfg
self._frames = deque(maxlen=int(cfg.local_window_frames))
if keep_points and old_frames:
for fr in old_frames[-int(cfg.local_window_frames):]:
self._frames.append(np.asarray(fr, dtype=np.float32))
self._local_pts = np.zeros((0, 3), dtype=np.float32)
self._global_pts = np.asarray(old_global, dtype=np.float32) if keep_points else np.zeros((0, 3), dtype=np.float32)
self._last_merge_pose = old_last_pose
self._last_merge_t = old_last_merge_t
self._merge_count = old_merge_count
self._insert_count = old_insert_count
self._rebuild_local()
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
def reset(self) -> None:
self._frames.clear()
self._local_pts = np.zeros((0, 3), dtype=np.float32)
self._global_pts = np.zeros((0, 3), dtype=np.float32)
self._last_merge_t = 0.0
self._last_merge_pose = None
self._merge_count = 0
self._insert_count = 0
def apply_correction(self, transform: np.ndarray) -> bool:
"""
Apply a rigid transform to all currently held submap data.
Useful when localization updates the global frame estimate and we
want old submap points to remain consistent with new incoming points.
"""
if not self.has_points:
return False
tf = np.asarray(transform, dtype=np.float64)
if tf.shape != (4, 4):
return False
R = np.asarray(tf[:3, :3], dtype=np.float64)
t = np.asarray(tf[:3, 3], dtype=np.float64)
try:
if len(self._global_pts) > 0:
self._global_pts = np.asarray((self._global_pts @ R.T) + t, dtype=np.float32)
if len(self._frames) > 0:
frames_tx = deque(maxlen=int(self._frames.maxlen or len(self._frames)))
for fr in self._frames:
fr_np = np.asarray(fr, dtype=np.float32)
if len(fr_np) == 0:
frames_tx.append(fr_np)
else:
frames_tx.append(np.asarray((fr_np @ R.T) + t, dtype=np.float32))
self._frames = frames_tx
if self._last_merge_pose is not None and np.asarray(self._last_merge_pose).shape == (4, 4):
self._last_merge_pose = tf @ np.asarray(self._last_merge_pose, dtype=np.float64)
self._rebuild_local()
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
return True
except Exception:
return False
@property
def has_points(self) -> bool:
return bool(len(self._local_pts) > 0 or len(self._global_pts) > 0)
def _trim(self, pts: np.ndarray, max_n: int) -> np.ndarray:
n = int(len(pts))
if n <= int(max_n):
return pts
stride = max(2, int(np.ceil(float(n) / float(max_n))))
out = pts[::stride]
if len(out) > int(max_n):
out = out[: int(max_n)]
return np.asarray(out, dtype=np.float32)
def _rebuild_local(self) -> None:
if len(self._frames) == 0:
self._local_pts = np.zeros((0, 3), dtype=np.float32)
return
cat = np.concatenate(list(self._frames), axis=0)
ds = _voxel_downsample(cat, float(self.cfg.local_voxel_m))
self._local_pts = self._trim(ds, max(5000, int(self.cfg.display_max_points)))
def _should_merge(self, now: float, pose_world: Optional[np.ndarray]) -> bool:
if len(self._local_pts) == 0:
return False
if len(self._global_pts) == 0:
return True
if (float(now) - float(self._last_merge_t)) >= float(self.cfg.merge_period_sec):
return True
if pose_world is None or self._last_merge_pose is None:
return False
try:
pose = np.asarray(pose_world, dtype=np.float64)
if pose.shape != (4, 4):
return False
trans, ang = _pose_delta(pose, self._last_merge_pose)
return bool(
trans >= float(self.cfg.merge_min_translation_m)
or ang >= float(self.cfg.merge_min_rotation_deg)
)
except Exception:
return False
def integrate(
self,
points_world: np.ndarray,
now: float,
pose_world: Optional[np.ndarray] = None,
) -> Dict[str, Any]:
pts = np.asarray(points_world, dtype=np.float32)
if pts.ndim != 2 or pts.shape[1] != 3 or len(pts) == 0:
return self.status(active=False, reason="empty")
ds = _voxel_downsample(pts, float(self.cfg.local_voxel_m))
if len(ds) == 0:
return self.status(active=False, reason="empty_downsample")
self._frames.append(ds)
self._insert_count += 1
self._rebuild_local()
merged = False
if self._should_merge(float(now), pose_world):
if len(self._global_pts) == 0:
combo = np.array(self._local_pts, dtype=np.float32, copy=True)
else:
combo = np.concatenate([self._global_pts, self._local_pts], axis=0)
self._global_pts = _voxel_downsample(combo, float(self.cfg.global_voxel_m))
self._global_pts = self._trim(self._global_pts, int(self.cfg.max_global_points))
self._last_merge_t = float(now)
if pose_world is not None:
pose = np.asarray(pose_world, dtype=np.float64)
if pose.shape == (4, 4):
self._last_merge_pose = np.array(pose, dtype=np.float64, copy=True)
merged = True
self._merge_count += 1
out = self.status(active=True, reason="ok")
out["merged"] = bool(merged)
return out
def get_display_points(self) -> np.ndarray:
if len(self._global_pts) == 0 and len(self._local_pts) == 0:
return np.zeros((0, 3), dtype=np.float32)
if len(self._global_pts) == 0:
return np.asarray(self._local_pts, dtype=np.float32)
if len(self._local_pts) == 0:
return self._trim(np.asarray(self._global_pts, dtype=np.float32), int(self.cfg.display_max_points))
combo = np.concatenate([self._global_pts, self._local_pts], axis=0)
disp = _voxel_downsample(combo, float(self.cfg.local_voxel_m))
return self._trim(disp, int(self.cfg.display_max_points))
def status(
self,
active: bool,
reason: str = "ok",
profile: str = "",
profile_allowed: bool = True,
) -> Dict[str, Any]:
return {
"enabled": bool(self.cfg.enabled),
"active": bool(active),
"reason": str(reason),
"profile": str(profile).upper().strip(),
"profile_allowed": bool(profile_allowed),
"local_points": int(len(self._local_pts)),
"global_points": int(len(self._global_pts)),
"window_frames": int(len(self._frames)),
"merge_count": int(self._merge_count),
"insert_count": int(self._insert_count),
}