Sanad/local/vad.py

151 lines
5.3 KiB
Python

"""Silero VAD wrapper — CPU-only speech boundary detection.
Phase 1 of the local pipeline. Consumes 16 kHz mono int16 PCM in short
frames, emits speech_start / speech_end events. All thresholds + frame
sizes come from config/local_config.json > vad.
Install (on the robot):
pip install silero-vad torch==2.2.* torchaudio==2.2.*
Usage:
vad = SileroVAD()
vad.start()
evt = vad.process(pcm_bytes)
if evt == 'speech_start': ...
elif evt == 'speech_end': buf = vad.collected_audio()
"""
from __future__ import annotations
import time
from typing import Optional
import numpy as np
from Project.Sanad.core.config_loader import section as _cfg_section
from Project.Sanad.core.logger import get_logger
log = get_logger("local_vad")
_CFG = _cfg_section("local", "vad")
SAMPLE_RATE = _CFG.get("sample_rate", 16000)
FRAME_MS = _CFG.get("frame_ms", 32)
THRESHOLD = _CFG.get("threshold", 0.55)
MIN_SILENCE_MS = _CFG.get("min_silence_ms", 400)
MIN_SPEECH_MS = _CFG.get("min_speech_ms", 250)
PAD_START_MS = _CFG.get("pad_start_ms", 200)
PAD_END_MS = _CFG.get("pad_end_ms", 200)
FRAME_SAMPLES = SAMPLE_RATE * FRAME_MS // 1000 # 512 @ 16k/32ms
class SileroVAD:
"""Streaming VAD with buffered utterance capture.
Fed one mic frame at a time via `process()`. Internal state tracks
whether we're inside an utterance; on speech_end, `collected_audio()`
returns the full utterance (with configured padding).
"""
def __init__(self) -> None:
self._model = None
self._audio_buf: list[bytes] = [] # utterance being collected
self._pre_buf: list[bytes] = [] # rolling "pre-speech" ring
self._pre_frames = max(1, PAD_START_MS // FRAME_MS)
self._pad_end_frames = max(1, PAD_END_MS // FRAME_MS)
self._in_speech = False
self._last_speech_time = 0.0
self._speech_start_time = 0.0
self._trailing_silence_frames = 0
self._last_utterance: Optional[bytes] = None
def start(self) -> None:
"""Load the Silero model once. Call before `process()`."""
try:
import torch
from silero_vad import load_silero_vad
except ImportError as exc:
raise RuntimeError(
f"SileroVAD requires 'silero-vad' + torch: {exc}"
)
self._model = load_silero_vad()
log.info("SileroVAD ready (threshold=%.2f, frame=%dms)",
THRESHOLD, FRAME_MS)
def process(self, pcm: bytes) -> Optional[str]:
"""Feed one frame (≈ FRAME_MS of audio). Returns an event or None.
Events: 'speech_start' | 'speech_end' | None
"""
if self._model is None:
return None
# keep a rolling pre-buffer so captured utterances include lead-in
self._pre_buf.append(pcm)
if len(self._pre_buf) > self._pre_frames:
self._pre_buf.pop(0)
# VAD expects float32 in [-1, 1]
arr = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0
if arr.size < FRAME_SAMPLES:
# pad if short tail chunk arrived
arr = np.concatenate([arr, np.zeros(FRAME_SAMPLES - arr.size, dtype=np.float32)])
elif arr.size > FRAME_SAMPLES:
arr = arr[:FRAME_SAMPLES]
try:
import torch
with torch.no_grad():
prob = float(self._model(torch.from_numpy(arr), SAMPLE_RATE).item())
except Exception as exc:
log.warning("VAD inference failed: %s", exc)
return None
now = time.time()
is_speech = prob >= THRESHOLD
if is_speech:
self._trailing_silence_frames = 0
self._last_speech_time = now
if not self._in_speech:
# transition → speech
self._in_speech = True
self._speech_start_time = now
self._audio_buf = list(self._pre_buf) # seed with pad
self._audio_buf.append(pcm)
return "speech_start"
self._audio_buf.append(pcm)
return None
# silent frame
if self._in_speech:
self._audio_buf.append(pcm) # collect trailing pad
self._trailing_silence_frames += 1
silence_ms = self._trailing_silence_frames * FRAME_MS
if silence_ms >= MIN_SILENCE_MS:
# speech ended — validate min_speech
speech_dur_ms = (now - self._speech_start_time) * 1000
self._in_speech = False
if speech_dur_ms < MIN_SPEECH_MS:
log.debug("drop short utterance (%.0fms)", speech_dur_ms)
self._audio_buf.clear()
self._last_utterance = None
return None
self._last_utterance = b"".join(self._audio_buf)
self._audio_buf.clear()
return "speech_end"
return None
def collected_audio(self) -> Optional[bytes]:
"""After a speech_end event, return the full utterance bytes."""
return self._last_utterance
def reset(self) -> None:
"""Drop any in-flight utterance (used on barge-in)."""
self._in_speech = False
self._audio_buf.clear()
self._trailing_silence_frames = 0
self._last_utterance = None
def stop(self) -> None:
self._model = None