260 lines
9.6 KiB
Python
260 lines
9.6 KiB
Python
"""LocalBrain — fully on-device voice pipeline.
|
|
|
|
Implements the same contract as `gemini/script.py:GeminiBrain` so
|
|
`voice/sanad_voice.py` can swap it in via `SANAD_VOICE_BRAIN=local`.
|
|
Wires together four subsystems:
|
|
|
|
Phase 1 — Silero VAD (mic → speech boundaries)
|
|
Phase 2 — faster-whisper (speech → text)
|
|
Phase 3 — llama.cpp + Qwen (text → streaming text chunks)
|
|
Phase 4 — CosyVoice2 streaming (text chunk → cloned-voice audio)
|
|
Phase 5 — barge-in (user speaks → cancel LLM + stop speaker)
|
|
Phase 6 — stability — model load fails cleanly, crashes are logged.
|
|
|
|
Async structure:
|
|
run() is the main coroutine. It spawns three tasks:
|
|
_mic_task — reads mic, VAD, Whisper, pushes user text to _llm_queue
|
|
_dialogue_task — pops user text, streams LLM tokens into _tts_queue
|
|
_tts_task — pops text chunks, synthesises, feeds the speaker
|
|
|
|
Logging contract (matched by local/subprocess.py._track_line):
|
|
"connecting to local pipeline"
|
|
"listening"
|
|
"USER: <text>"
|
|
"BOT: <text>"
|
|
"BARGE-IN (local)"
|
|
"session error: <msg>"
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Optional
|
|
|
|
from Project.Sanad.core.config_loader import section as _cfg_section
|
|
from Project.Sanad.core.logger import get_logger
|
|
|
|
from Project.Sanad.local.llm import LlamaServer
|
|
from Project.Sanad.local.stt import WhisperSTT
|
|
from Project.Sanad.local.tts import CosyVoiceTTS
|
|
from Project.Sanad.local.vad import SileroVAD, FRAME_SAMPLES
|
|
|
|
log = get_logger("local_brain")
|
|
|
|
_CFG_SV = _cfg_section("voice", "sanad_voice")
|
|
_CHUNK_BYTES = FRAME_SAMPLES * 2 # int16 mono
|
|
|
|
|
|
class LocalBrain:
|
|
"""Fully on-device Gemini replacement."""
|
|
|
|
def __init__(self, audio_io, recorder, voice_name: Optional[str] = None,
|
|
system_prompt: str = ""):
|
|
self._audio = audio_io
|
|
self._mic = audio_io.mic
|
|
self._speaker = audio_io.speaker
|
|
self._recorder = recorder
|
|
self._voice = voice_name
|
|
self._system_prompt = system_prompt
|
|
|
|
# subsystems — instantiated here, loaded in run()
|
|
self._vad = SileroVAD()
|
|
self._stt = WhisperSTT()
|
|
self._llm = LlamaServer()
|
|
self._tts = CosyVoiceTTS()
|
|
|
|
# pipeline queues
|
|
self._llm_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=4)
|
|
self._tts_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=4)
|
|
|
|
# control flags
|
|
self._stop_flag = asyncio.Event() # full shutdown
|
|
self._interrupt = asyncio.Event() # per-turn barge-in
|
|
self._speaking = False
|
|
self._speak_start_time = 0.0
|
|
|
|
# ─── lifecycle ────────────────────────────────────────
|
|
|
|
def stop(self) -> None:
|
|
self._stop_flag.set()
|
|
self._interrupt.set()
|
|
|
|
async def run(self) -> None:
|
|
"""Main entry. Loads models, runs pipeline, handles shutdown."""
|
|
log.info("connecting to local pipeline")
|
|
try:
|
|
await asyncio.to_thread(self._vad.start)
|
|
await asyncio.to_thread(self._stt.start)
|
|
await asyncio.to_thread(self._llm.start)
|
|
await asyncio.to_thread(self._tts.start)
|
|
except Exception as exc:
|
|
log.error("session error: local pipeline startup failed — %s", exc)
|
|
return
|
|
|
|
log.info("listening")
|
|
try:
|
|
await asyncio.gather(
|
|
self._mic_task(),
|
|
self._dialogue_task(),
|
|
self._tts_task(),
|
|
)
|
|
except asyncio.CancelledError:
|
|
log.info("cancelled — stopping")
|
|
except Exception as exc:
|
|
log.error("session error: %s", exc)
|
|
finally:
|
|
try:
|
|
self._llm.stop()
|
|
except Exception:
|
|
log.warning("LlamaServer.stop failed", exc_info=True)
|
|
self._tts.stop()
|
|
self._stt.stop()
|
|
self._vad.stop()
|
|
log.info("local pipeline stopped")
|
|
|
|
# ─── barge-in ─────────────────────────────────────────
|
|
|
|
def _begin_barge_in(self) -> None:
|
|
"""Called from mic task when user starts speaking while bot is."""
|
|
if not self._speaking:
|
|
return
|
|
log.info("BARGE-IN (local)")
|
|
self._interrupt.set()
|
|
try:
|
|
self._speaker.stop()
|
|
except Exception:
|
|
log.warning("speaker.stop during barge-in failed", exc_info=True)
|
|
# drain pipelines — discard any pending LLM/TTS chunks for this turn
|
|
self._drain_queue(self._llm_queue)
|
|
self._drain_queue(self._tts_queue)
|
|
self._speaking = False
|
|
try:
|
|
self._recorder.finish_turn()
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _drain_queue(q: asyncio.Queue) -> None:
|
|
try:
|
|
while True:
|
|
q.get_nowait()
|
|
q.task_done()
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
|
|
# ─── Task 1: mic → VAD → Whisper → LLM queue ──────────
|
|
|
|
async def _mic_task(self) -> None:
|
|
loop = asyncio.get_event_loop()
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
pcm = await loop.run_in_executor(
|
|
None, self._mic.read_chunk, _CHUNK_BYTES,
|
|
)
|
|
except Exception:
|
|
await asyncio.sleep(0.01)
|
|
continue
|
|
|
|
event = self._vad.process(pcm)
|
|
if event == "speech_start":
|
|
# user started talking — if bot is speaking, it's a barge-in
|
|
if self._speaking:
|
|
self._begin_barge_in()
|
|
elif event == "speech_end":
|
|
utt = self._vad.collected_audio()
|
|
if not utt:
|
|
continue
|
|
try:
|
|
self._recorder.capture_user(utt)
|
|
except Exception:
|
|
pass
|
|
text = await loop.run_in_executor(None, self._stt.transcribe, utt)
|
|
if not text:
|
|
continue
|
|
log.info("USER: %s", text)
|
|
try:
|
|
self._recorder.add_user_text(text)
|
|
except Exception:
|
|
pass
|
|
# wake the LLM side — drop older pending item if full (latency > throughput)
|
|
if self._llm_queue.full():
|
|
try:
|
|
self._llm_queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
await self._llm_queue.put(text)
|
|
|
|
# ─── Task 2: LLM streaming → TTS queue ────────────────
|
|
|
|
async def _dialogue_task(self) -> None:
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
user_text = await asyncio.wait_for(
|
|
self._llm_queue.get(), timeout=0.2)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
self._interrupt.clear()
|
|
full_response = []
|
|
async for chunk in self._llm.stream(
|
|
user_text, self._system_prompt, self._interrupt):
|
|
if self._interrupt.is_set():
|
|
break
|
|
full_response.append(chunk)
|
|
await self._tts_queue.put(chunk)
|
|
self._llm_queue.task_done()
|
|
if full_response and not self._interrupt.is_set():
|
|
bot_text = " ".join(full_response).strip()
|
|
if bot_text:
|
|
log.info("BOT: %s", bot_text)
|
|
try:
|
|
self._recorder.add_robot_text(bot_text)
|
|
except Exception:
|
|
pass
|
|
|
|
# ─── Task 3: TTS → speaker ────────────────────────────
|
|
|
|
async def _tts_task(self) -> None:
|
|
loop = asyncio.get_event_loop()
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
chunk_text = await asyncio.wait_for(
|
|
self._tts_queue.get(), timeout=0.2)
|
|
except asyncio.TimeoutError:
|
|
# idle — if we've been speaking and queue drained, close stream
|
|
if self._speaking and self._llm_queue.empty() and self._tts_queue.empty():
|
|
await loop.run_in_executor(None, self._speaker.wait_finish)
|
|
self._speaking = False
|
|
log.info("listening")
|
|
try:
|
|
self._recorder.finish_turn()
|
|
except Exception:
|
|
pass
|
|
continue
|
|
if self._interrupt.is_set():
|
|
self._tts_queue.task_done()
|
|
continue
|
|
|
|
# synthesise this text chunk → stream to speaker
|
|
if not self._speaking:
|
|
await loop.run_in_executor(None, self._speaker.begin_stream)
|
|
self._speaking = True
|
|
self._speak_start_time = time.time()
|
|
|
|
try:
|
|
for pcm in self._tts.synthesize_stream(chunk_text):
|
|
if self._interrupt.is_set():
|
|
break
|
|
try:
|
|
self._recorder.capture_robot(pcm)
|
|
except Exception:
|
|
pass
|
|
await loop.run_in_executor(
|
|
None, self._speaker.send_chunk,
|
|
pcm, self._tts.output_rate,
|
|
)
|
|
except Exception as exc:
|
|
log.warning("TTS chunk failed: %s", exc)
|
|
finally:
|
|
self._tts_queue.task_done()
|