126 lines
3.9 KiB
Python
126 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import numpy as np
|
|
|
|
|
|
def _safe_int(v: Any, default: int) -> int:
|
|
try:
|
|
return int(v)
|
|
except Exception:
|
|
return int(default)
|
|
|
|
|
|
def _as_bool(v: Any, default: bool) -> bool:
|
|
if v is None:
|
|
return default
|
|
if isinstance(v, bool):
|
|
return v
|
|
if isinstance(v, (int, float)):
|
|
return bool(v)
|
|
return str(v).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
@dataclass
|
|
class SessionMemoryConfig:
|
|
enabled: bool = True
|
|
filename: str = "SLAM_session_memory.json"
|
|
max_entries: int = 200
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any] | None) -> "SessionMemoryConfig":
|
|
src = d or {}
|
|
return SessionMemoryConfig(
|
|
enabled=_as_bool(src.get("enabled", True), True),
|
|
filename=str(src.get("filename", "SLAM_session_memory.json")),
|
|
max_entries=max(20, _safe_int(src.get("max_entries", 200), 200)),
|
|
)
|
|
|
|
|
|
class SessionTransformStore:
|
|
def __init__(self, data_folder: str, cfg: SessionMemoryConfig):
|
|
self.cfg = cfg
|
|
self.data_folder = Path(data_folder)
|
|
self.data_folder.mkdir(parents=True, exist_ok=True)
|
|
self.path = self.data_folder / self.cfg.filename
|
|
self._db: Dict[str, Any] = {"entries": {}}
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
if not self.cfg.enabled:
|
|
return
|
|
try:
|
|
if not self.path.exists():
|
|
return
|
|
obj = json.loads(self.path.read_text(encoding="utf-8"))
|
|
if isinstance(obj, dict) and isinstance(obj.get("entries"), dict):
|
|
self._db = obj
|
|
except Exception:
|
|
self._db = {"entries": {}}
|
|
|
|
def _save(self) -> None:
|
|
if not self.cfg.enabled:
|
|
return
|
|
try:
|
|
self.path.write_text(json.dumps(self._db, indent=2), encoding="utf-8")
|
|
except Exception:
|
|
pass
|
|
|
|
def _key(self, ref_map_path: str) -> str:
|
|
try:
|
|
# os.path.realpath resolves symlinks reliably even for non-existent paths
|
|
return os.path.realpath(os.path.expanduser(str(ref_map_path)))
|
|
except Exception:
|
|
return str(ref_map_path)
|
|
|
|
def get_transform(self, ref_map_path: str) -> Optional[np.ndarray]:
|
|
if not self.cfg.enabled:
|
|
return None
|
|
key = self._key(ref_map_path)
|
|
ent = self._db.get("entries", {}).get(key)
|
|
if not isinstance(ent, dict):
|
|
return None
|
|
mat = ent.get("transform")
|
|
if not isinstance(mat, list):
|
|
return None
|
|
try:
|
|
arr = np.asarray(mat, dtype=np.float64)
|
|
if arr.shape != (4, 4):
|
|
return None
|
|
return arr
|
|
except Exception:
|
|
return None
|
|
|
|
def record_success(self, ref_map_path: str, transform: np.ndarray, fitness: float, rmse: float) -> None:
|
|
if not self.cfg.enabled:
|
|
return
|
|
key = self._key(ref_map_path)
|
|
try:
|
|
tf = np.asarray(transform, dtype=np.float64)
|
|
if tf.shape != (4, 4):
|
|
return
|
|
entries = self._db.setdefault("entries", {})
|
|
entries[key] = {
|
|
"ref_map": key,
|
|
"timestamp": float(time.time()),
|
|
"fitness": float(fitness),
|
|
"rmse": float(rmse),
|
|
"transform": tf.tolist(),
|
|
}
|
|
|
|
# Keep latest-N entries by timestamp
|
|
if len(entries) > int(self.cfg.max_entries):
|
|
items = sorted(entries.items(), key=lambda kv: float(kv[1].get("timestamp", 0.0)), reverse=True)
|
|
keep = dict(items[: int(self.cfg.max_entries)])
|
|
self._db["entries"] = keep
|
|
|
|
self._save()
|
|
except Exception:
|
|
return
|