Marcus/Voice/turn_recorder.py

189 lines
6.6 KiB
Python

"""Per-turn WAV recorder for voice brains.
Direct port of Project/Sanad/voice/sanad_voice.py::TurnRecorder. Saves each
conversation turn as two WAV files:
<timestamp>_user.wav mono int16 @ 16 kHz (what the mic captured)
<timestamp>_robot.wav mono int16 @ 24 kHz (what the brain spoke)
Plus an index.json that appends one entry per turn with the transcripts.
A turn starts when audio first flows through `capture_user` or
`capture_robot`, and ends on `finish_turn`. Call pattern matches Sanad
exactly: `capture_user`, `capture_robot`, `add_user_text`, `add_robot_text`,
`finish_turn`.
Disable via config: stt.gemini_record_enabled = false (the caller passes
`enabled=False`).
"""
from __future__ import annotations
import json
import logging
import os
import threading
import time
import wave
from datetime import datetime
log = logging.getLogger("turn_recorder")
class TurnRecorder:
"""Saves each turn as two WAV files: user mic + model output."""
def __init__(
self,
enabled: bool = True,
out_dir: str = "",
user_rate: int = 16000,
robot_rate: int = 24000,
keep_count: int = 50,
):
self.enabled = bool(enabled) and bool(out_dir)
self.out_dir = out_dir
self.user_rate = int(user_rate)
self.robot_rate = int(robot_rate)
# Cap the number of turn-pairs on disk. 0/negative = unlimited.
# finish_turn() prunes the oldest WAVs after every save so the
# folder can never balloon past keep_count*2 files in the wild.
self.keep_count = int(keep_count)
if self.enabled:
os.makedirs(self.out_dir, exist_ok=True)
self._lock = threading.Lock()
self._user_buf = []
self._robot_buf = []
self._user_text = ""
self._robot_text = ""
self._started_at = 0.0
def capture_user(self, pcm: bytes) -> None:
if not self.enabled or not pcm:
return
with self._lock:
if not self._user_buf and not self._robot_buf:
self._started_at = time.time()
self._user_buf.append(pcm)
def capture_robot(self, pcm: bytes) -> None:
if not self.enabled or not pcm:
return
with self._lock:
if not self._user_buf and not self._robot_buf:
self._started_at = time.time()
self._robot_buf.append(pcm)
def add_user_text(self, text: str) -> None:
if text and self.enabled:
with self._lock:
self._user_text = (self._user_text + " " + text).strip()
def add_robot_text(self, text: str) -> None:
if text and self.enabled:
with self._lock:
self._robot_text = (self._robot_text + " " + text).strip()
def finish_turn(self) -> dict:
if not self.enabled:
return {}
with self._lock:
user_data = b"".join(self._user_buf)
robot_data = b"".join(self._robot_buf)
user_text = self._user_text
robot_text = self._robot_text
started_at = self._started_at
self._user_buf.clear()
self._robot_buf.clear()
self._user_text = ""
self._robot_text = ""
if not user_data and not robot_data:
return {}
stamp = datetime.fromtimestamp(started_at).strftime("%Y%m%d_%H%M%S")
entry = {
"timestamp": stamp,
"started_at": started_at,
"user_text": user_text,
"robot_text": robot_text,
}
try:
if user_data:
p = os.path.join(self.out_dir, "{}_user.wav".format(stamp))
self._save_wav(p, user_data, self.user_rate)
entry["user_wav"] = p
entry["user_duration_sec"] = round(
len(user_data) / (self.user_rate * 2), 3,
)
if robot_data:
p = os.path.join(self.out_dir, "{}_robot.wav".format(stamp))
self._save_wav(p, robot_data, self.robot_rate)
entry["robot_wav"] = p
entry["robot_duration_sec"] = round(
len(robot_data) / (self.robot_rate * 2), 3,
)
self._append_index(entry)
self._rotate()
log.info(
"recorded turn → %s (user %.1fs, robot %.1fs)",
stamp,
entry.get("user_duration_sec", 0),
entry.get("robot_duration_sec", 0),
)
except Exception as exc:
log.warning("recording save failed: %s", exc)
return entry
def _rotate(self) -> None:
"""Prune oldest WAV files when count exceeds keep_count*2 (one
user + one robot per turn). index.json is never touched."""
if self.keep_count <= 0:
return
try:
wavs = [
f for f in os.listdir(self.out_dir)
if f.endswith(".wav")
]
if len(wavs) <= self.keep_count * 2:
return
# Sort oldest-first by mtime; remove the head of the list.
wavs_paths = [os.path.join(self.out_dir, f) for f in wavs]
wavs_paths.sort(key=lambda p: os.path.getmtime(p))
to_remove = wavs_paths[: len(wavs_paths) - self.keep_count * 2]
for p in to_remove:
try:
os.remove(p)
except Exception:
pass
except Exception as exc:
log.warning("recording rotate failed: %s", exc)
@staticmethod
def _save_wav(path: str, pcm: bytes, rate: int) -> None:
with wave.open(path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(rate)
wf.writeframes(pcm)
def _append_index(self, entry: dict) -> None:
idx_path = os.path.join(self.out_dir, "index.json")
try:
if os.path.exists(idx_path):
with open(idx_path, "r", encoding="utf-8") as f:
payload = json.load(f)
if not isinstance(payload, dict):
payload = {"records": []}
else:
payload = {"records": []}
except Exception:
payload = {"records": []}
payload.setdefault("records", []).append(entry)
payload["total_records"] = len(payload["records"])
try:
with open(idx_path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
except Exception as exc:
log.warning("index.json write failed: %s", exc)