159 lines
7.1 KiB
Python
159 lines
7.1 KiB
Python
"""Template brain — copy this file to plug in a non-Gemini model.
|
|
|
|
How to use:
|
|
1. Copy this file: `cp voice/model_script.py voice/openai_script.py`
|
|
2. Rename the class: `ModelBrain` → e.g. `OpenAIRealtimeBrain`
|
|
3. Fill in every block marked `TODO` with your provider's SDK calls.
|
|
4. Register the new brain in `voice/sanad_voice.py` inside
|
|
`_build_brain()` (there's a single `elif` to add).
|
|
5. Run with `SANAD_VOICE_BRAIN=openai python3 voice/sanad_voice.py eth0`.
|
|
|
|
Contract that `sanad_voice.py` expects of ANY brain:
|
|
__init__(audio_io, recorder, voice_name, system_prompt)
|
|
audio_io — voice.audio_io.AudioIO (exposes .mic + .speaker)
|
|
recorder — voice.sanad_voice.TurnRecorder (per-turn WAV capture)
|
|
voice_name — provider-specific voice id (e.g. "Charon", "alloy")
|
|
system_prompt — persona string to seed the session with
|
|
async run() — blocks until stopped or fatal. Reconnects are YOUR
|
|
responsibility; the orchestrator won't restart you.
|
|
stop() — sync signal (can be called from a signal handler).
|
|
Set an asyncio.Event and let `run()` notice it.
|
|
|
|
What the mic side looks like:
|
|
data = self._mic.read_chunk(n_bytes) # 16 kHz int16 mono bytes
|
|
# send `data` to your model's realtime-audio endpoint
|
|
|
|
What the speaker side looks like:
|
|
self._speaker.begin_stream()
|
|
self._speaker.send_chunk(pcm, source_rate=24000) # rate is yours
|
|
self._speaker.wait_finish() # blocks until playback drains
|
|
# or self._speaker.stop() # cancel mid-playback (barge-in)
|
|
|
|
What the recorder side looks like:
|
|
self._recorder.capture_user(pcm_bytes) # mic audio for this turn
|
|
self._recorder.capture_robot(pcm_bytes) # model audio for this turn
|
|
self._recorder.add_user_text(str) # partial transcript
|
|
self._recorder.add_robot_text(str) # partial transcript
|
|
self._recorder.finish_turn() # flush to WAV + index.json
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any, Optional
|
|
|
|
from Project.Sanad.core.logger import get_logger
|
|
|
|
log = get_logger("model_brain")
|
|
|
|
|
|
class ModelBrain:
|
|
"""Skeleton voice brain — adapt to your provider."""
|
|
|
|
def __init__(self, audio_io, recorder, voice_name: Optional[str] = None,
|
|
system_prompt: str = ""):
|
|
self._audio = audio_io
|
|
self._mic = audio_io.mic
|
|
self._speaker = audio_io.speaker
|
|
self._recorder = recorder
|
|
self._voice = voice_name
|
|
self._system_prompt = system_prompt
|
|
self._stop_flag = asyncio.Event()
|
|
|
|
# TODO: instantiate your provider's client here. Keep the client
|
|
# creation cheap — connection/handshake should happen inside `run()`
|
|
# so reconnects don't require re-building this object.
|
|
# Example:
|
|
# from openai import AsyncOpenAI
|
|
# self._client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
|
self._client: Any = None
|
|
|
|
# ─── lifecycle ────────────────────────────────────────
|
|
|
|
def stop(self) -> None:
|
|
"""Signal the run loop to exit cleanly. Safe to call from anywhere."""
|
|
self._stop_flag.set()
|
|
|
|
async def run(self) -> None:
|
|
"""Main conversation loop. Blocks until stopped.
|
|
|
|
Responsibilities:
|
|
- Open a realtime session with your provider.
|
|
- Forward mic audio to the model in small chunks.
|
|
- Stream the model's audio response to the speaker.
|
|
- Drive barge-in: when the user speaks while the model is speaking,
|
|
cancel model playback and mark the turn interrupted.
|
|
- On disconnect/error, back off and reconnect.
|
|
"""
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
log.info("connecting to model...")
|
|
# TODO: open a session with your provider. For websocket-style
|
|
# APIs, use `async with client.realtime.connect(...) as session:`.
|
|
# For request/response APIs, poll or stream in a loop.
|
|
await asyncio.gather(
|
|
self._send_mic_loop(),
|
|
self._receive_loop(),
|
|
)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as exc:
|
|
log.error("session error: %s — reconnecting in 2s", exc)
|
|
await asyncio.sleep(2)
|
|
|
|
# ─── mic → model ──────────────────────────────────────
|
|
|
|
async def _send_mic_loop(self) -> None:
|
|
"""Read mic chunks and forward them to the model.
|
|
|
|
Minimum responsibilities:
|
|
- Loop on `self._mic.read_chunk(N_BYTES)`.
|
|
- Encode to whatever format your provider expects
|
|
(PCM16 mono is standard; some want base64 in JSON frames).
|
|
- Respect `self._stop_flag`.
|
|
|
|
Optional (highly recommended):
|
|
- Measure energy; feed the mic frame to `self._recorder.capture_user`
|
|
only when the user is actually speaking.
|
|
- Apply echo suppression while the speaker is playing (mute or
|
|
substitute silence when energy is low — keeps the model from
|
|
transcribing its own voice bleed).
|
|
"""
|
|
chunk_bytes = 1024 # 32 ms at 16 kHz mono int16 — tune to your API
|
|
loop = asyncio.get_event_loop()
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
data = await loop.run_in_executor(
|
|
None, self._mic.read_chunk, chunk_bytes,
|
|
)
|
|
except Exception:
|
|
break
|
|
|
|
# TODO: forward `data` to the model. Example for a hypothetical
|
|
# websocket session:
|
|
# await session.send({"type": "audio", "pcm16": data})
|
|
_ = data
|
|
|
|
# Pace to real-time so we don't starve the event loop
|
|
await asyncio.sleep(chunk_bytes / (16000 * 2))
|
|
|
|
# ─── model → speaker ──────────────────────────────────
|
|
|
|
async def _receive_loop(self) -> None:
|
|
"""Receive model events (audio chunks, transcripts, turn markers).
|
|
|
|
Event handling you need to implement:
|
|
- Audio chunk → `self._speaker.send_chunk(pcm, source_rate)`
|
|
(first chunk must be preceded by
|
|
`self._speaker.begin_stream()`).
|
|
- Model interrupted → `self._speaker.stop(); self._mic.flush()`
|
|
and call `self._recorder.finish_turn()`.
|
|
- User transcript → `self._recorder.add_user_text(text)`.
|
|
- Model transcript → `self._recorder.add_robot_text(text)`.
|
|
- Turn complete → `self._speaker.wait_finish();
|
|
self._recorder.finish_turn(); mic.flush()`.
|
|
"""
|
|
while not self._stop_flag.is_set():
|
|
# TODO: iterate your provider's event stream and dispatch.
|
|
await asyncio.sleep(0.1)
|