168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from typing import Any, Deque, Dict, Optional
|
|
|
|
|
|
def _safe_float(v: Any, default: float) -> float:
|
|
try:
|
|
return float(v)
|
|
except Exception:
|
|
return float(default)
|
|
|
|
|
|
def _safe_int(v: Any, default: int) -> int:
|
|
try:
|
|
return int(v)
|
|
except Exception:
|
|
return int(default)
|
|
|
|
|
|
def _as_bool(v: Any, default: bool) -> bool:
|
|
if v is None:
|
|
return default
|
|
if isinstance(v, bool):
|
|
return v
|
|
if isinstance(v, (int, float)):
|
|
return bool(v)
|
|
return str(v).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
@dataclass
|
|
class LocalizationStateConfig:
|
|
enabled: bool = True
|
|
window_size: int = 12
|
|
degraded_fitness: float = 0.20
|
|
lost_fitness: float = 0.10
|
|
degraded_rmse: float = 0.45
|
|
lost_rmse: float = 0.80
|
|
max_bad_before_lost: int = 4
|
|
min_good_to_recover: int = 3
|
|
recovery_cooldown_sec: float = 2.0
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any] | None) -> "LocalizationStateConfig":
|
|
src = d or {}
|
|
return LocalizationStateConfig(
|
|
enabled=_as_bool(src.get("enabled", True), True),
|
|
window_size=max(3, _safe_int(src.get("window_size", 12), 12)),
|
|
degraded_fitness=max(0.0, min(1.0, _safe_float(src.get("degraded_fitness", 0.20), 0.20))),
|
|
lost_fitness=max(0.0, min(1.0, _safe_float(src.get("lost_fitness", 0.10), 0.10))),
|
|
degraded_rmse=max(0.01, _safe_float(src.get("degraded_rmse", 0.45), 0.45)),
|
|
lost_rmse=max(0.02, _safe_float(src.get("lost_rmse", 0.80), 0.80)),
|
|
max_bad_before_lost=max(2, _safe_int(src.get("max_bad_before_lost", 4), 4)),
|
|
min_good_to_recover=max(1, _safe_int(src.get("min_good_to_recover", 3), 3)),
|
|
recovery_cooldown_sec=max(0.1, _safe_float(src.get("recovery_cooldown_sec", 2.0), 2.0)),
|
|
)
|
|
|
|
|
|
class LocalizationStateMachine:
|
|
def __init__(self, cfg: LocalizationStateConfig):
|
|
self.cfg = cfg
|
|
self.state = "TRACKING"
|
|
self._history: Deque[bool] = deque(maxlen=int(self.cfg.window_size))
|
|
self._consecutive_bad = 0
|
|
self._consecutive_good = 0
|
|
self._last_update_t = 0.0
|
|
self._last_recovery_t = 0.0
|
|
|
|
def reset(self) -> None:
|
|
self.state = "TRACKING"
|
|
self._history.clear()
|
|
self._consecutive_bad = 0
|
|
self._consecutive_good = 0
|
|
self._last_update_t = 0.0
|
|
self._last_recovery_t = 0.0
|
|
|
|
def set_enabled(self, enabled: bool) -> None:
|
|
self.cfg.enabled = bool(enabled)
|
|
if not self.cfg.enabled:
|
|
self.reset()
|
|
self.state = "TRACKING"
|
|
|
|
def enter_recovery(self, now: float) -> None:
|
|
if not self.cfg.enabled:
|
|
return
|
|
self.state = "RECOVERY"
|
|
self._last_recovery_t = float(now)
|
|
|
|
def _is_good(self, result: Optional[Dict[str, Any]]) -> bool:
|
|
if result is None:
|
|
return False
|
|
if bool(result.get("accepted", False)):
|
|
return True
|
|
fit = float(result.get("fitness", 0.0))
|
|
rmse = float(result.get("rmse", 9e9))
|
|
if fit >= float(self.cfg.degraded_fitness) and rmse <= float(self.cfg.degraded_rmse):
|
|
return True
|
|
return False
|
|
|
|
def _is_lost_like(self, result: Optional[Dict[str, Any]]) -> bool:
|
|
if result is None:
|
|
return True
|
|
fit = float(result.get("fitness", 0.0))
|
|
rmse = float(result.get("rmse", 9e9))
|
|
return (fit < float(self.cfg.lost_fitness)) or (rmse > float(self.cfg.lost_rmse))
|
|
|
|
def update(self, result: Optional[Dict[str, Any]], now: float) -> str:
|
|
if not self.cfg.enabled:
|
|
self.state = "TRACKING"
|
|
return self.state
|
|
|
|
good = self._is_good(result)
|
|
self._history.append(good)
|
|
self._last_update_t = float(now)
|
|
|
|
if self.state == "RECOVERY":
|
|
if good:
|
|
self._consecutive_good += 1
|
|
self._consecutive_bad = 0
|
|
if self._consecutive_good >= int(self.cfg.min_good_to_recover):
|
|
self.state = "TRACKING"
|
|
else:
|
|
self._consecutive_bad += 1
|
|
self._consecutive_good = 0
|
|
self.state = "LOST"
|
|
return self.state
|
|
|
|
if good:
|
|
self._consecutive_good += 1
|
|
self._consecutive_bad = 0
|
|
if self.state in ("DEGRADED", "LOST") and self._consecutive_good >= int(self.cfg.min_good_to_recover):
|
|
self.state = "TRACKING"
|
|
else:
|
|
self._consecutive_bad += 1
|
|
self._consecutive_good = 0
|
|
if self._is_lost_like(result) and self._consecutive_bad >= int(self.cfg.max_bad_before_lost):
|
|
self.state = "LOST"
|
|
elif self.state == "TRACKING":
|
|
self.state = "DEGRADED"
|
|
elif self.state == "LOST":
|
|
self.state = "LOST"
|
|
else:
|
|
self.state = "DEGRADED"
|
|
|
|
return self.state
|
|
|
|
def should_recover(self, now: float) -> bool:
|
|
if self.state not in ("LOST", "RECOVERY"):
|
|
return False
|
|
return (float(now) - float(self._last_recovery_t)) >= float(self.cfg.recovery_cooldown_sec)
|
|
|
|
def mark_recovery_attempt(self, now: float) -> None:
|
|
self._last_recovery_t = float(now)
|
|
|
|
def snapshot(self) -> Dict[str, Any]:
|
|
good_ratio = 0.0
|
|
if len(self._history) > 0:
|
|
good_ratio = float(sum(1 for x in self._history if x)) / float(len(self._history))
|
|
return {
|
|
"enabled": bool(self.cfg.enabled),
|
|
"state": self.state,
|
|
"history_len": int(len(self._history)),
|
|
"good_ratio": float(good_ratio),
|
|
"consecutive_bad": int(self._consecutive_bad),
|
|
"consecutive_good": int(self._consecutive_good),
|
|
}
|