Sanad/voice/typed_replay.py

710 lines
28 KiB
Python

"""Typed Replay Engine — send text to Gemini, play audio, capture + persist.
Full-featured port of gemini_voice_v2/sanad_webserver.py's SanadReplayEngine:
- Generate audio via GeminiVoiceClient (reuses existing WebSocket client)
- Play via PulseAudio + optionally capture speaker output (what was actually
heard) via parec or PyAudio monitor-source
- Save two WAVs per record: speaker capture + Gemini raw output
- JSON record index with rename/delete/replay
- In-memory "last session" for quick replay without re-hitting Gemini
"""
from __future__ import annotations
import asyncio
import json
import os
import re
import shutil
import subprocess
import tempfile
import threading
import time
import wave
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
from Project.Sanad.config import (
AUDIO_RECORDINGS_DIR,
CHANNELS,
CHUNK_SIZE,
RECEIVE_SAMPLE_RATE,
SINK as DEFAULT_SINK,
SOURCE as DEFAULT_SOURCE,
MONITOR_SOURCE as DEFAULT_MONITOR_SOURCE,
)
from Project.Sanad.core.logger import get_logger
try:
import pyaudio
except ImportError:
pyaudio = None # degraded mode — can still generate, but not capture/play
log = get_logger("typed_replay")
# ─── constants (from config/voice_config.json) ──────────────────────
try:
from Project.Sanad.core.config_loader import section as _cfg_section
_TR = _cfg_section("voice", "typed_replay")
except Exception:
_TR = {}
RECORD_INDEX_PATH = AUDIO_RECORDINGS_DIR / "records.json"
MONITOR_CHUNK_SIZE = _TR.get("monitor_chunk_size", CHUNK_SIZE)
MONITOR_TAIL_SEC = _TR.get("monitor_tail_sec", 0.2)
MAX_TEXT_LEN = _TR.get("max_text_len", 2000)
# ─── helpers ─────────────────────────────────────────────────────────
def format_timestamp(dt: Optional[datetime] = None) -> str:
return (dt or datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
def sanitize_record_name(name: str) -> str:
name = (name or "").strip() or f"record_{datetime.now():%Y%m%d_%H%M%S}"
name = re.sub(r"[^\w\-\u0600-\u06FF\s\.]", "_", name, flags=re.UNICODE)
name = re.sub(r"\s+", "_", name)
return name[:80]
def build_default_name(text: str) -> str:
stub = re.sub(r"\s+", "_", (text or "").strip())
stub = re.sub(r"[^\w\u0600-\u06FF]", "", stub, flags=re.UNICODE)
stub = stub[:40] or "record"
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{stub}_{stamp}"
def audio_duration_seconds(pcm: bytes, sample_rate: int, channels: int,
sample_width: int) -> float:
if not pcm or sample_rate <= 0 or channels <= 0 or sample_width <= 0:
return 0.0
return len(pcm) / (sample_rate * channels * sample_width)
def ensure_unique_record_stem(base_name: str, out_dir: Path) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
candidate = out_dir / sanitize_record_name(base_name)
counter = 0
while True:
speaker = candidate.with_suffix(".wav")
raw = candidate.with_name(f"{candidate.name}_raw.wav")
if not speaker.exists() and not raw.exists():
return candidate
counter += 1
candidate = out_dir / f"{sanitize_record_name(base_name)}_{counter}"
def run_pactl(args: list[str]) -> subprocess.CompletedProcess[str]:
return subprocess.run(
["pactl", *args], check=True, text=True,
capture_output=True, timeout=5,
)
# ─── monitor recorders (speaker output capture) ──────────────────────
class MonitorRecorder:
"""Capture speaker output via PyAudio on the monitor source."""
def __init__(self, pya, device_config: dict[str, Any]):
self.pya = pya
self.device_config = device_config
self.frames: list[bytes] = []
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
self._stream = None
self._error: Optional[BaseException] = None
def start(self):
if pyaudio is None:
raise RuntimeError("pyaudio unavailable — cannot capture speaker")
self._stop_event.clear()
self.frames = []
self._stream = self.pya.open(
format=pyaudio.paInt16,
channels=self.device_config["channels"],
rate=self.device_config["rate"],
input=True,
input_device_index=self.device_config["index"],
frames_per_buffer=self.device_config["chunk_size"],
)
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
time.sleep(0.05)
def _loop(self):
while not self._stop_event.is_set():
try:
data = self._stream.read(
self.device_config["chunk_size"], exception_on_overflow=False)
self.frames.append(data)
except Exception as exc:
if not self._stop_event.is_set():
self._error = exc
break
def stop(self) -> bytes:
time.sleep(MONITOR_TAIL_SEC)
self._stop_event.set()
if self._stream is not None:
try:
self._stream.stop_stream()
except Exception:
pass
try:
self._stream.close()
except Exception:
pass
if self._thread is not None:
self._thread.join(timeout=1.0)
if self._error is not None:
raise RuntimeError(f"Speaker capture failed: {self._error}")
return b"".join(self.frames)
class ParecMonitorRecorder:
"""Capture speaker output via `parec` (PulseAudio CLI)."""
def __init__(self, device_config: dict[str, Any]):
self.device_config = device_config
self.frames: list[bytes] = []
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
self._proc: Optional[subprocess.Popen[bytes]] = None
self._error: Optional[BaseException] = None
def start(self):
cmd = [
"parec",
f"--device={self.device_config['name']}",
"--format=s16le",
f"--rate={self.device_config['rate']}",
f"--channels={self.device_config['channels']}",
]
self._proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
time.sleep(0.05)
def _loop(self):
if self._proc is None or self._proc.stdout is None:
self._error = RuntimeError("parec did not start")
return
size = self.device_config["chunk_size"] * self.device_config["channels"] * 2
while not self._stop_event.is_set():
try:
data = self._proc.stdout.read(size)
if data:
self.frames.append(data)
continue
if self._proc.poll() is not None:
break
except Exception as exc:
if not self._stop_event.is_set():
self._error = exc
break
def stop(self) -> bytes:
time.sleep(MONITOR_TAIL_SEC)
self._stop_event.set()
if self._proc is not None and self._proc.poll() is None:
self._proc.terminate()
try:
self._proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
self._proc.kill()
if self._thread is not None:
self._thread.join(timeout=1.0)
if self._error is not None:
raise RuntimeError(f"parec capture failed: {self._error}")
return b"".join(self.frames)
# ─── session state ──────────────────────────────────────────────────
@dataclass
class ReplaySessionState:
"""Last generation kept in memory for replay/save-last."""
text: str = ""
audio_bytes: bytes = b""
speaker_capture: bytes = b""
generated_at: str = ""
last_playback_at: str = ""
replay_count: int = 0
saved_as: str = ""
def as_status(self) -> dict[str, Any]:
return {
"text": self.text,
"has_audio": bool(self.audio_bytes),
"has_capture": bool(self.speaker_capture),
"generated_at": self.generated_at,
"last_playback_at": self.last_playback_at,
"replay_count": self.replay_count,
"saved_as": self.saved_as,
}
# ─── record index ───────────────────────────────────────────────────
def _load_index() -> dict[str, Any]:
if not RECORD_INDEX_PATH.exists():
return {"total_records": 0, "records": []}
try:
payload = json.loads(RECORD_INDEX_PATH.read_text(encoding="utf-8"))
if not isinstance(payload, dict) or not isinstance(payload.get("records"), list):
raise ValueError("bad index structure")
payload.setdefault("total_records", len(payload["records"]))
return payload
except Exception as exc:
log.warning("record index unreadable, resetting: %s", exc)
return {"total_records": 0, "records": []}
def _save_index(payload: dict[str, Any]):
RECORD_INDEX_PATH.parent.mkdir(parents=True, exist_ok=True)
fd, tmp = tempfile.mkstemp(dir=str(RECORD_INDEX_PATH.parent),
suffix=".tmp")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
os.replace(tmp, RECORD_INDEX_PATH)
except Exception:
try:
os.unlink(tmp)
except OSError:
pass
raise
def _resolve_record_path(path_str: str) -> Path:
"""Resolve a path from the records index.
Paths in records.json can be either:
- absolute (legacy — may be stale after scp to another machine)
- relative / basename — looked up under AUDIO_RECORDINGS_DIR
"""
if not path_str:
return AUDIO_RECORDINGS_DIR
p = Path(path_str)
if p.is_absolute():
return p
return AUDIO_RECORDINGS_DIR / p
def _reconcile_index(payload: dict[str, Any]) -> dict[str, Any]:
"""Drop records whose files no longer exist on disk."""
surviving: list[dict[str, Any]] = []
for entry in payload.get("records", []):
try:
speaker = _resolve_record_path(
entry["files"]["speaker_recording"]["path"])
if speaker.exists():
surviving.append(entry)
except (KeyError, TypeError):
continue
payload["records"] = surviving
payload["total_records"] = len(surviving)
return payload
def _build_file_info(path: Path, pcm: bytes, rate: int,
channels: int, sample_width: int) -> dict[str, Any]:
"""Build a records.json file entry with a portable relative path.
`path` can be an absolute path on disk — we store just the basename
so the index is portable across workstation ↔ robot.
"""
return {
"name": path.name,
"path": path.name, # basename only — resolved via _resolve_record_path
"size_bytes": len(pcm),
"sample_rate": rate,
"channels": channels,
"sample_width_bytes": sample_width,
"duration_seconds": round(
audio_duration_seconds(pcm, rate, channels, sample_width), 3),
}
# ─── engine ─────────────────────────────────────────────────────────
class TypedReplayEngine:
"""Full-featured typed replay — generate, play, capture, save, replay."""
def __init__(self, voice_client, audio_mgr):
"""voice_client: GeminiVoiceClient audio_mgr: AudioManager"""
self.voice_client = voice_client
self.audio_mgr = audio_mgr
self.session = ReplaySessionState()
self._gen_lock = threading.Lock()
self._play_lock = threading.Lock()
self._monitor_config = self._resolve_monitor_config()
AUDIO_RECORDINGS_DIR.mkdir(parents=True, exist_ok=True)
# ── monitor config ───────────────────────────────────────────
def _resolve_monitor_config(self) -> Optional[dict[str, Any]]:
"""Pick the backend for capturing speaker output.
Priority:
1. parec (cleanest — just listens to the speaker monitor source)
2. PyAudio input device matching 'pulse' or 'default'
3. None → capture disabled (generation still works)
"""
if shutil.which("parec"):
log.info("speaker capture: parec monitor=%s", DEFAULT_MONITOR_SOURCE)
return {
"backend": "parec",
"name": DEFAULT_MONITOR_SOURCE,
"rate": RECEIVE_SAMPLE_RATE,
"channels": CHANNELS,
"chunk_size": MONITOR_CHUNK_SIZE,
}
if pyaudio is None:
log.warning("speaker capture disabled — no parec and no pyaudio")
return None
try:
pya = self.audio_mgr.pya if self.audio_mgr else pyaudio.PyAudio()
except Exception:
return None
for i in range(pya.get_device_count()):
info = pya.get_device_info_by_index(i)
name = str(info.get("name", "")).lower()
if ("pulse" in name or "default" in name) and int(info.get("maxInputChannels", 0)) > 0:
log.info("speaker capture: pyaudio device=%s", info.get("name"))
return {
"backend": "pyaudio",
"index": i,
"name": str(info.get("name")),
"rate": int(info.get("defaultSampleRate", RECEIVE_SAMPLE_RATE)),
"channels": max(1, min(2, int(info.get("maxInputChannels", 1)))),
"chunk_size": MONITOR_CHUNK_SIZE,
}
log.warning("speaker capture disabled — no pulse/default pyaudio device")
return None
def sample_width(self) -> int:
if pyaudio is None or self.audio_mgr is None or self.audio_mgr.pya is None:
return 2 # int16
return self.audio_mgr.pya.get_sample_size(pyaudio.paInt16)
# ── generation ───────────────────────────────────────────────
async def generate_audio(self, text: str) -> tuple[bytes, list[str]]:
"""Send text to Gemini, return (pcm_audio, text_parts)."""
if self.voice_client is None:
raise RuntimeError("voice_client unavailable")
if not self.voice_client.connected:
await self.voice_client.connect()
return await self.voice_client.send_text(text, owner="typed_replay")
# ── playback + capture ───────────────────────────────────────
def play_audio(self, audio_bytes: bytes, capture_speaker: bool) -> bytes:
"""Play PCM on speaker; optionally capture what was heard."""
if not audio_bytes:
return b""
if self.audio_mgr is None or self.audio_mgr.pya is None:
raise RuntimeError("audio_mgr unavailable — cannot play")
with self._play_lock:
recorder = None
restore_source = False
if capture_speaker and self._monitor_config is not None:
if self._monitor_config["backend"] == "parec":
recorder = ParecMonitorRecorder(self._monitor_config)
else:
recorder = MonitorRecorder(self.audio_mgr.pya, self._monitor_config)
try:
run_pactl(["set-default-source", self._monitor_config["name"]])
restore_source = True
except Exception as exc:
log.warning("couldn't switch default source to monitor: %s", exc)
stream = None
try:
stream = self.audio_mgr.pya.open(
format=pyaudio.paInt16,
channels=CHANNELS,
rate=RECEIVE_SAMPLE_RATE,
output=True,
frames_per_buffer=CHUNK_SIZE,
)
if recorder is not None:
recorder.start()
frame_bytes = CHUNK_SIZE * 2
for offset in range(0, len(audio_bytes), frame_bytes):
stream.write(audio_bytes[offset:offset + frame_bytes])
finally:
if stream is not None:
try:
stream.stop_stream()
finally:
stream.close()
captured = b""
try:
if recorder is not None:
captured = recorder.stop()
finally:
if restore_source:
try:
run_pactl(["set-default-source", DEFAULT_SOURCE])
except Exception as exc:
log.warning("couldn't restore default source: %s", exc)
return captured
def save_audio(self, pcm: bytes, path: Path, channels: int, rate: int) -> None:
with wave.open(str(path), "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(self.sample_width())
wf.setframerate(rate)
wf.writeframes(pcm)
# ── high-level API ───────────────────────────────────────────
async def say(self, text: str, record: bool = False,
record_name: str = "") -> dict[str, Any]:
"""Generate, play, capture, return metadata. Optionally persist."""
if not text or not text.strip():
raise ValueError("text cannot be empty")
if not self._gen_lock.acquire(blocking=False):
raise RuntimeError("another typed-replay generation is in progress")
try:
audio_bytes, text_parts = await self.generate_audio(text)
if not audio_bytes:
raise RuntimeError("Gemini returned no audio — parts: "
+ " | ".join(text_parts or []))
generated_at = format_timestamp()
# Play + capture in a worker thread (PyAudio is sync)
captured = await asyncio.to_thread(
self.play_audio, audio_bytes, record)
playback_finished_at = format_timestamp()
# Update session state
self.session.text = text
self.session.audio_bytes = audio_bytes
self.session.speaker_capture = captured
self.session.generated_at = generated_at
self.session.last_playback_at = playback_finished_at
self.session.replay_count = 1
self.session.saved_as = ""
result = {
"ok": True,
"text": text,
"gemini_text": text_parts,
"generated_at": generated_at,
"playback_finished_at": playback_finished_at,
"raw_duration_sec": round(
audio_duration_seconds(audio_bytes, RECEIVE_SAMPLE_RATE,
CHANNELS, self.sample_width()), 3),
"captured_speaker_bytes": len(captured),
"recorded": False,
}
if record:
entry = self._persist_session(record_name or build_default_name(text))
self.session.saved_as = entry["record_name"]
result["record"] = entry
result["recorded"] = True
return result
finally:
self._gen_lock.release()
def replay_last(self) -> dict[str, Any]:
"""Re-play the cached audio without hitting Gemini."""
if not self.session.audio_bytes:
raise RuntimeError("no cached generation — call say() first")
captured = self.play_audio(self.session.audio_bytes, capture_speaker=False)
self.session.replay_count += 1
self.session.last_playback_at = format_timestamp()
return {
"ok": True,
"replay_count": self.session.replay_count,
"text": self.session.text,
"played_at": self.session.last_playback_at,
}
def save_last(self, record_name: str = "") -> dict[str, Any]:
"""Persist the last generation to the records index."""
if not self.session.audio_bytes:
raise RuntimeError("no cached generation — call say() first")
entry = self._persist_session(record_name or build_default_name(self.session.text))
self.session.saved_as = entry["record_name"]
return entry
def _persist_session(self, record_name: str) -> dict[str, Any]:
base = ensure_unique_record_stem(record_name, AUDIO_RECORDINGS_DIR)
speaker_path = base.with_suffix(".wav")
raw_path = base.with_name(f"{base.name}_raw.wav")
capture = self.session.speaker_capture
audio = self.session.audio_bytes
sw = self.sample_width()
if capture:
cap_rate = (self._monitor_config or {}).get("rate", RECEIVE_SAMPLE_RATE)
cap_channels = (self._monitor_config or {}).get("channels", CHANNELS)
self.save_audio(capture, speaker_path, cap_channels, cap_rate)
else:
# No capture available → save raw as speaker too so every record
# has a .wav file for reconciliation checks.
self.save_audio(audio, speaker_path, CHANNELS, RECEIVE_SAMPLE_RATE)
cap_rate = RECEIVE_SAMPLE_RATE
cap_channels = CHANNELS
capture = audio
self.save_audio(audio, raw_path, CHANNELS, RECEIVE_SAMPLE_RATE)
entry = {
"record_name": base.name,
"text": self.session.text,
"replay_count": self.session.replay_count,
"timeline": {
"audio_generated_at": self.session.generated_at,
"last_playback_finished_at": self.session.last_playback_at,
"saved_at": format_timestamp(),
},
"audio_capture": {
"backend": (self._monitor_config or {}).get("backend", "none"),
"sink": DEFAULT_SINK,
"monitor_source": DEFAULT_MONITOR_SOURCE,
"restored_microphone_source": DEFAULT_SOURCE,
},
"files": {
"speaker_recording": _build_file_info(
speaker_path, capture, cap_rate, cap_channels, sw),
"gemini_raw_output": _build_file_info(
raw_path, audio, RECEIVE_SAMPLE_RATE, CHANNELS, sw),
},
}
payload = _reconcile_index(_load_index())
payload["records"].append(entry)
payload["total_records"] = len(payload["records"])
_save_index(payload)
log.info("saved record %s (%.1fs speaker, %.1fs raw)",
base.name,
entry["files"]["speaker_recording"]["duration_seconds"],
entry["files"]["gemini_raw_output"]["duration_seconds"])
return entry
# ── records CRUD ─────────────────────────────────────────────
def list_records(self) -> dict[str, Any]:
return _reconcile_index(_load_index())
def find_record(self, name: str) -> dict[str, Any]:
for e in _load_index().get("records", []):
if e.get("record_name") == name:
return e
raise KeyError(f"record not found: {name}")
def rename_record(self, name: str, new_name: str) -> dict[str, Any]:
new_name = sanitize_record_name(new_name)
if not new_name:
raise ValueError("new_name empty after sanitize")
payload = _reconcile_index(_load_index())
target = None
for e in payload["records"]:
if e.get("record_name") == name:
target = e
break
if target is None:
raise KeyError(f"record not found: {name}")
if any(e.get("record_name") == new_name for e in payload["records"]):
raise ValueError(f"a record named {new_name} already exists")
old_speaker = _resolve_record_path(target["files"]["speaker_recording"]["path"])
old_raw = _resolve_record_path(target["files"]["gemini_raw_output"]["path"])
new_base = AUDIO_RECORDINGS_DIR / new_name
new_speaker = new_base.with_suffix(".wav")
new_raw = new_base.with_name(f"{new_base.name}_raw.wav")
old_speaker.rename(new_speaker)
old_raw.rename(new_raw)
target["record_name"] = new_name
target["files"]["speaker_recording"]["path"] = new_speaker.name # basename only
target["files"]["speaker_recording"]["name"] = new_speaker.name
target["files"]["gemini_raw_output"]["path"] = new_raw.name
target["files"]["gemini_raw_output"]["name"] = new_raw.name
_save_index(payload)
if self.session.saved_as == name:
self.session.saved_as = new_name
return target
def delete_record(self, name: str) -> dict[str, Any]:
payload = _reconcile_index(_load_index())
target = None
for e in payload["records"]:
if e.get("record_name") == name:
target = e
break
if target is None:
raise KeyError(f"record not found: {name}")
for key in ("speaker_recording", "gemini_raw_output"):
path = _resolve_record_path(target["files"][key]["path"])
try:
path.unlink()
except FileNotFoundError:
pass
except Exception as exc:
log.warning("couldn't delete %s: %s", path, exc)
payload["records"] = [e for e in payload["records"] if e.get("record_name") != name]
payload["total_records"] = len(payload["records"])
_save_index(payload)
if self.session.saved_as == name:
self.session.saved_as = ""
return {"deleted": name, "total_records": payload["total_records"]}
def play_record(self, name: str, file_kind: str = "speaker") -> dict[str, Any]:
"""Play a saved WAV. file_kind = 'speaker' or 'raw'."""
entry = self.find_record(name)
file_key = "speaker_recording" if file_kind == "speaker" else "gemini_raw_output"
path = _resolve_record_path(entry["files"][file_key]["path"])
if not path.exists():
raise FileNotFoundError(str(path))
with wave.open(str(path), "rb") as wf:
channels = wf.getnchannels()
sample_width = wf.getsampwidth()
sample_rate = wf.getframerate()
frames = wf.readframes(wf.getnframes())
with self._play_lock:
if self.audio_mgr and self.audio_mgr.pya:
stream = self.audio_mgr.pya.open(
format=self.audio_mgr.pya.get_format_from_width(sample_width),
channels=channels, rate=sample_rate,
output=True, frames_per_buffer=CHUNK_SIZE,
)
try:
chunk = CHUNK_SIZE * channels * sample_width
for offset in range(0, len(frames), chunk):
stream.write(frames[offset:offset + chunk])
finally:
stream.stop_stream()
stream.close()
return {
"ok": True, "record_name": name, "file_kind": file_kind,
"duration_sec": round(audio_duration_seconds(
frames, sample_rate, channels, sample_width), 3),
}
# ── status ───────────────────────────────────────────────────
def status(self) -> dict[str, Any]:
return {
"voice_client_connected": bool(
self.voice_client and self.voice_client.connected),
"audio_mgr_ready": bool(self.audio_mgr and self.audio_mgr.pya),
"capture_backend": (self._monitor_config or {}).get("backend", "none"),
"records_dir": str(AUDIO_RECORDINGS_DIR),
"session": self.session.as_status(),
"total_records": len(_load_index().get("records", [])),
}