"""Gemini WebSocket client for real-time voice interaction. Provides: - Bidirectional audio streaming (mic → Gemini → speaker) - Text-to-speech via typed input - Voice-command detection through transcription parsing - System instruction injection for persona control """ from __future__ import annotations import asyncio import base64 import inspect import json from typing import Any import websockets from Project.Sanad.config import ( GEMINI_API_KEY, GEMINI_MODEL, GEMINI_VOICE, GEMINI_WS_TIMEOUT, GEMINI_WS_URI, ) from Project.Sanad.core.config_loader import section as _cfg_section from Project.Sanad.core.event_bus import bus from Project.Sanad.core.logger import get_logger log = get_logger("gemini_client") _GC = _cfg_section("voice", "gemini_client") # Default system prompt — SINGLE SOURCE in core.gemini_defaults _DEFAULT_SYSTEM_PROMPT = _cfg_section("core", "gemini_defaults").get( "default_system_prompt", "You are Sanad (Bousandah), a wise and friendly Emirati assistant. " "Speak in UAE dialect (Khaleeji). Be helpful and concise." ) _RECV_TIMEOUT_SEC = _GC.get("recv_timeout_sec", 30) _RECONNECT_MAX_ATTEMPTS = _GC.get("reconnect_max_attempts", 3) _RECONNECT_INITIAL_DELAY_SEC = _GC.get("reconnect_initial_delay_sec", 1.0) _RECONNECT_MAX_DELAY_SEC = _GC.get("reconnect_max_delay_sec", 10.0) class GeminiVoiceClient: """Manages one WebSocket session to the Gemini Bidi audio API. Concurrency model: - `_send_lock` serializes ALL websocket writes. - `_session_lock` ensures only one consumer (live loop OR typed replay) owns the receive stream at a time. Acquired by send_text and receive_stream context managers. - `_owner` records who currently holds the session lock for diagnostics. """ def __init__(self, system_prompt: str = ""): self.system_prompt = system_prompt or _DEFAULT_SYSTEM_PROMPT self._ws: Any = None self._connected = False self._send_lock = asyncio.Lock() self._session_lock = asyncio.Lock() self._connect_lock = asyncio.Lock() # serializes reconnect attempts self._owner: str | None = None self._reconnect_attempts = 0 @property def connected(self) -> bool: return self._connected @property def session_owner(self) -> str | None: return self._owner def _ws_kwargs(self) -> dict[str, Any]: kwargs: dict[str, Any] = {"max_size": None, "open_timeout": 30} try: sig = inspect.signature(websockets.connect) key = "extra_headers" if "extra_headers" in sig.parameters else "additional_headers" except Exception: key = "extra_headers" kwargs[key] = {"Content-Type": "application/json"} return kwargs async def connect(self): uri = f"{GEMINI_WS_URI}?key={GEMINI_API_KEY}" try: self._ws = await websockets.connect(uri, **self._ws_kwargs()) setup = { "setup": { "model": GEMINI_MODEL, "generationConfig": { "responseModalities": ["AUDIO"], "speechConfig": { "voiceConfig": { "prebuiltVoiceConfig": {"voiceName": GEMINI_VOICE} } }, }, "systemInstruction": {"parts": [{"text": self.system_prompt}]}, } } await self._ws.send(json.dumps(setup)) await self._ws.recv() # ACK self._connected = True self._reconnect_attempts = 0 log.info("Connected to Gemini (%s)", GEMINI_MODEL) await bus.emit("voice.connected") except Exception: self._connected = False self._ws = None log.exception("Failed to connect to Gemini") raise async def disconnect(self): try: if self._ws is not None: await self._ws.close() except Exception: pass finally: self._ws = None self._connected = False self._owner = None log.info("Disconnected from Gemini") await bus.emit("voice.disconnected") async def _ensure_connected(self): """Reconnect if dropped, with bounded retries. Serialized via _connect_lock so concurrent callers don't trigger duplicate handshakes. """ # Fast path — no lock needed if self._connected and self._ws is not None: return True async with self._connect_lock: # Re-check inside the lock (another coroutine may have just connected) if self._connected and self._ws is not None: return True max_attempts = _RECONNECT_MAX_ATTEMPTS delay = _RECONNECT_INITIAL_DELAY_SEC for attempt in range(max_attempts): try: log.warning("Reconnecting to Gemini (attempt %d/%d)", attempt + 1, max_attempts) await self.connect() return True except Exception: self._reconnect_attempts += 1 await asyncio.sleep(delay) delay = min(delay * 2, _RECONNECT_MAX_DELAY_SEC) log.error("Reconnect failed after %d attempts", max_attempts) await bus.emit("voice.error", reason="reconnect_failed") return False async def send_audio_chunk(self, pcm_b64: str) -> bool: """Send a base64-encoded PCM audio chunk (mic input). Returns False on failure so the caller can react instead of silently no-op'ing forever (the original bug). """ if not self._connected or self._ws is None: return False msg = { "realtimeInput": { "mediaChunks": [ {"mimeType": "audio/pcm;rate=16000", "data": pcm_b64} ] } } try: async with self._send_lock: await self._ws.send(json.dumps(msg)) return True except websockets.exceptions.ConnectionClosed: log.warning("send_audio_chunk: connection closed") self._connected = False await bus.emit("voice.error", reason="connection_closed") return False except Exception: log.exception("send_audio_chunk failed") return False async def send_text(self, text: str, owner: str = "send_text") -> tuple[bytes, list[str]]: """Send text, receive audio response. Returns (audio_bytes, text_parts). Acquires the session lock for the entire request/response cycle so no other consumer can steal frames from the receive side. If the connection drops mid-request, reconnects once and retries. """ if not await self._ensure_connected(): raise RuntimeError("Not connected to Gemini and reconnect failed.") async with self._session_lock: self._owner = owner try: return await self._send_text_inner(text) except websockets.exceptions.ConnectionClosed: log.warning("send_text: connection died on send — reconnecting once") self._connected = False if not await self._ensure_connected(): raise RuntimeError("Reconnect after send failure also failed.") return await self._send_text_inner(text) finally: self._owner = None async def _send_text_inner(self, text: str) -> tuple[bytes, list[str]]: """Inner send/receive loop — caller must hold _session_lock.""" request = { "client_content": { "turns": [{"role": "user", "parts": [{"text": text}]}], "turn_complete": True, } } async with self._send_lock: await self._ws.send(json.dumps(request)) audio_chunks: list[bytes] = [] text_parts: list[str] = [] while True: try: raw = await asyncio.wait_for(self._ws.recv(), timeout=GEMINI_WS_TIMEOUT) except asyncio.TimeoutError: log.warning("send_text: recv timed out") break except websockets.exceptions.ConnectionClosed: log.warning("send_text: connection closed mid-stream") self._connected = False break try: resp = json.loads(raw) except json.JSONDecodeError: log.warning("send_text: bad JSON from server") continue if "error" in resp: log.error("Gemini error: %s", resp["error"]) await bus.emit("voice.error", reason=str(resp["error"])) break sc = resp.get("serverContent", {}) mt = sc.get("modelTurn", {}) for part in mt.get("parts", []): inline = part.get("inlineData") if inline and inline.get("data"): audio_chunks.append(base64.b64decode(inline["data"])) tp = part.get("text") if isinstance(tp, str) and tp.strip(): text_parts.append(tp.strip()) input_tr = sc.get("inputTranscription", {}) if input_tr.get("text"): await bus.emit("voice.user_said", text=input_tr["text"]) if sc.get("turnComplete") or sc.get("generationComplete"): break audio_bytes = b"".join(audio_chunks) if audio_bytes: await bus.emit("voice.gemini_spoke", audio_len=len(audio_bytes)) return audio_bytes, text_parts def acquire_session(self, owner: str) -> "_SessionGuard": """Return an async context manager for exclusive session ownership. Use as `async with client.acquire_session("live_voice"):`. While held, no other consumer may call send_text or receive_stream. """ return _SessionGuard(self, owner) async def receive_stream(self): """Yield server events. Caller MUST hold the session lock.""" if self._owner is None: raise RuntimeError( "receive_stream requires session lock — use acquire_session() first" ) if not self._connected or self._ws is None: return try: async for raw in self._ws: try: resp = json.loads(raw) except json.JSONDecodeError: continue yield resp.get("serverContent", {}) except websockets.exceptions.ConnectionClosed: log.warning("receive_stream: connection closed") self._connected = False await bus.emit("voice.error", reason="connection_closed") async def raw_send(self, payload: dict): """Low-level send for the live loop. Always use through send lock.""" if not self._connected or self._ws is None: return False try: async with self._send_lock: await self._ws.send(json.dumps(payload)) return True except Exception: log.exception("raw_send failed") return False def status(self) -> dict[str, Any]: return { "connected": self._connected, "model": GEMINI_MODEL, "voice": GEMINI_VOICE, "session_owner": self._owner, "reconnect_attempts": self._reconnect_attempts, } class _SessionGuard: """Async context manager for exclusive session ownership. Always releases owner + lock on exit, even on exceptions. """ def __init__(self, client: GeminiVoiceClient, owner: str): self._client = client self._owner = owner self._held = False async def __aenter__(self): await self._client._session_lock.acquire() self._held = True self._client._owner = self._owner return self._client async def __aexit__(self, exc_type, exc, tb): try: self._client._owner = None finally: if self._held: self._client._session_lock.release() self._held = False return False # don't suppress exceptions