Sanad/voice/text_utils.py

342 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Arabic text normalization and voice-command phrase matching.
Ported from gemini_interact/sanad_text_utils.py — unified for Sanad.
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import Any
# Arabic diacritics (tashkeel) — stripped for matching.
_DIACRITICS_RE = re.compile(r"[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]")
_AR_PUNCT = re.compile(r"[؟،؛]")
_NON_WORD = re.compile(r"[^\w\u0600-\u06FF\s]", re.UNICODE)
_MULTI_WS = re.compile(r"\s+")
def normalize_arabic(text: str) -> str:
"""Normalize Arabic + English text for matching."""
s = text.strip().lower()
s = _AR_PUNCT.sub(" ", s)
s = _NON_WORD.sub(" ", s)
s = _MULTI_WS.sub(" ", s)
# Hamza variants → bare alif
s = s.replace("\u0623", "\u0627") # أ → ا
s = s.replace("\u0625", "\u0627") # إ → ا
s = s.replace("\u0622", "\u0627") # آ → ا
# Ta marbuta / alif maqsoora
s = s.replace("\u0629", "\u0647") # ة → ه
s = s.replace("\u0649", "\u064A") # ى → ي
# Tatweel
s = s.replace("\u0640", "")
# Strip diacritics last
s = _DIACRITICS_RE.sub("", s)
return s.strip()
def strip_diacritics(text: str) -> str:
return _DIACRITICS_RE.sub("", text)
def load_phrase_map(filepath: str | Path) -> dict[str, set[str]]:
"""Load a phrase file mapping command names to trigger phrases.
Format (per command):
WAKE_PHRASES_shake_hand = {
"مصافحه", "handshake", "shake hands",
}
Returns: {"shake_hand": {"مصافحه", "handshake", ...}, ...}
"""
path = Path(filepath)
if not path.exists():
return {}
content = path.read_text(encoding="utf-8")
result: dict[str, set[str]] = {}
current_name: str | None = None
current_phrases: set[str] = set()
for raw_line in content.splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
# Header: WAKE_PHRASES_shake_hand = {
header_match = re.match(r"WAKE_PHRASES_(\w+)\s*=\s*\{", line)
if header_match:
if current_name and current_phrases:
result[current_name] = current_phrases
current_name = header_match.group(1)
current_phrases = set()
continue
# Closing brace
if line == "}":
if current_name and current_phrases:
result[current_name] = current_phrases
current_name = None
current_phrases = set()
continue
# Phrase line: "some phrase",
phrase_match = re.match(r'"([^"]+)"', line)
if phrase_match and current_name is not None:
phrase = normalize_arabic(phrase_match.group(1))
if phrase:
current_phrases.add(phrase)
if current_name and current_phrases:
result[current_name] = current_phrases
return result
def match_phrase(text: str, phrase_sets: dict[str, set[str]]) -> str | None:
"""Return the command name if normalized *text* matches any phrase set.
Token-set matching: every word of the phrase must appear as a whole
word in *text*. Prevents short phrases (e.g. 'hi') from matching
longer words (e.g. 'this').
"""
norm = normalize_arabic(text)
if not norm:
return None
text_tokens = set(norm.split())
if not text_tokens:
return None
best_command: str | None = None
best_len = 0
for command_name, phrases in phrase_sets.items():
for phrase in phrases:
phrase_tokens = phrase.split()
if not phrase_tokens:
continue
if all(t in text_tokens for t in phrase_tokens):
if len(phrase) > best_len:
best_command = command_name
best_len = len(phrase)
return best_command
# ────────────────────── stateful ASR-buffer matcher ──────────────────────
# Port of gemini_interact/sanad_text_utils.py:_maybe_trigger_arm
#
# Why stateful: Gemini streams short ASR pieces like "مر", "حب", "ا" that
# need to be joined across ~2 s to match "مرحبا". This matcher buffers
# incoming transcript pieces, dedups repeats, and fires when any phrase
# in the wake set is found.
import time
import asyncio
import threading
_YA_PREFIX_RE = re.compile(r"^يا\s*")
def _strip_ya_prefix(s: str) -> str:
s = (s or "").strip()
return _YA_PREFIX_RE.sub("", s).strip()
def _remove_al_prefix_words(text: str) -> str:
if not text:
return ""
out = []
for w in text.split():
if w.startswith("ال") and len(w) > 2:
out.append(w[2:])
else:
out.append(w)
return " ".join(out).strip()
def _is_valid_text(s: str) -> bool:
has_ar = bool(re.search(r"[\u0600-\u06FF]", s or ""))
has_en = bool(re.search(r"[a-zA-Z]", s or ""))
return has_ar or has_en
def maybe_trigger_arm(
state: Any,
transcript_text: str,
wake_phrases: set[str],
*,
fire_on_wake_match: bool = True,
arm_trigger_fn=None,
) -> bool:
"""Buffer-aware wake-phrase matcher.
`state` is any object — attributes are lazily initialized on first use.
Suitable targets: a session dataclass, or even a plain `types.SimpleNamespace`.
On match:
- Clears ASR buffer to avoid re-trigger on next chunk
- If fire_on_wake_match: runs arm_trigger_fn in a background thread
immediately (wrapped in asyncio.to_thread if in a loop, else
threading.Thread)
- If not fire_on_wake_match: marks _pending_arm_wave=True so the
caller can fire it on turn_complete
Returns True if a phrase fired, False otherwise.
"""
if not transcript_text or not wake_phrases:
return False
# ── lazy state init ────────────────────────────────────────
for attr, default in (
("_asr_buf", ""), ("_asr_last_time", 0.0),
("ASR_WINDOW_SEC", 2.0), ("ASR_SHORT_TOKEN_BONUS_SEC", 1.0),
("ASR_JOIN_NO_SPACE_MAXLEN", 2), ("ASR_MAX_CHARS", 120),
("_last_trigger_norm", ""), ("_last_trigger_time", 0.0),
("TRIGGER_DEDUP_WINDOW", 2.0),
("_pending_arm_wave", False), ("_pending_arm_wave_fired", False),
("_pending_arm_wave_set_time", 0.0), ("PENDING_ARM_TTL", 6.0),
("_pending_arm_trigger_fn", None), ("_pending_arm_fallback_time", 0.0),
("_last_piece_call_norm", ""), ("_last_piece_call_time", 0.0),
("_asr_stream", ""), ("ASR_STREAM_MAX_CHARS", 80),
):
if not hasattr(state, attr):
setattr(state, attr, default)
dup_call_window = float(getattr(state, "DUP_CALL_WINDOW_SEC", 0.25))
dup_asr_repeat_window = float(getattr(state, "DUP_ASR_REPEAT_WINDOW_SEC", 0.9))
pending_fallback_sec = float(getattr(state, "PENDING_ARM_FALLBACK_SEC", 0.65))
piece_raw = transcript_text.strip()
if not piece_raw:
return False
piece_norm = normalize_arabic(piece_raw)
if not piece_norm or not _is_valid_text(piece_norm):
return False
now = time.time()
duplicate_call = (
piece_norm == state._last_piece_call_norm
and (now - state._last_piece_call_time) < dup_call_window
)
repeated_asr = (
piece_norm == state._last_piece_call_norm
and (now - state._last_piece_call_time) < dup_asr_repeat_window
)
state._last_piece_call_norm = piece_norm
state._last_piece_call_time = now
# Buffer update
if not duplicate_call and not repeated_asr:
if state._asr_last_time:
gap = now - state._asr_last_time
window = state.ASR_WINDOW_SEC
if len(piece_norm) <= state.ASR_JOIN_NO_SPACE_MAXLEN:
window += state.ASR_SHORT_TOKEN_BONUS_SEC
if gap > window:
state._asr_buf = ""
state._asr_stream = ""
state._asr_last_time = now
# Join logic — no-space for very short pieces
if state._asr_buf:
if len(piece_norm) <= state.ASR_JOIN_NO_SPACE_MAXLEN:
state._asr_buf = (state._asr_buf + piece_norm).strip()
else:
state._asr_buf = (state._asr_buf + " " + piece_norm).strip()
else:
state._asr_buf = piece_norm
compact = piece_norm.replace(" ", "")
state._asr_stream = (state._asr_stream + compact)[-state.ASR_STREAM_MAX_CHARS:]
if len(state._asr_buf) > state.ASR_MAX_CHARS:
state._asr_buf = state._asr_buf[-state.ASR_MAX_CHARS:]
buf_norm = normalize_arabic(state._asr_buf)
buf_nospace = buf_norm.replace(" ", "")
buf_noal = _remove_al_prefix_words(buf_norm)
buf_noal_nospace = buf_noal.replace(" ", "")
stream = normalize_arabic(state._asr_stream).replace(" ", "")
stream_noal = _remove_al_prefix_words(stream)
# Dedup — don't fire same buffer twice within TRIGGER_DEDUP_WINDOW
if (buf_norm == state._last_trigger_norm
and (now - state._last_trigger_time) < state.TRIGGER_DEDUP_WINDOW):
return False
# Match loop
for phrase in wake_phrases:
p_norm = _strip_ya_prefix(normalize_arabic(str(phrase)))
if not p_norm:
continue
p_nospace = p_norm.replace(" ", "")
p_noal = _remove_al_prefix_words(p_norm)
p_noal_nospace = p_noal.replace(" ", "")
pattern = r"\b" + re.escape(p_norm) + r"\b"
hit_buf = bool(re.search(pattern, buf_norm)) \
or (p_nospace and p_nospace == buf_nospace) \
or (p_noal and (p_noal in buf_noal
or (p_noal_nospace and p_noal_nospace in buf_noal_nospace)))
hit_stream = bool(p_nospace and p_nospace in stream) \
or bool(p_noal_nospace and p_noal_nospace in stream_noal)
if hit_buf or hit_stream:
state._last_trigger_norm = buf_norm
state._last_trigger_time = now
state._asr_buf = ""
state._asr_last_time = 0.0
state._asr_stream = ""
if fire_on_wake_match:
if arm_trigger_fn:
_fire_arm_trigger(arm_trigger_fn)
state._pending_arm_wave = False
state._pending_arm_wave_fired = False
state._pending_arm_wave_set_time = 0.0
state._pending_arm_trigger_fn = None
state._pending_arm_fallback_time = 0.0
else:
state._pending_arm_wave = True
state._pending_arm_wave_fired = False
state._pending_arm_wave_set_time = now
state._pending_arm_trigger_fn = arm_trigger_fn
state._pending_arm_fallback_time = now + pending_fallback_sec
return True
return False
def _fire_arm_trigger(fn) -> None:
"""Run the arm trigger callback in a background thread, regardless
of whether we're inside an asyncio loop."""
try:
asyncio.get_running_loop()
asyncio.create_task(asyncio.to_thread(fn))
except RuntimeError:
threading.Thread(target=fn, daemon=True).start()
def load_arm_phrase_dispatch(
sanad_arm_txt: str | Path,
option_list: list,
) -> dict[int, set[str]]:
"""Build {action_id: set_of_phrases} from sanad_arm.txt × OPTION_LIST.
Each OPTION has .id and .name. The sanad_arm.txt file defines
WAKE_PHRASES_<option.name with _ instead of spaces/hyphens>.
"""
phrase_map = load_phrase_map(sanad_arm_txt) # {name_var: set[phrase]}
dispatch: dict[int, set[str]] = {}
for opt in option_list:
var = opt.name.replace(" ", "_").replace("-", "_")
phrases = phrase_map.get(var)
if phrases:
dispatch[opt.id] = phrases
return dispatch