342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""Arabic text normalization and voice-command phrase matching.
|
||
|
||
Ported from gemini_interact/sanad_text_utils.py — unified for Sanad.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
# Arabic diacritics (tashkeel) — stripped for matching.
|
||
_DIACRITICS_RE = re.compile(r"[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]")
|
||
_AR_PUNCT = re.compile(r"[؟،؛]")
|
||
_NON_WORD = re.compile(r"[^\w\u0600-\u06FF\s]", re.UNICODE)
|
||
_MULTI_WS = re.compile(r"\s+")
|
||
|
||
|
||
def normalize_arabic(text: str) -> str:
|
||
"""Normalize Arabic + English text for matching."""
|
||
s = text.strip().lower()
|
||
s = _AR_PUNCT.sub(" ", s)
|
||
s = _NON_WORD.sub(" ", s)
|
||
s = _MULTI_WS.sub(" ", s)
|
||
# Hamza variants → bare alif
|
||
s = s.replace("\u0623", "\u0627") # أ → ا
|
||
s = s.replace("\u0625", "\u0627") # إ → ا
|
||
s = s.replace("\u0622", "\u0627") # آ → ا
|
||
# Ta marbuta / alif maqsoora
|
||
s = s.replace("\u0629", "\u0647") # ة → ه
|
||
s = s.replace("\u0649", "\u064A") # ى → ي
|
||
# Tatweel
|
||
s = s.replace("\u0640", "")
|
||
# Strip diacritics last
|
||
s = _DIACRITICS_RE.sub("", s)
|
||
return s.strip()
|
||
|
||
|
||
def strip_diacritics(text: str) -> str:
|
||
return _DIACRITICS_RE.sub("", text)
|
||
|
||
|
||
def load_phrase_map(filepath: str | Path) -> dict[str, set[str]]:
|
||
"""Load a phrase file mapping command names to trigger phrases.
|
||
|
||
Format (per command):
|
||
WAKE_PHRASES_shake_hand = {
|
||
"مصافحه", "handshake", "shake hands",
|
||
}
|
||
|
||
Returns: {"shake_hand": {"مصافحه", "handshake", ...}, ...}
|
||
"""
|
||
path = Path(filepath)
|
||
if not path.exists():
|
||
return {}
|
||
|
||
content = path.read_text(encoding="utf-8")
|
||
result: dict[str, set[str]] = {}
|
||
current_name: str | None = None
|
||
current_phrases: set[str] = set()
|
||
|
||
for raw_line in content.splitlines():
|
||
line = raw_line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
|
||
# Header: WAKE_PHRASES_shake_hand = {
|
||
header_match = re.match(r"WAKE_PHRASES_(\w+)\s*=\s*\{", line)
|
||
if header_match:
|
||
if current_name and current_phrases:
|
||
result[current_name] = current_phrases
|
||
current_name = header_match.group(1)
|
||
current_phrases = set()
|
||
continue
|
||
|
||
# Closing brace
|
||
if line == "}":
|
||
if current_name and current_phrases:
|
||
result[current_name] = current_phrases
|
||
current_name = None
|
||
current_phrases = set()
|
||
continue
|
||
|
||
# Phrase line: "some phrase",
|
||
phrase_match = re.match(r'"([^"]+)"', line)
|
||
if phrase_match and current_name is not None:
|
||
phrase = normalize_arabic(phrase_match.group(1))
|
||
if phrase:
|
||
current_phrases.add(phrase)
|
||
|
||
if current_name and current_phrases:
|
||
result[current_name] = current_phrases
|
||
|
||
return result
|
||
|
||
|
||
def match_phrase(text: str, phrase_sets: dict[str, set[str]]) -> str | None:
|
||
"""Return the command name if normalized *text* matches any phrase set.
|
||
|
||
Token-set matching: every word of the phrase must appear as a whole
|
||
word in *text*. Prevents short phrases (e.g. 'hi') from matching
|
||
longer words (e.g. 'this').
|
||
"""
|
||
norm = normalize_arabic(text)
|
||
if not norm:
|
||
return None
|
||
text_tokens = set(norm.split())
|
||
if not text_tokens:
|
||
return None
|
||
best_command: str | None = None
|
||
best_len = 0
|
||
for command_name, phrases in phrase_sets.items():
|
||
for phrase in phrases:
|
||
phrase_tokens = phrase.split()
|
||
if not phrase_tokens:
|
||
continue
|
||
if all(t in text_tokens for t in phrase_tokens):
|
||
if len(phrase) > best_len:
|
||
best_command = command_name
|
||
best_len = len(phrase)
|
||
return best_command
|
||
|
||
|
||
# ────────────────────── stateful ASR-buffer matcher ──────────────────────
|
||
# Port of gemini_interact/sanad_text_utils.py:_maybe_trigger_arm
|
||
#
|
||
# Why stateful: Gemini streams short ASR pieces like "مر", "حب", "ا" that
|
||
# need to be joined across ~2 s to match "مرحبا". This matcher buffers
|
||
# incoming transcript pieces, dedups repeats, and fires when any phrase
|
||
# in the wake set is found.
|
||
|
||
import time
|
||
import asyncio
|
||
import threading
|
||
|
||
|
||
_YA_PREFIX_RE = re.compile(r"^يا\s*")
|
||
|
||
|
||
def _strip_ya_prefix(s: str) -> str:
|
||
s = (s or "").strip()
|
||
return _YA_PREFIX_RE.sub("", s).strip()
|
||
|
||
|
||
def _remove_al_prefix_words(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
out = []
|
||
for w in text.split():
|
||
if w.startswith("ال") and len(w) > 2:
|
||
out.append(w[2:])
|
||
else:
|
||
out.append(w)
|
||
return " ".join(out).strip()
|
||
|
||
|
||
def _is_valid_text(s: str) -> bool:
|
||
has_ar = bool(re.search(r"[\u0600-\u06FF]", s or ""))
|
||
has_en = bool(re.search(r"[a-zA-Z]", s or ""))
|
||
return has_ar or has_en
|
||
|
||
|
||
def maybe_trigger_arm(
|
||
state: Any,
|
||
transcript_text: str,
|
||
wake_phrases: set[str],
|
||
*,
|
||
fire_on_wake_match: bool = True,
|
||
arm_trigger_fn=None,
|
||
) -> bool:
|
||
"""Buffer-aware wake-phrase matcher.
|
||
|
||
`state` is any object — attributes are lazily initialized on first use.
|
||
Suitable targets: a session dataclass, or even a plain `types.SimpleNamespace`.
|
||
|
||
On match:
|
||
- Clears ASR buffer to avoid re-trigger on next chunk
|
||
- If fire_on_wake_match: runs arm_trigger_fn in a background thread
|
||
immediately (wrapped in asyncio.to_thread if in a loop, else
|
||
threading.Thread)
|
||
- If not fire_on_wake_match: marks _pending_arm_wave=True so the
|
||
caller can fire it on turn_complete
|
||
|
||
Returns True if a phrase fired, False otherwise.
|
||
"""
|
||
if not transcript_text or not wake_phrases:
|
||
return False
|
||
|
||
# ── lazy state init ────────────────────────────────────────
|
||
for attr, default in (
|
||
("_asr_buf", ""), ("_asr_last_time", 0.0),
|
||
("ASR_WINDOW_SEC", 2.0), ("ASR_SHORT_TOKEN_BONUS_SEC", 1.0),
|
||
("ASR_JOIN_NO_SPACE_MAXLEN", 2), ("ASR_MAX_CHARS", 120),
|
||
("_last_trigger_norm", ""), ("_last_trigger_time", 0.0),
|
||
("TRIGGER_DEDUP_WINDOW", 2.0),
|
||
("_pending_arm_wave", False), ("_pending_arm_wave_fired", False),
|
||
("_pending_arm_wave_set_time", 0.0), ("PENDING_ARM_TTL", 6.0),
|
||
("_pending_arm_trigger_fn", None), ("_pending_arm_fallback_time", 0.0),
|
||
("_last_piece_call_norm", ""), ("_last_piece_call_time", 0.0),
|
||
("_asr_stream", ""), ("ASR_STREAM_MAX_CHARS", 80),
|
||
):
|
||
if not hasattr(state, attr):
|
||
setattr(state, attr, default)
|
||
|
||
dup_call_window = float(getattr(state, "DUP_CALL_WINDOW_SEC", 0.25))
|
||
dup_asr_repeat_window = float(getattr(state, "DUP_ASR_REPEAT_WINDOW_SEC", 0.9))
|
||
pending_fallback_sec = float(getattr(state, "PENDING_ARM_FALLBACK_SEC", 0.65))
|
||
|
||
piece_raw = transcript_text.strip()
|
||
if not piece_raw:
|
||
return False
|
||
|
||
piece_norm = normalize_arabic(piece_raw)
|
||
if not piece_norm or not _is_valid_text(piece_norm):
|
||
return False
|
||
|
||
now = time.time()
|
||
|
||
duplicate_call = (
|
||
piece_norm == state._last_piece_call_norm
|
||
and (now - state._last_piece_call_time) < dup_call_window
|
||
)
|
||
repeated_asr = (
|
||
piece_norm == state._last_piece_call_norm
|
||
and (now - state._last_piece_call_time) < dup_asr_repeat_window
|
||
)
|
||
|
||
state._last_piece_call_norm = piece_norm
|
||
state._last_piece_call_time = now
|
||
|
||
# Buffer update
|
||
if not duplicate_call and not repeated_asr:
|
||
if state._asr_last_time:
|
||
gap = now - state._asr_last_time
|
||
window = state.ASR_WINDOW_SEC
|
||
if len(piece_norm) <= state.ASR_JOIN_NO_SPACE_MAXLEN:
|
||
window += state.ASR_SHORT_TOKEN_BONUS_SEC
|
||
if gap > window:
|
||
state._asr_buf = ""
|
||
state._asr_stream = ""
|
||
|
||
state._asr_last_time = now
|
||
|
||
# Join logic — no-space for very short pieces
|
||
if state._asr_buf:
|
||
if len(piece_norm) <= state.ASR_JOIN_NO_SPACE_MAXLEN:
|
||
state._asr_buf = (state._asr_buf + piece_norm).strip()
|
||
else:
|
||
state._asr_buf = (state._asr_buf + " " + piece_norm).strip()
|
||
else:
|
||
state._asr_buf = piece_norm
|
||
|
||
compact = piece_norm.replace(" ", "")
|
||
state._asr_stream = (state._asr_stream + compact)[-state.ASR_STREAM_MAX_CHARS:]
|
||
if len(state._asr_buf) > state.ASR_MAX_CHARS:
|
||
state._asr_buf = state._asr_buf[-state.ASR_MAX_CHARS:]
|
||
|
||
buf_norm = normalize_arabic(state._asr_buf)
|
||
buf_nospace = buf_norm.replace(" ", "")
|
||
buf_noal = _remove_al_prefix_words(buf_norm)
|
||
buf_noal_nospace = buf_noal.replace(" ", "")
|
||
stream = normalize_arabic(state._asr_stream).replace(" ", "")
|
||
stream_noal = _remove_al_prefix_words(stream)
|
||
|
||
# Dedup — don't fire same buffer twice within TRIGGER_DEDUP_WINDOW
|
||
if (buf_norm == state._last_trigger_norm
|
||
and (now - state._last_trigger_time) < state.TRIGGER_DEDUP_WINDOW):
|
||
return False
|
||
|
||
# Match loop
|
||
for phrase in wake_phrases:
|
||
p_norm = _strip_ya_prefix(normalize_arabic(str(phrase)))
|
||
if not p_norm:
|
||
continue
|
||
p_nospace = p_norm.replace(" ", "")
|
||
p_noal = _remove_al_prefix_words(p_norm)
|
||
p_noal_nospace = p_noal.replace(" ", "")
|
||
|
||
pattern = r"\b" + re.escape(p_norm) + r"\b"
|
||
hit_buf = bool(re.search(pattern, buf_norm)) \
|
||
or (p_nospace and p_nospace == buf_nospace) \
|
||
or (p_noal and (p_noal in buf_noal
|
||
or (p_noal_nospace and p_noal_nospace in buf_noal_nospace)))
|
||
|
||
hit_stream = bool(p_nospace and p_nospace in stream) \
|
||
or bool(p_noal_nospace and p_noal_nospace in stream_noal)
|
||
|
||
if hit_buf or hit_stream:
|
||
state._last_trigger_norm = buf_norm
|
||
state._last_trigger_time = now
|
||
state._asr_buf = ""
|
||
state._asr_last_time = 0.0
|
||
state._asr_stream = ""
|
||
|
||
if fire_on_wake_match:
|
||
if arm_trigger_fn:
|
||
_fire_arm_trigger(arm_trigger_fn)
|
||
state._pending_arm_wave = False
|
||
state._pending_arm_wave_fired = False
|
||
state._pending_arm_wave_set_time = 0.0
|
||
state._pending_arm_trigger_fn = None
|
||
state._pending_arm_fallback_time = 0.0
|
||
else:
|
||
state._pending_arm_wave = True
|
||
state._pending_arm_wave_fired = False
|
||
state._pending_arm_wave_set_time = now
|
||
state._pending_arm_trigger_fn = arm_trigger_fn
|
||
state._pending_arm_fallback_time = now + pending_fallback_sec
|
||
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def _fire_arm_trigger(fn) -> None:
|
||
"""Run the arm trigger callback in a background thread, regardless
|
||
of whether we're inside an asyncio loop."""
|
||
try:
|
||
asyncio.get_running_loop()
|
||
asyncio.create_task(asyncio.to_thread(fn))
|
||
except RuntimeError:
|
||
threading.Thread(target=fn, daemon=True).start()
|
||
|
||
|
||
def load_arm_phrase_dispatch(
|
||
sanad_arm_txt: str | Path,
|
||
option_list: list,
|
||
) -> dict[int, set[str]]:
|
||
"""Build {action_id: set_of_phrases} from sanad_arm.txt × OPTION_LIST.
|
||
|
||
Each OPTION has .id and .name. The sanad_arm.txt file defines
|
||
WAKE_PHRASES_<option.name with _ instead of spaces/hyphens>.
|
||
"""
|
||
phrase_map = load_phrase_map(sanad_arm_txt) # {name_var: set[phrase]}
|
||
dispatch: dict[int, set[str]] = {}
|
||
for opt in option_list:
|
||
var = opt.name.replace(" ", "_").replace("-", "_")
|
||
phrases = phrase_map.get(var)
|
||
if phrases:
|
||
dispatch[opt.id] = phrases
|
||
return dispatch
|