151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
"""Silero VAD wrapper — CPU-only speech boundary detection.
|
|
|
|
Phase 1 of the local pipeline. Consumes 16 kHz mono int16 PCM in short
|
|
frames, emits speech_start / speech_end events. All thresholds + frame
|
|
sizes come from config/local_config.json > vad.
|
|
|
|
Install (on the robot):
|
|
pip install silero-vad torch==2.2.* torchaudio==2.2.*
|
|
|
|
Usage:
|
|
vad = SileroVAD()
|
|
vad.start()
|
|
evt = vad.process(pcm_bytes)
|
|
if evt == 'speech_start': ...
|
|
elif evt == 'speech_end': buf = vad.collected_audio()
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
from Project.Sanad.core.config_loader import section as _cfg_section
|
|
from Project.Sanad.core.logger import get_logger
|
|
|
|
log = get_logger("local_vad")
|
|
_CFG = _cfg_section("local", "vad")
|
|
|
|
SAMPLE_RATE = _CFG.get("sample_rate", 16000)
|
|
FRAME_MS = _CFG.get("frame_ms", 32)
|
|
THRESHOLD = _CFG.get("threshold", 0.55)
|
|
MIN_SILENCE_MS = _CFG.get("min_silence_ms", 400)
|
|
MIN_SPEECH_MS = _CFG.get("min_speech_ms", 250)
|
|
PAD_START_MS = _CFG.get("pad_start_ms", 200)
|
|
PAD_END_MS = _CFG.get("pad_end_ms", 200)
|
|
|
|
FRAME_SAMPLES = SAMPLE_RATE * FRAME_MS // 1000 # 512 @ 16k/32ms
|
|
|
|
|
|
class SileroVAD:
|
|
"""Streaming VAD with buffered utterance capture.
|
|
|
|
Fed one mic frame at a time via `process()`. Internal state tracks
|
|
whether we're inside an utterance; on speech_end, `collected_audio()`
|
|
returns the full utterance (with configured padding).
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._model = None
|
|
self._audio_buf: list[bytes] = [] # utterance being collected
|
|
self._pre_buf: list[bytes] = [] # rolling "pre-speech" ring
|
|
self._pre_frames = max(1, PAD_START_MS // FRAME_MS)
|
|
self._pad_end_frames = max(1, PAD_END_MS // FRAME_MS)
|
|
self._in_speech = False
|
|
self._last_speech_time = 0.0
|
|
self._speech_start_time = 0.0
|
|
self._trailing_silence_frames = 0
|
|
self._last_utterance: Optional[bytes] = None
|
|
|
|
def start(self) -> None:
|
|
"""Load the Silero model once. Call before `process()`."""
|
|
try:
|
|
import torch
|
|
from silero_vad import load_silero_vad
|
|
except ImportError as exc:
|
|
raise RuntimeError(
|
|
f"SileroVAD requires 'silero-vad' + torch: {exc}"
|
|
)
|
|
self._model = load_silero_vad()
|
|
log.info("SileroVAD ready (threshold=%.2f, frame=%dms)",
|
|
THRESHOLD, FRAME_MS)
|
|
|
|
def process(self, pcm: bytes) -> Optional[str]:
|
|
"""Feed one frame (≈ FRAME_MS of audio). Returns an event or None.
|
|
|
|
Events: 'speech_start' | 'speech_end' | None
|
|
"""
|
|
if self._model is None:
|
|
return None
|
|
# keep a rolling pre-buffer so captured utterances include lead-in
|
|
self._pre_buf.append(pcm)
|
|
if len(self._pre_buf) > self._pre_frames:
|
|
self._pre_buf.pop(0)
|
|
|
|
# VAD expects float32 in [-1, 1]
|
|
arr = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0
|
|
if arr.size < FRAME_SAMPLES:
|
|
# pad if short tail chunk arrived
|
|
arr = np.concatenate([arr, np.zeros(FRAME_SAMPLES - arr.size, dtype=np.float32)])
|
|
elif arr.size > FRAME_SAMPLES:
|
|
arr = arr[:FRAME_SAMPLES]
|
|
|
|
try:
|
|
import torch
|
|
with torch.no_grad():
|
|
prob = float(self._model(torch.from_numpy(arr), SAMPLE_RATE).item())
|
|
except Exception as exc:
|
|
log.warning("VAD inference failed: %s", exc)
|
|
return None
|
|
|
|
now = time.time()
|
|
is_speech = prob >= THRESHOLD
|
|
|
|
if is_speech:
|
|
self._trailing_silence_frames = 0
|
|
self._last_speech_time = now
|
|
if not self._in_speech:
|
|
# transition → speech
|
|
self._in_speech = True
|
|
self._speech_start_time = now
|
|
self._audio_buf = list(self._pre_buf) # seed with pad
|
|
self._audio_buf.append(pcm)
|
|
return "speech_start"
|
|
self._audio_buf.append(pcm)
|
|
return None
|
|
|
|
# silent frame
|
|
if self._in_speech:
|
|
self._audio_buf.append(pcm) # collect trailing pad
|
|
self._trailing_silence_frames += 1
|
|
silence_ms = self._trailing_silence_frames * FRAME_MS
|
|
if silence_ms >= MIN_SILENCE_MS:
|
|
# speech ended — validate min_speech
|
|
speech_dur_ms = (now - self._speech_start_time) * 1000
|
|
self._in_speech = False
|
|
if speech_dur_ms < MIN_SPEECH_MS:
|
|
log.debug("drop short utterance (%.0fms)", speech_dur_ms)
|
|
self._audio_buf.clear()
|
|
self._last_utterance = None
|
|
return None
|
|
self._last_utterance = b"".join(self._audio_buf)
|
|
self._audio_buf.clear()
|
|
return "speech_end"
|
|
return None
|
|
|
|
def collected_audio(self) -> Optional[bytes]:
|
|
"""After a speech_end event, return the full utterance bytes."""
|
|
return self._last_utterance
|
|
|
|
def reset(self) -> None:
|
|
"""Drop any in-flight utterance (used on barge-in)."""
|
|
self._in_speech = False
|
|
self._audio_buf.clear()
|
|
self._trailing_silence_frames = 0
|
|
self._last_utterance = None
|
|
|
|
def stop(self) -> None:
|
|
self._model = None
|