342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""Gemini WebSocket client for real-time voice interaction.
|
|
|
|
Provides:
|
|
- Bidirectional audio streaming (mic → Gemini → speaker)
|
|
- Text-to-speech via typed input
|
|
- Voice-command detection through transcription parsing
|
|
- System instruction injection for persona control
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import inspect
|
|
import json
|
|
from typing import Any
|
|
|
|
import websockets
|
|
|
|
from Project.Sanad.config import (
|
|
GEMINI_API_KEY,
|
|
GEMINI_MODEL,
|
|
GEMINI_VOICE,
|
|
GEMINI_WS_TIMEOUT,
|
|
GEMINI_WS_URI,
|
|
)
|
|
from Project.Sanad.core.config_loader import section as _cfg_section
|
|
from Project.Sanad.core.event_bus import bus
|
|
from Project.Sanad.core.logger import get_logger
|
|
|
|
log = get_logger("gemini_client")
|
|
|
|
_GC = _cfg_section("gemini", "client")
|
|
# Default system prompt — SINGLE SOURCE in core.gemini_defaults
|
|
_DEFAULT_SYSTEM_PROMPT = _cfg_section("core", "gemini_defaults").get(
|
|
"default_system_prompt",
|
|
"You are Sanad (Bousandah), a wise and friendly Emirati assistant. "
|
|
"Speak in UAE dialect (Khaleeji). Be helpful and concise."
|
|
)
|
|
_RECV_TIMEOUT_SEC = _GC.get("recv_timeout_sec", 30)
|
|
_RECONNECT_MAX_ATTEMPTS = _GC.get("reconnect_max_attempts", 3)
|
|
_RECONNECT_INITIAL_DELAY_SEC = _GC.get("reconnect_initial_delay_sec", 1.0)
|
|
_RECONNECT_MAX_DELAY_SEC = _GC.get("reconnect_max_delay_sec", 10.0)
|
|
|
|
|
|
class GeminiVoiceClient:
|
|
"""Manages one WebSocket session to the Gemini Bidi audio API.
|
|
|
|
Concurrency model:
|
|
- `_send_lock` serializes ALL websocket writes.
|
|
- `_session_lock` ensures only one consumer (live loop OR typed replay)
|
|
owns the receive stream at a time. Acquired by send_text and
|
|
receive_stream context managers.
|
|
- `_owner` records who currently holds the session lock for diagnostics.
|
|
"""
|
|
|
|
def __init__(self, system_prompt: str = ""):
|
|
self.system_prompt = system_prompt or _DEFAULT_SYSTEM_PROMPT
|
|
self._ws: Any = None
|
|
self._connected = False
|
|
self._send_lock = asyncio.Lock()
|
|
self._session_lock = asyncio.Lock()
|
|
self._connect_lock = asyncio.Lock() # serializes reconnect attempts
|
|
self._owner: str | None = None
|
|
self._reconnect_attempts = 0
|
|
|
|
@property
|
|
def connected(self) -> bool:
|
|
return self._connected
|
|
|
|
@property
|
|
def session_owner(self) -> str | None:
|
|
return self._owner
|
|
|
|
def _ws_kwargs(self) -> dict[str, Any]:
|
|
kwargs: dict[str, Any] = {"max_size": None, "open_timeout": 30}
|
|
try:
|
|
sig = inspect.signature(websockets.connect)
|
|
key = "extra_headers" if "extra_headers" in sig.parameters else "additional_headers"
|
|
except Exception:
|
|
key = "extra_headers"
|
|
kwargs[key] = {"Content-Type": "application/json"}
|
|
return kwargs
|
|
|
|
async def connect(self):
|
|
uri = f"{GEMINI_WS_URI}?key={GEMINI_API_KEY}"
|
|
try:
|
|
self._ws = await websockets.connect(uri, **self._ws_kwargs())
|
|
setup = {
|
|
"setup": {
|
|
"model": GEMINI_MODEL,
|
|
"generationConfig": {
|
|
"responseModalities": ["AUDIO"],
|
|
"speechConfig": {
|
|
"voiceConfig": {
|
|
"prebuiltVoiceConfig": {"voiceName": GEMINI_VOICE}
|
|
}
|
|
},
|
|
},
|
|
"systemInstruction": {"parts": [{"text": self.system_prompt}]},
|
|
}
|
|
}
|
|
await self._ws.send(json.dumps(setup))
|
|
await self._ws.recv() # ACK
|
|
self._connected = True
|
|
self._reconnect_attempts = 0
|
|
log.info("Connected to Gemini (%s)", GEMINI_MODEL)
|
|
await bus.emit("voice.connected")
|
|
except Exception:
|
|
self._connected = False
|
|
self._ws = None
|
|
log.exception("Failed to connect to Gemini")
|
|
raise
|
|
|
|
async def disconnect(self):
|
|
try:
|
|
if self._ws is not None:
|
|
await self._ws.close()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
self._ws = None
|
|
self._connected = False
|
|
self._owner = None
|
|
log.info("Disconnected from Gemini")
|
|
await bus.emit("voice.disconnected")
|
|
|
|
async def _ensure_connected(self):
|
|
"""Reconnect if dropped, with bounded retries.
|
|
|
|
Serialized via _connect_lock so concurrent callers don't trigger
|
|
duplicate handshakes.
|
|
"""
|
|
# Fast path — no lock needed
|
|
if self._connected and self._ws is not None:
|
|
return True
|
|
|
|
async with self._connect_lock:
|
|
# Re-check inside the lock (another coroutine may have just connected)
|
|
if self._connected and self._ws is not None:
|
|
return True
|
|
|
|
max_attempts = _RECONNECT_MAX_ATTEMPTS
|
|
delay = _RECONNECT_INITIAL_DELAY_SEC
|
|
for attempt in range(max_attempts):
|
|
try:
|
|
log.warning("Reconnecting to Gemini (attempt %d/%d)", attempt + 1, max_attempts)
|
|
await self.connect()
|
|
return True
|
|
except Exception:
|
|
self._reconnect_attempts += 1
|
|
await asyncio.sleep(delay)
|
|
delay = min(delay * 2, _RECONNECT_MAX_DELAY_SEC)
|
|
log.error("Reconnect failed after %d attempts", max_attempts)
|
|
await bus.emit("voice.error", reason="reconnect_failed")
|
|
return False
|
|
|
|
async def send_audio_chunk(self, pcm_b64: str) -> bool:
|
|
"""Send a base64-encoded PCM audio chunk (mic input).
|
|
|
|
Returns False on failure so the caller can react instead of silently
|
|
no-op'ing forever (the original bug).
|
|
"""
|
|
if not self._connected or self._ws is None:
|
|
return False
|
|
msg = {
|
|
"realtimeInput": {
|
|
"mediaChunks": [
|
|
{"mimeType": "audio/pcm;rate=16000", "data": pcm_b64}
|
|
]
|
|
}
|
|
}
|
|
try:
|
|
async with self._send_lock:
|
|
await self._ws.send(json.dumps(msg))
|
|
return True
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log.warning("send_audio_chunk: connection closed")
|
|
self._connected = False
|
|
await bus.emit("voice.error", reason="connection_closed")
|
|
return False
|
|
except Exception:
|
|
log.exception("send_audio_chunk failed")
|
|
return False
|
|
|
|
async def send_text(self, text: str, owner: str = "send_text") -> tuple[bytes, list[str]]:
|
|
"""Send text, receive audio response. Returns (audio_bytes, text_parts).
|
|
|
|
Acquires the session lock for the entire request/response cycle so
|
|
no other consumer can steal frames from the receive side.
|
|
If the connection drops mid-request, reconnects once and retries.
|
|
"""
|
|
if not await self._ensure_connected():
|
|
raise RuntimeError("Not connected to Gemini and reconnect failed.")
|
|
|
|
async with self._session_lock:
|
|
self._owner = owner
|
|
try:
|
|
return await self._send_text_inner(text)
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log.warning("send_text: connection died on send — reconnecting once")
|
|
self._connected = False
|
|
if not await self._ensure_connected():
|
|
raise RuntimeError("Reconnect after send failure also failed.")
|
|
return await self._send_text_inner(text)
|
|
finally:
|
|
self._owner = None
|
|
|
|
async def _send_text_inner(self, text: str) -> tuple[bytes, list[str]]:
|
|
"""Inner send/receive loop — caller must hold _session_lock."""
|
|
request = {
|
|
"client_content": {
|
|
"turns": [{"role": "user", "parts": [{"text": text}]}],
|
|
"turn_complete": True,
|
|
}
|
|
}
|
|
async with self._send_lock:
|
|
await self._ws.send(json.dumps(request))
|
|
|
|
audio_chunks: list[bytes] = []
|
|
text_parts: list[str] = []
|
|
|
|
while True:
|
|
try:
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=GEMINI_WS_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
log.warning("send_text: recv timed out")
|
|
break
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log.warning("send_text: connection closed mid-stream")
|
|
self._connected = False
|
|
break
|
|
|
|
try:
|
|
resp = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
log.warning("send_text: bad JSON from server")
|
|
continue
|
|
|
|
if "error" in resp:
|
|
log.error("Gemini error: %s", resp["error"])
|
|
await bus.emit("voice.error", reason=str(resp["error"]))
|
|
break
|
|
|
|
sc = resp.get("serverContent", {})
|
|
mt = sc.get("modelTurn", {})
|
|
for part in mt.get("parts", []):
|
|
inline = part.get("inlineData")
|
|
if inline and inline.get("data"):
|
|
audio_chunks.append(base64.b64decode(inline["data"]))
|
|
tp = part.get("text")
|
|
if isinstance(tp, str) and tp.strip():
|
|
text_parts.append(tp.strip())
|
|
|
|
input_tr = sc.get("inputTranscription", {})
|
|
if input_tr.get("text"):
|
|
await bus.emit("voice.user_said", text=input_tr["text"])
|
|
|
|
if sc.get("turnComplete") or sc.get("generationComplete"):
|
|
break
|
|
|
|
audio_bytes = b"".join(audio_chunks)
|
|
if audio_bytes:
|
|
await bus.emit("voice.gemini_spoke", audio_len=len(audio_bytes))
|
|
return audio_bytes, text_parts
|
|
|
|
def acquire_session(self, owner: str) -> "_SessionGuard":
|
|
"""Return an async context manager for exclusive session ownership.
|
|
|
|
Use as `async with client.acquire_session("live_voice"):`.
|
|
While held, no other consumer may call send_text or receive_stream.
|
|
"""
|
|
return _SessionGuard(self, owner)
|
|
|
|
async def receive_stream(self):
|
|
"""Yield server events. Caller MUST hold the session lock."""
|
|
if self._owner is None:
|
|
raise RuntimeError(
|
|
"receive_stream requires session lock — use acquire_session() first"
|
|
)
|
|
if not self._connected or self._ws is None:
|
|
return
|
|
try:
|
|
async for raw in self._ws:
|
|
try:
|
|
resp = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
yield resp.get("serverContent", {})
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log.warning("receive_stream: connection closed")
|
|
self._connected = False
|
|
await bus.emit("voice.error", reason="connection_closed")
|
|
|
|
async def raw_send(self, payload: dict):
|
|
"""Low-level send for the live loop. Always use through send lock."""
|
|
if not self._connected or self._ws is None:
|
|
return False
|
|
try:
|
|
async with self._send_lock:
|
|
await self._ws.send(json.dumps(payload))
|
|
return True
|
|
except Exception:
|
|
log.exception("raw_send failed")
|
|
return False
|
|
|
|
def status(self) -> dict[str, Any]:
|
|
return {
|
|
"connected": self._connected,
|
|
"model": GEMINI_MODEL,
|
|
"voice": GEMINI_VOICE,
|
|
"session_owner": self._owner,
|
|
"reconnect_attempts": self._reconnect_attempts,
|
|
}
|
|
|
|
|
|
class _SessionGuard:
|
|
"""Async context manager for exclusive session ownership.
|
|
|
|
Always releases owner + lock on exit, even on exceptions.
|
|
"""
|
|
|
|
def __init__(self, client: GeminiVoiceClient, owner: str):
|
|
self._client = client
|
|
self._owner = owner
|
|
self._held = False
|
|
|
|
async def __aenter__(self):
|
|
await self._client._session_lock.acquire()
|
|
self._held = True
|
|
self._client._owner = self._owner
|
|
return self._client
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
try:
|
|
self._client._owner = None
|
|
finally:
|
|
if self._held:
|
|
self._client._session_lock.release()
|
|
self._held = False
|
|
return False # don't suppress exceptions
|