Sanad/gemini/script.py

371 lines
15 KiB
Python

"""Gemini brain — live conversation loop using the google-genai SDK.
Implements the VoiceBrain contract documented in `voice/model_script.py`:
__init__(audio_io, recorder, voice_name, system_prompt)
async run()
stop()
Owns everything Gemini-specific: the `genai.Client`, `LiveConnectConfig`,
the session connect/receive loop, VAD-based barge-in, echo suppression,
reconnect backoff. Hardware I/O is delegated to `audio_io` and per-turn
WAV capture to `recorder` — both are model-agnostic.
Env overrides:
SANAD_GEMINI_MODEL — Gemini Live model id (without "models/" prefix)
"""
from __future__ import annotations
import array
import asyncio
import os
import time
from typing import Any, Optional
import numpy as np
from google import genai
from google.genai import types
from Project.Sanad.config import (
CHUNK_SIZE,
GEMINI_API_KEY,
GEMINI_VOICE,
RECEIVE_SAMPLE_RATE,
SEND_SAMPLE_RATE,
)
from Project.Sanad.core.config_loader import section as _cfg_section
from Project.Sanad.core.logger import get_logger
log = get_logger("gemini_brain")
_SV = _cfg_section("voice", "sanad_voice")
_VAD = _cfg_section("voice", "vad")
_BI = _cfg_section("voice", "barge_in")
_MODEL = os.environ.get(
"SANAD_GEMINI_MODEL",
"gemini-2.5-flash-native-audio-preview-12-2025",
)
_MIC_GAIN = _SV.get("mic_gain", 1.0)
_SESSION_TIMEOUT = _SV.get("session_timeout_sec", 660)
_MAX_RECONNECT_DELAY = _SV.get("max_reconnect_delay_sec", 30)
_MAX_CONSECUTIVE_ERRORS = _SV.get("max_consecutive_errors", 10)
_NO_MESSAGES_TIMEOUT = _SV.get("no_messages_timeout_sec", 30)
_CHUNK_BYTES = CHUNK_SIZE * 2
_SILENCE_PCM = b"\x00" * _CHUNK_BYTES
def _audio_energy(pcm: bytes) -> int:
try:
samples = array.array("h", pcm)
return sum(abs(s) for s in samples) // len(samples) if samples else 0
except Exception:
return 0
class GeminiBrain:
"""Gemini Live conversation brain — reconnect-safe."""
def __init__(self, audio_io, recorder, voice_name: Optional[str] = None,
system_prompt: str = ""):
self._audio = audio_io
self._mic = audio_io.mic
self._speaker = audio_io.speaker
self._recorder = recorder
self._voice = voice_name or GEMINI_VOICE
self._system_prompt = system_prompt
self._api_key = GEMINI_API_KEY
self._stop_flag = asyncio.Event()
# per-session state (reset in the outer reconnect loop)
self._speaking = False
self._stream_started = False
self._barge_block_until = 0.0
self._ai_speak_start = 0.0
self._last_ai_audio = 0.0
self._done: Optional[asyncio.Event] = None
def stop(self) -> None:
"""Signal the run loop to exit at the next opportunity."""
try:
self._stop_flag.set()
except Exception:
pass
# ─── public entry point ───────────────────────────────
async def run(self) -> None:
client = genai.Client(api_key=self._api_key)
config = self._build_config()
session_num = 0
start_time = time.time()
consecutive_errors = 0
while not self._stop_flag.is_set():
session_num += 1
self._reset_turn_state()
uptime_min = (time.time() - start_time) / 60
try:
log.info("connecting to Gemini (session #%d, uptime %.0fm)...",
session_num, uptime_min)
async with client.aio.live.connect(model=_MODEL, config=config) as session:
log.info("connected — speak anytime!")
consecutive_errors = 0
self._mic.flush()
self._done = asyncio.Event()
try:
await asyncio.wait_for(
asyncio.gather(
self._send_mic_loop(session),
self._receive_loop(session),
),
timeout=_SESSION_TIMEOUT,
)
except asyncio.TimeoutError:
log.warning("session timed out after %ds", _SESSION_TIMEOUT)
except asyncio.CancelledError:
log.warning("session cancelled")
log.info("session #%d ended — reconnecting in 1s", session_num)
self._speaker.stop()
self._mic.flush()
await asyncio.sleep(1)
except asyncio.CancelledError:
log.info("cancelled — stopping")
break
except KeyboardInterrupt:
log.info("keyboard interrupt — stopping")
break
except Exception as exc:
consecutive_errors += 1
delay = min(_MAX_RECONNECT_DELAY, 2 ** consecutive_errors)
log.error("session error (#%d): %s — reconnecting in %ds",
consecutive_errors, exc, delay)
await asyncio.sleep(delay)
if consecutive_errors >= _MAX_CONSECUTIVE_ERRORS:
log.warning("%d consecutive errors — recreating client",
consecutive_errors)
try:
client = genai.Client(api_key=self._api_key)
consecutive_errors = 0
except Exception as ce:
log.error("client recreation failed: %s", ce)
# ─── Gemini config ────────────────────────────────────
def _build_config(self) -> types.LiveConnectConfig:
return types.LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=self._voice,
),
),
),
realtime_input_config=types.RealtimeInputConfig(
automatic_activity_detection=types.AutomaticActivityDetection(
disabled=False,
start_of_speech_sensitivity=getattr(
types.StartSensitivity,
_VAD.get("start_sensitivity", "START_SENSITIVITY_HIGH"),
),
end_of_speech_sensitivity=getattr(
types.EndSensitivity,
_VAD.get("end_sensitivity", "END_SENSITIVITY_LOW"),
),
prefix_padding_ms=_VAD.get("prefix_padding_ms", 20),
silence_duration_ms=_VAD.get("silence_duration_ms", 200),
),
),
input_audio_transcription=types.AudioTranscriptionConfig(),
output_audio_transcription=types.AudioTranscriptionConfig(),
system_instruction=types.Content(
parts=[types.Part(text=self._system_prompt)],
),
)
# ─── state helpers ────────────────────────────────────
def _reset_turn_state(self) -> None:
self._speaking = False
self._stream_started = False
self._barge_block_until = 0.0
self._ai_speak_start = 0.0
self._last_ai_audio = 0.0
def _interrupt(self, source: str = "local") -> None:
self._speaking = False
self._stream_started = False
self._speaker.stop()
self._mic.flush()
self._recorder.finish_turn()
log.info("interrupt (%s)", source)
# ─── mic send loop ────────────────────────────────────
async def _send_mic_loop(self, session: Any) -> None:
threshold = _BI.get("threshold", 500)
chunks_needed = _BI.get("loud_chunks_needed", 3)
cooldown = _BI.get("cooldown_sec", 0.3)
echo_suppress_below = _BI.get("echo_suppress_below", 500)
grace = _BI.get("ai_speak_grace_sec", 0.15)
loop = asyncio.get_event_loop()
loud_count = 0
last_activity = time.time()
while not self._done.is_set() and not self._stop_flag.is_set():
try:
raw = await loop.run_in_executor(
None, self._mic.read_chunk, _CHUNK_BYTES,
)
except Exception:
break
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32)
samples = np.clip(samples * _MIC_GAIN, -32768, 32767).astype(np.int16)
data = samples.tobytes()
energy = _audio_energy(data)
now = time.time()
# Barge-in: after AI starts speaking, sustained user energy cuts it.
if self._speaking and now >= self._barge_block_until:
if (now - self._ai_speak_start) >= grace:
if energy > threshold:
loud_count += 1
else:
loud_count = max(0, loud_count - 1)
if loud_count > chunks_needed:
log.info("BARGE-IN (e=%d)", energy)
self._interrupt("barge-in")
loud_count = 0
self._barge_block_until = now + cooldown
# Echo suppression: while AI is speaking, mask quiet frames so the
# mic doesn't feed the model its own voice bleed.
send_data = data
if self._speaking and energy < echo_suppress_below:
send_data = _SILENCE_PCM
# Record user audio when clearly speaking and AI isn't.
if energy > 250 and not self._speaking:
self._recorder.capture_user(data)
# Keep-alive watchdog
if energy > 250:
last_activity = now
elif now - last_activity > 10:
log.info("alive (no speech %.0fs, e=%d)",
now - last_activity, energy)
last_activity = now
try:
await session.send_realtime_input(
audio=types.Blob(
data=send_data,
mime_type=f"audio/pcm;rate={SEND_SAMPLE_RATE}",
),
)
except asyncio.CancelledError:
return
except Exception as exc:
log.warning("mic send failed: %s — ending session", exc)
self._done.set()
return
await asyncio.sleep(CHUNK_SIZE / SEND_SAMPLE_RATE)
log.info("send_mic task ended")
# ─── receive loop ─────────────────────────────────────
async def _receive_loop(self, session: Any) -> None:
loop = asyncio.get_event_loop()
try:
last_recv = time.time()
while not self._done.is_set() and not self._stop_flag.is_set():
async for response in session.receive():
last_recv = time.time()
if self._done.is_set():
break
if hasattr(response, "go_away") and response.go_away is not None:
log.info("server going away — will reconnect")
self._done.set()
return
sc = response.server_content
if sc is None:
continue
if sc.interrupted is True:
if self._speaking:
log.info("Gemini interrupted")
self._interrupt("gemini")
continue
if sc.input_transcription:
text = (sc.input_transcription.text or "").strip()
if text and not self._speaking:
log.info("USER: %s", text)
self._recorder.add_user_text(text)
if sc.output_transcription:
text = (sc.output_transcription.text or "").strip()
if text:
log.info("BOT : %s", text)
self._recorder.add_robot_text(text)
if sc.model_turn:
for part in sc.model_turn.parts:
if part.inline_data and part.inline_data.data:
now = time.time()
if not self._speaking:
self._ai_speak_start = now
self._speaking = True
self._last_ai_audio = now
raw_audio = part.inline_data.data
self._recorder.capture_robot(raw_audio)
audio = np.frombuffer(raw_audio, dtype=np.int16)
if not self._stream_started:
await loop.run_in_executor(
None, self._speaker.begin_stream,
)
self._stream_started = True
await loop.run_in_executor(
None, self._speaker.send_chunk,
audio, RECEIVE_SAMPLE_RATE,
)
if sc.turn_complete:
if (self._speaking and self._stream_started
and not self._speaker.interrupted):
log.info("speaker %.1fs", self._speaker.total_sent_sec)
await loop.run_in_executor(
None, self._speaker.wait_finish,
)
elif self._speaking and self._speaker.interrupted:
log.info("speaker interrupted")
self._speaking = False
self._stream_started = False
self._mic.flush()
self._recorder.finish_turn()
log.info("listening")
if time.time() - last_recv > _NO_MESSAGES_TIMEOUT:
log.warning("no messages from Gemini for %ds — session dead",
_NO_MESSAGES_TIMEOUT)
break
await asyncio.sleep(0.1)
except Exception as exc:
log.warning("receive ended: %s", exc)
finally:
self._done.set()