371 lines
15 KiB
Python
371 lines
15 KiB
Python
"""Gemini brain — live conversation loop using the google-genai SDK.
|
|
|
|
Implements the VoiceBrain contract documented in `voice/model_script.py`:
|
|
|
|
__init__(audio_io, recorder, voice_name, system_prompt)
|
|
async run()
|
|
stop()
|
|
|
|
Owns everything Gemini-specific: the `genai.Client`, `LiveConnectConfig`,
|
|
the session connect/receive loop, VAD-based barge-in, echo suppression,
|
|
reconnect backoff. Hardware I/O is delegated to `audio_io` and per-turn
|
|
WAV capture to `recorder` — both are model-agnostic.
|
|
|
|
Env overrides:
|
|
SANAD_GEMINI_MODEL — Gemini Live model id (without "models/" prefix)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import array
|
|
import asyncio
|
|
import os
|
|
import time
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
|
|
from Project.Sanad.config import (
|
|
CHUNK_SIZE,
|
|
GEMINI_API_KEY,
|
|
GEMINI_VOICE,
|
|
RECEIVE_SAMPLE_RATE,
|
|
SEND_SAMPLE_RATE,
|
|
)
|
|
from Project.Sanad.core.config_loader import section as _cfg_section
|
|
from Project.Sanad.core.logger import get_logger
|
|
|
|
log = get_logger("gemini_brain")
|
|
|
|
_SV = _cfg_section("voice", "sanad_voice")
|
|
_VAD = _cfg_section("voice", "vad")
|
|
_BI = _cfg_section("voice", "barge_in")
|
|
|
|
_MODEL = os.environ.get(
|
|
"SANAD_GEMINI_MODEL",
|
|
"gemini-2.5-flash-native-audio-preview-12-2025",
|
|
)
|
|
_MIC_GAIN = _SV.get("mic_gain", 1.0)
|
|
_SESSION_TIMEOUT = _SV.get("session_timeout_sec", 660)
|
|
_MAX_RECONNECT_DELAY = _SV.get("max_reconnect_delay_sec", 30)
|
|
_MAX_CONSECUTIVE_ERRORS = _SV.get("max_consecutive_errors", 10)
|
|
_NO_MESSAGES_TIMEOUT = _SV.get("no_messages_timeout_sec", 30)
|
|
|
|
_CHUNK_BYTES = CHUNK_SIZE * 2
|
|
_SILENCE_PCM = b"\x00" * _CHUNK_BYTES
|
|
|
|
|
|
def _audio_energy(pcm: bytes) -> int:
|
|
try:
|
|
samples = array.array("h", pcm)
|
|
return sum(abs(s) for s in samples) // len(samples) if samples else 0
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
class GeminiBrain:
|
|
"""Gemini Live conversation brain — reconnect-safe."""
|
|
|
|
def __init__(self, audio_io, recorder, voice_name: Optional[str] = None,
|
|
system_prompt: str = ""):
|
|
self._audio = audio_io
|
|
self._mic = audio_io.mic
|
|
self._speaker = audio_io.speaker
|
|
self._recorder = recorder
|
|
self._voice = voice_name or GEMINI_VOICE
|
|
self._system_prompt = system_prompt
|
|
self._api_key = GEMINI_API_KEY
|
|
self._stop_flag = asyncio.Event()
|
|
# per-session state (reset in the outer reconnect loop)
|
|
self._speaking = False
|
|
self._stream_started = False
|
|
self._barge_block_until = 0.0
|
|
self._ai_speak_start = 0.0
|
|
self._last_ai_audio = 0.0
|
|
self._done: Optional[asyncio.Event] = None
|
|
|
|
def stop(self) -> None:
|
|
"""Signal the run loop to exit at the next opportunity."""
|
|
try:
|
|
self._stop_flag.set()
|
|
except Exception:
|
|
pass
|
|
|
|
# ─── public entry point ───────────────────────────────
|
|
|
|
async def run(self) -> None:
|
|
client = genai.Client(api_key=self._api_key)
|
|
config = self._build_config()
|
|
session_num = 0
|
|
start_time = time.time()
|
|
consecutive_errors = 0
|
|
|
|
while not self._stop_flag.is_set():
|
|
session_num += 1
|
|
self._reset_turn_state()
|
|
uptime_min = (time.time() - start_time) / 60
|
|
|
|
try:
|
|
log.info("connecting to Gemini (session #%d, uptime %.0fm)...",
|
|
session_num, uptime_min)
|
|
async with client.aio.live.connect(model=_MODEL, config=config) as session:
|
|
log.info("connected — speak anytime!")
|
|
consecutive_errors = 0
|
|
self._mic.flush()
|
|
self._done = asyncio.Event()
|
|
|
|
try:
|
|
await asyncio.wait_for(
|
|
asyncio.gather(
|
|
self._send_mic_loop(session),
|
|
self._receive_loop(session),
|
|
),
|
|
timeout=_SESSION_TIMEOUT,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
log.warning("session timed out after %ds", _SESSION_TIMEOUT)
|
|
except asyncio.CancelledError:
|
|
log.warning("session cancelled")
|
|
|
|
log.info("session #%d ended — reconnecting in 1s", session_num)
|
|
self._speaker.stop()
|
|
self._mic.flush()
|
|
await asyncio.sleep(1)
|
|
|
|
except asyncio.CancelledError:
|
|
log.info("cancelled — stopping")
|
|
break
|
|
except KeyboardInterrupt:
|
|
log.info("keyboard interrupt — stopping")
|
|
break
|
|
except Exception as exc:
|
|
consecutive_errors += 1
|
|
delay = min(_MAX_RECONNECT_DELAY, 2 ** consecutive_errors)
|
|
log.error("session error (#%d): %s — reconnecting in %ds",
|
|
consecutive_errors, exc, delay)
|
|
await asyncio.sleep(delay)
|
|
if consecutive_errors >= _MAX_CONSECUTIVE_ERRORS:
|
|
log.warning("%d consecutive errors — recreating client",
|
|
consecutive_errors)
|
|
try:
|
|
client = genai.Client(api_key=self._api_key)
|
|
consecutive_errors = 0
|
|
except Exception as ce:
|
|
log.error("client recreation failed: %s", ce)
|
|
|
|
# ─── Gemini config ────────────────────────────────────
|
|
|
|
def _build_config(self) -> types.LiveConnectConfig:
|
|
return types.LiveConnectConfig(
|
|
response_modalities=["AUDIO"],
|
|
speech_config=types.SpeechConfig(
|
|
voice_config=types.VoiceConfig(
|
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
voice_name=self._voice,
|
|
),
|
|
),
|
|
),
|
|
realtime_input_config=types.RealtimeInputConfig(
|
|
automatic_activity_detection=types.AutomaticActivityDetection(
|
|
disabled=False,
|
|
start_of_speech_sensitivity=getattr(
|
|
types.StartSensitivity,
|
|
_VAD.get("start_sensitivity", "START_SENSITIVITY_HIGH"),
|
|
),
|
|
end_of_speech_sensitivity=getattr(
|
|
types.EndSensitivity,
|
|
_VAD.get("end_sensitivity", "END_SENSITIVITY_LOW"),
|
|
),
|
|
prefix_padding_ms=_VAD.get("prefix_padding_ms", 20),
|
|
silence_duration_ms=_VAD.get("silence_duration_ms", 200),
|
|
),
|
|
),
|
|
input_audio_transcription=types.AudioTranscriptionConfig(),
|
|
output_audio_transcription=types.AudioTranscriptionConfig(),
|
|
system_instruction=types.Content(
|
|
parts=[types.Part(text=self._system_prompt)],
|
|
),
|
|
)
|
|
|
|
# ─── state helpers ────────────────────────────────────
|
|
|
|
def _reset_turn_state(self) -> None:
|
|
self._speaking = False
|
|
self._stream_started = False
|
|
self._barge_block_until = 0.0
|
|
self._ai_speak_start = 0.0
|
|
self._last_ai_audio = 0.0
|
|
|
|
def _interrupt(self, source: str = "local") -> None:
|
|
self._speaking = False
|
|
self._stream_started = False
|
|
self._speaker.stop()
|
|
self._mic.flush()
|
|
self._recorder.finish_turn()
|
|
log.info("interrupt (%s)", source)
|
|
|
|
# ─── mic send loop ────────────────────────────────────
|
|
|
|
async def _send_mic_loop(self, session: Any) -> None:
|
|
threshold = _BI.get("threshold", 500)
|
|
chunks_needed = _BI.get("loud_chunks_needed", 3)
|
|
cooldown = _BI.get("cooldown_sec", 0.3)
|
|
echo_suppress_below = _BI.get("echo_suppress_below", 500)
|
|
grace = _BI.get("ai_speak_grace_sec", 0.15)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
loud_count = 0
|
|
last_activity = time.time()
|
|
|
|
while not self._done.is_set() and not self._stop_flag.is_set():
|
|
try:
|
|
raw = await loop.run_in_executor(
|
|
None, self._mic.read_chunk, _CHUNK_BYTES,
|
|
)
|
|
except Exception:
|
|
break
|
|
|
|
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32)
|
|
samples = np.clip(samples * _MIC_GAIN, -32768, 32767).astype(np.int16)
|
|
data = samples.tobytes()
|
|
energy = _audio_energy(data)
|
|
now = time.time()
|
|
|
|
# Barge-in: after AI starts speaking, sustained user energy cuts it.
|
|
if self._speaking and now >= self._barge_block_until:
|
|
if (now - self._ai_speak_start) >= grace:
|
|
if energy > threshold:
|
|
loud_count += 1
|
|
else:
|
|
loud_count = max(0, loud_count - 1)
|
|
if loud_count > chunks_needed:
|
|
log.info("BARGE-IN (e=%d)", energy)
|
|
self._interrupt("barge-in")
|
|
loud_count = 0
|
|
self._barge_block_until = now + cooldown
|
|
|
|
# Echo suppression: while AI is speaking, mask quiet frames so the
|
|
# mic doesn't feed the model its own voice bleed.
|
|
send_data = data
|
|
if self._speaking and energy < echo_suppress_below:
|
|
send_data = _SILENCE_PCM
|
|
|
|
# Record user audio when clearly speaking and AI isn't.
|
|
if energy > 250 and not self._speaking:
|
|
self._recorder.capture_user(data)
|
|
|
|
# Keep-alive watchdog
|
|
if energy > 250:
|
|
last_activity = now
|
|
elif now - last_activity > 10:
|
|
log.info("alive (no speech %.0fs, e=%d)",
|
|
now - last_activity, energy)
|
|
last_activity = now
|
|
|
|
try:
|
|
await session.send_realtime_input(
|
|
audio=types.Blob(
|
|
data=send_data,
|
|
mime_type=f"audio/pcm;rate={SEND_SAMPLE_RATE}",
|
|
),
|
|
)
|
|
except asyncio.CancelledError:
|
|
return
|
|
except Exception as exc:
|
|
log.warning("mic send failed: %s — ending session", exc)
|
|
self._done.set()
|
|
return
|
|
|
|
await asyncio.sleep(CHUNK_SIZE / SEND_SAMPLE_RATE)
|
|
|
|
log.info("send_mic task ended")
|
|
|
|
# ─── receive loop ─────────────────────────────────────
|
|
|
|
async def _receive_loop(self, session: Any) -> None:
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
last_recv = time.time()
|
|
while not self._done.is_set() and not self._stop_flag.is_set():
|
|
async for response in session.receive():
|
|
last_recv = time.time()
|
|
if self._done.is_set():
|
|
break
|
|
|
|
if hasattr(response, "go_away") and response.go_away is not None:
|
|
log.info("server going away — will reconnect")
|
|
self._done.set()
|
|
return
|
|
|
|
sc = response.server_content
|
|
if sc is None:
|
|
continue
|
|
|
|
if sc.interrupted is True:
|
|
if self._speaking:
|
|
log.info("Gemini interrupted")
|
|
self._interrupt("gemini")
|
|
continue
|
|
|
|
if sc.input_transcription:
|
|
text = (sc.input_transcription.text or "").strip()
|
|
if text and not self._speaking:
|
|
log.info("USER: %s", text)
|
|
self._recorder.add_user_text(text)
|
|
|
|
if sc.output_transcription:
|
|
text = (sc.output_transcription.text or "").strip()
|
|
if text:
|
|
log.info("BOT : %s", text)
|
|
self._recorder.add_robot_text(text)
|
|
|
|
if sc.model_turn:
|
|
for part in sc.model_turn.parts:
|
|
if part.inline_data and part.inline_data.data:
|
|
now = time.time()
|
|
if not self._speaking:
|
|
self._ai_speak_start = now
|
|
self._speaking = True
|
|
self._last_ai_audio = now
|
|
raw_audio = part.inline_data.data
|
|
self._recorder.capture_robot(raw_audio)
|
|
audio = np.frombuffer(raw_audio, dtype=np.int16)
|
|
if not self._stream_started:
|
|
await loop.run_in_executor(
|
|
None, self._speaker.begin_stream,
|
|
)
|
|
self._stream_started = True
|
|
await loop.run_in_executor(
|
|
None, self._speaker.send_chunk,
|
|
audio, RECEIVE_SAMPLE_RATE,
|
|
)
|
|
|
|
if sc.turn_complete:
|
|
if (self._speaking and self._stream_started
|
|
and not self._speaker.interrupted):
|
|
log.info("speaker %.1fs", self._speaker.total_sent_sec)
|
|
await loop.run_in_executor(
|
|
None, self._speaker.wait_finish,
|
|
)
|
|
elif self._speaking and self._speaker.interrupted:
|
|
log.info("speaker interrupted")
|
|
self._speaking = False
|
|
self._stream_started = False
|
|
self._mic.flush()
|
|
self._recorder.finish_turn()
|
|
log.info("listening")
|
|
|
|
if time.time() - last_recv > _NO_MESSAGES_TIMEOUT:
|
|
log.warning("no messages from Gemini for %ds — session dead",
|
|
_NO_MESSAGES_TIMEOUT)
|
|
break
|
|
await asyncio.sleep(0.1)
|
|
|
|
except Exception as exc:
|
|
log.warning("receive ended: %s", exc)
|
|
finally:
|
|
self._done.set()
|