189 lines
6.6 KiB
Python
189 lines
6.6 KiB
Python
"""Per-turn WAV recorder for voice brains.
|
|
|
|
Direct port of Project/Sanad/voice/sanad_voice.py::TurnRecorder. Saves each
|
|
conversation turn as two WAV files:
|
|
|
|
<timestamp>_user.wav mono int16 @ 16 kHz (what the mic captured)
|
|
<timestamp>_robot.wav mono int16 @ 24 kHz (what the brain spoke)
|
|
|
|
Plus an index.json that appends one entry per turn with the transcripts.
|
|
|
|
A turn starts when audio first flows through `capture_user` or
|
|
`capture_robot`, and ends on `finish_turn`. Call pattern matches Sanad
|
|
exactly: `capture_user`, `capture_robot`, `add_user_text`, `add_robot_text`,
|
|
`finish_turn`.
|
|
|
|
Disable via config: stt.gemini_record_enabled = false (the caller passes
|
|
`enabled=False`).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
import wave
|
|
from datetime import datetime
|
|
|
|
log = logging.getLogger("turn_recorder")
|
|
|
|
|
|
class TurnRecorder:
|
|
"""Saves each turn as two WAV files: user mic + model output."""
|
|
|
|
def __init__(
|
|
self,
|
|
enabled: bool = True,
|
|
out_dir: str = "",
|
|
user_rate: int = 16000,
|
|
robot_rate: int = 24000,
|
|
keep_count: int = 50,
|
|
):
|
|
self.enabled = bool(enabled) and bool(out_dir)
|
|
self.out_dir = out_dir
|
|
self.user_rate = int(user_rate)
|
|
self.robot_rate = int(robot_rate)
|
|
# Cap the number of turn-pairs on disk. 0/negative = unlimited.
|
|
# finish_turn() prunes the oldest WAVs after every save so the
|
|
# folder can never balloon past keep_count*2 files in the wild.
|
|
self.keep_count = int(keep_count)
|
|
if self.enabled:
|
|
os.makedirs(self.out_dir, exist_ok=True)
|
|
self._lock = threading.Lock()
|
|
self._user_buf = []
|
|
self._robot_buf = []
|
|
self._user_text = ""
|
|
self._robot_text = ""
|
|
self._started_at = 0.0
|
|
|
|
def capture_user(self, pcm: bytes) -> None:
|
|
if not self.enabled or not pcm:
|
|
return
|
|
with self._lock:
|
|
if not self._user_buf and not self._robot_buf:
|
|
self._started_at = time.time()
|
|
self._user_buf.append(pcm)
|
|
|
|
def capture_robot(self, pcm: bytes) -> None:
|
|
if not self.enabled or not pcm:
|
|
return
|
|
with self._lock:
|
|
if not self._user_buf and not self._robot_buf:
|
|
self._started_at = time.time()
|
|
self._robot_buf.append(pcm)
|
|
|
|
def add_user_text(self, text: str) -> None:
|
|
if text and self.enabled:
|
|
with self._lock:
|
|
self._user_text = (self._user_text + " " + text).strip()
|
|
|
|
def add_robot_text(self, text: str) -> None:
|
|
if text and self.enabled:
|
|
with self._lock:
|
|
self._robot_text = (self._robot_text + " " + text).strip()
|
|
|
|
def finish_turn(self) -> dict:
|
|
if not self.enabled:
|
|
return {}
|
|
with self._lock:
|
|
user_data = b"".join(self._user_buf)
|
|
robot_data = b"".join(self._robot_buf)
|
|
user_text = self._user_text
|
|
robot_text = self._robot_text
|
|
started_at = self._started_at
|
|
self._user_buf.clear()
|
|
self._robot_buf.clear()
|
|
self._user_text = ""
|
|
self._robot_text = ""
|
|
|
|
if not user_data and not robot_data:
|
|
return {}
|
|
|
|
stamp = datetime.fromtimestamp(started_at).strftime("%Y%m%d_%H%M%S")
|
|
entry = {
|
|
"timestamp": stamp,
|
|
"started_at": started_at,
|
|
"user_text": user_text,
|
|
"robot_text": robot_text,
|
|
}
|
|
try:
|
|
if user_data:
|
|
p = os.path.join(self.out_dir, "{}_user.wav".format(stamp))
|
|
self._save_wav(p, user_data, self.user_rate)
|
|
entry["user_wav"] = p
|
|
entry["user_duration_sec"] = round(
|
|
len(user_data) / (self.user_rate * 2), 3,
|
|
)
|
|
if robot_data:
|
|
p = os.path.join(self.out_dir, "{}_robot.wav".format(stamp))
|
|
self._save_wav(p, robot_data, self.robot_rate)
|
|
entry["robot_wav"] = p
|
|
entry["robot_duration_sec"] = round(
|
|
len(robot_data) / (self.robot_rate * 2), 3,
|
|
)
|
|
self._append_index(entry)
|
|
self._rotate()
|
|
log.info(
|
|
"recorded turn → %s (user %.1fs, robot %.1fs)",
|
|
stamp,
|
|
entry.get("user_duration_sec", 0),
|
|
entry.get("robot_duration_sec", 0),
|
|
)
|
|
except Exception as exc:
|
|
log.warning("recording save failed: %s", exc)
|
|
return entry
|
|
|
|
def _rotate(self) -> None:
|
|
"""Prune oldest WAV files when count exceeds keep_count*2 (one
|
|
user + one robot per turn). index.json is never touched."""
|
|
if self.keep_count <= 0:
|
|
return
|
|
try:
|
|
wavs = [
|
|
f for f in os.listdir(self.out_dir)
|
|
if f.endswith(".wav")
|
|
]
|
|
if len(wavs) <= self.keep_count * 2:
|
|
return
|
|
# Sort oldest-first by mtime; remove the head of the list.
|
|
wavs_paths = [os.path.join(self.out_dir, f) for f in wavs]
|
|
wavs_paths.sort(key=lambda p: os.path.getmtime(p))
|
|
to_remove = wavs_paths[: len(wavs_paths) - self.keep_count * 2]
|
|
for p in to_remove:
|
|
try:
|
|
os.remove(p)
|
|
except Exception:
|
|
pass
|
|
except Exception as exc:
|
|
log.warning("recording rotate failed: %s", exc)
|
|
|
|
@staticmethod
|
|
def _save_wav(path: str, pcm: bytes, rate: int) -> None:
|
|
with wave.open(path, "wb") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(rate)
|
|
wf.writeframes(pcm)
|
|
|
|
def _append_index(self, entry: dict) -> None:
|
|
idx_path = os.path.join(self.out_dir, "index.json")
|
|
try:
|
|
if os.path.exists(idx_path):
|
|
with open(idx_path, "r", encoding="utf-8") as f:
|
|
payload = json.load(f)
|
|
if not isinstance(payload, dict):
|
|
payload = {"records": []}
|
|
else:
|
|
payload = {"records": []}
|
|
except Exception:
|
|
payload = {"records": []}
|
|
payload.setdefault("records", []).append(entry)
|
|
payload["total_records"] = len(payload["records"])
|
|
try:
|
|
with open(idx_path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, indent=2, ensure_ascii=False)
|
|
except Exception as exc:
|
|
log.warning("index.json write failed: %s", exc)
|