Sanad/voice/gemini_client.py

342 lines
12 KiB
Python

"""Gemini WebSocket client for real-time voice interaction.
Provides:
- Bidirectional audio streaming (mic → Gemini → speaker)
- Text-to-speech via typed input
- Voice-command detection through transcription parsing
- System instruction injection for persona control
"""
from __future__ import annotations
import asyncio
import base64
import inspect
import json
from typing import Any
import websockets
from Project.Sanad.config import (
GEMINI_API_KEY,
GEMINI_MODEL,
GEMINI_VOICE,
GEMINI_WS_TIMEOUT,
GEMINI_WS_URI,
)
from Project.Sanad.core.config_loader import section as _cfg_section
from Project.Sanad.core.event_bus import bus
from Project.Sanad.core.logger import get_logger
log = get_logger("gemini_client")
_GC = _cfg_section("voice", "gemini_client")
# Default system prompt — SINGLE SOURCE in core.gemini_defaults
_DEFAULT_SYSTEM_PROMPT = _cfg_section("core", "gemini_defaults").get(
"default_system_prompt",
"You are Sanad (Bousandah), a wise and friendly Emirati assistant. "
"Speak in UAE dialect (Khaleeji). Be helpful and concise."
)
_RECV_TIMEOUT_SEC = _GC.get("recv_timeout_sec", 30)
_RECONNECT_MAX_ATTEMPTS = _GC.get("reconnect_max_attempts", 3)
_RECONNECT_INITIAL_DELAY_SEC = _GC.get("reconnect_initial_delay_sec", 1.0)
_RECONNECT_MAX_DELAY_SEC = _GC.get("reconnect_max_delay_sec", 10.0)
class GeminiVoiceClient:
"""Manages one WebSocket session to the Gemini Bidi audio API.
Concurrency model:
- `_send_lock` serializes ALL websocket writes.
- `_session_lock` ensures only one consumer (live loop OR typed replay)
owns the receive stream at a time. Acquired by send_text and
receive_stream context managers.
- `_owner` records who currently holds the session lock for diagnostics.
"""
def __init__(self, system_prompt: str = ""):
self.system_prompt = system_prompt or _DEFAULT_SYSTEM_PROMPT
self._ws: Any = None
self._connected = False
self._send_lock = asyncio.Lock()
self._session_lock = asyncio.Lock()
self._connect_lock = asyncio.Lock() # serializes reconnect attempts
self._owner: str | None = None
self._reconnect_attempts = 0
@property
def connected(self) -> bool:
return self._connected
@property
def session_owner(self) -> str | None:
return self._owner
def _ws_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {"max_size": None, "open_timeout": 30}
try:
sig = inspect.signature(websockets.connect)
key = "extra_headers" if "extra_headers" in sig.parameters else "additional_headers"
except Exception:
key = "extra_headers"
kwargs[key] = {"Content-Type": "application/json"}
return kwargs
async def connect(self):
uri = f"{GEMINI_WS_URI}?key={GEMINI_API_KEY}"
try:
self._ws = await websockets.connect(uri, **self._ws_kwargs())
setup = {
"setup": {
"model": GEMINI_MODEL,
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {"voiceName": GEMINI_VOICE}
}
},
},
"systemInstruction": {"parts": [{"text": self.system_prompt}]},
}
}
await self._ws.send(json.dumps(setup))
await self._ws.recv() # ACK
self._connected = True
self._reconnect_attempts = 0
log.info("Connected to Gemini (%s)", GEMINI_MODEL)
await bus.emit("voice.connected")
except Exception:
self._connected = False
self._ws = None
log.exception("Failed to connect to Gemini")
raise
async def disconnect(self):
try:
if self._ws is not None:
await self._ws.close()
except Exception:
pass
finally:
self._ws = None
self._connected = False
self._owner = None
log.info("Disconnected from Gemini")
await bus.emit("voice.disconnected")
async def _ensure_connected(self):
"""Reconnect if dropped, with bounded retries.
Serialized via _connect_lock so concurrent callers don't trigger
duplicate handshakes.
"""
# Fast path — no lock needed
if self._connected and self._ws is not None:
return True
async with self._connect_lock:
# Re-check inside the lock (another coroutine may have just connected)
if self._connected and self._ws is not None:
return True
max_attempts = _RECONNECT_MAX_ATTEMPTS
delay = _RECONNECT_INITIAL_DELAY_SEC
for attempt in range(max_attempts):
try:
log.warning("Reconnecting to Gemini (attempt %d/%d)", attempt + 1, max_attempts)
await self.connect()
return True
except Exception:
self._reconnect_attempts += 1
await asyncio.sleep(delay)
delay = min(delay * 2, _RECONNECT_MAX_DELAY_SEC)
log.error("Reconnect failed after %d attempts", max_attempts)
await bus.emit("voice.error", reason="reconnect_failed")
return False
async def send_audio_chunk(self, pcm_b64: str) -> bool:
"""Send a base64-encoded PCM audio chunk (mic input).
Returns False on failure so the caller can react instead of silently
no-op'ing forever (the original bug).
"""
if not self._connected or self._ws is None:
return False
msg = {
"realtimeInput": {
"mediaChunks": [
{"mimeType": "audio/pcm;rate=16000", "data": pcm_b64}
]
}
}
try:
async with self._send_lock:
await self._ws.send(json.dumps(msg))
return True
except websockets.exceptions.ConnectionClosed:
log.warning("send_audio_chunk: connection closed")
self._connected = False
await bus.emit("voice.error", reason="connection_closed")
return False
except Exception:
log.exception("send_audio_chunk failed")
return False
async def send_text(self, text: str, owner: str = "send_text") -> tuple[bytes, list[str]]:
"""Send text, receive audio response. Returns (audio_bytes, text_parts).
Acquires the session lock for the entire request/response cycle so
no other consumer can steal frames from the receive side.
If the connection drops mid-request, reconnects once and retries.
"""
if not await self._ensure_connected():
raise RuntimeError("Not connected to Gemini and reconnect failed.")
async with self._session_lock:
self._owner = owner
try:
return await self._send_text_inner(text)
except websockets.exceptions.ConnectionClosed:
log.warning("send_text: connection died on send — reconnecting once")
self._connected = False
if not await self._ensure_connected():
raise RuntimeError("Reconnect after send failure also failed.")
return await self._send_text_inner(text)
finally:
self._owner = None
async def _send_text_inner(self, text: str) -> tuple[bytes, list[str]]:
"""Inner send/receive loop — caller must hold _session_lock."""
request = {
"client_content": {
"turns": [{"role": "user", "parts": [{"text": text}]}],
"turn_complete": True,
}
}
async with self._send_lock:
await self._ws.send(json.dumps(request))
audio_chunks: list[bytes] = []
text_parts: list[str] = []
while True:
try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=GEMINI_WS_TIMEOUT)
except asyncio.TimeoutError:
log.warning("send_text: recv timed out")
break
except websockets.exceptions.ConnectionClosed:
log.warning("send_text: connection closed mid-stream")
self._connected = False
break
try:
resp = json.loads(raw)
except json.JSONDecodeError:
log.warning("send_text: bad JSON from server")
continue
if "error" in resp:
log.error("Gemini error: %s", resp["error"])
await bus.emit("voice.error", reason=str(resp["error"]))
break
sc = resp.get("serverContent", {})
mt = sc.get("modelTurn", {})
for part in mt.get("parts", []):
inline = part.get("inlineData")
if inline and inline.get("data"):
audio_chunks.append(base64.b64decode(inline["data"]))
tp = part.get("text")
if isinstance(tp, str) and tp.strip():
text_parts.append(tp.strip())
input_tr = sc.get("inputTranscription", {})
if input_tr.get("text"):
await bus.emit("voice.user_said", text=input_tr["text"])
if sc.get("turnComplete") or sc.get("generationComplete"):
break
audio_bytes = b"".join(audio_chunks)
if audio_bytes:
await bus.emit("voice.gemini_spoke", audio_len=len(audio_bytes))
return audio_bytes, text_parts
def acquire_session(self, owner: str) -> "_SessionGuard":
"""Return an async context manager for exclusive session ownership.
Use as `async with client.acquire_session("live_voice"):`.
While held, no other consumer may call send_text or receive_stream.
"""
return _SessionGuard(self, owner)
async def receive_stream(self):
"""Yield server events. Caller MUST hold the session lock."""
if self._owner is None:
raise RuntimeError(
"receive_stream requires session lock — use acquire_session() first"
)
if not self._connected or self._ws is None:
return
try:
async for raw in self._ws:
try:
resp = json.loads(raw)
except json.JSONDecodeError:
continue
yield resp.get("serverContent", {})
except websockets.exceptions.ConnectionClosed:
log.warning("receive_stream: connection closed")
self._connected = False
await bus.emit("voice.error", reason="connection_closed")
async def raw_send(self, payload: dict):
"""Low-level send for the live loop. Always use through send lock."""
if not self._connected or self._ws is None:
return False
try:
async with self._send_lock:
await self._ws.send(json.dumps(payload))
return True
except Exception:
log.exception("raw_send failed")
return False
def status(self) -> dict[str, Any]:
return {
"connected": self._connected,
"model": GEMINI_MODEL,
"voice": GEMINI_VOICE,
"session_owner": self._owner,
"reconnect_attempts": self._reconnect_attempts,
}
class _SessionGuard:
"""Async context manager for exclusive session ownership.
Always releases owner + lock on exit, even on exceptions.
"""
def __init__(self, client: GeminiVoiceClient, owner: str):
self._client = client
self._owner = owner
self._held = False
async def __aenter__(self):
await self._client._session_lock.acquire()
self._held = True
self._client._owner = self._owner
return self._client
async def __aexit__(self, exc_type, exc, tb):
try:
self._client._owner = None
finally:
if self._held:
self._client._session_lock.release()
self._held = False
return False # don't suppress exceptions