281 lines
11 KiB
Python
281 lines
11 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import pyaudio
|
|
import websockets
|
|
import os
|
|
import array
|
|
import time
|
|
import functools
|
|
import sys
|
|
|
|
# ==================================================
|
|
# ⚙️ CONFIGURATION
|
|
# ==================================================
|
|
API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyB8B1AkhWJSq4sNr-Pk8KsVfkxTbuV7kyo")
|
|
|
|
MODEL = "models/gemini-2.5-flash-native-audio-preview-12-2025"
|
|
URI = (
|
|
"wss://generativelanguage.googleapis.com/ws/"
|
|
"google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent"
|
|
f"?key={API_KEY}"
|
|
)
|
|
|
|
FORMAT = pyaudio.paInt16
|
|
CHANNELS = 1
|
|
SEND_SAMPLE_RATE = 16000
|
|
RECEIVE_SAMPLE_RATE = 24000
|
|
CHUNK_SIZE = 1024 # Larger chunk to prevent cutting off words
|
|
VOICE_NAME = "Charon"
|
|
|
|
# ==================================================
|
|
# LOGGER
|
|
# ==================================================
|
|
try:
|
|
from Logger import Logs
|
|
logger = Logs()
|
|
logger.LogEngine("go2_voice_logs", "Go2Voice.log")
|
|
def log(msg, mtype="info"): logger.print_and_log(msg, mtype)
|
|
except ImportError:
|
|
def log(msg, mtype="info"): print(f"[{mtype.upper()}] {msg}")
|
|
|
|
# ==================================================
|
|
# ✅ Python 3.8 Compatibility
|
|
# ==================================================
|
|
if hasattr(asyncio, "to_thread"):
|
|
to_thread = asyncio.to_thread
|
|
else:
|
|
async def to_thread(func, *args, **kwargs):
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(None, functools.partial(func, *args, **kwargs))
|
|
|
|
# ==================================================
|
|
# 🧠 System Prompt
|
|
# ==================================================
|
|
def load_system_prompt():
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
path = os.path.join(base_dir, "go2_script.txt")
|
|
try:
|
|
with open(path, "r", encoding="utf-8-sig") as f:
|
|
content = f.read().strip()
|
|
log("✅ 'Go2' persona loaded.", "info")
|
|
return content
|
|
except FileNotFoundError:
|
|
log("⚠️ Using default persona.", "warning")
|
|
return "You are a helpful robot assistant."
|
|
|
|
SYSTEM_PROMPT = load_system_prompt()
|
|
|
|
# ==================================================
|
|
# 🎤 Main Client Class (Anti-Freeze Version)
|
|
# ==================================================
|
|
class HamadGeminiVoice:
|
|
def __init__(self):
|
|
self.audio_q = None
|
|
self.speaking = False
|
|
self.interrupted = False
|
|
self.pya = pyaudio.PyAudio()
|
|
|
|
# Tuning
|
|
self.MIN_THRESHOLD = 3000
|
|
self.barge_in_threshold = 3000
|
|
self.REQUIRED_LOUD_CHUNKS = 5
|
|
|
|
# Stability
|
|
self.PREBUFFER_CHUNKS = 4
|
|
self.PLAYBACK_TIMEOUT = 0.35
|
|
self.BARGE_IN_COOLDOWN = 0.7
|
|
self.AI_SPEAK_GRACE = 0.25
|
|
|
|
# 🛡️ ANTI-FREEZE VARIABLES
|
|
self._last_interruption_time = 0.0
|
|
self.INTERRUPTION_RESET_TIMEOUT = 2.0 # Reset interruption after 2 seconds if stuck
|
|
|
|
self._last_ai_audio_time = 0.0
|
|
self._ai_speaking_since = 0.0
|
|
self._barge_in_block_until = 0.0
|
|
|
|
# Echo Protection
|
|
self.ECHO_GUARD_SEC = 0.8
|
|
self._ignore_input_until = 0.0
|
|
self.SEND_SILENCE_WHEN_SPEAKING = True
|
|
self.SPEAKING_ENERGY_GATE = 0.85
|
|
self._silence_pcm = b"\x00" * (CHUNK_SIZE * 2)
|
|
|
|
def audio_energy(self, pcm):
|
|
try:
|
|
samples = array.array("h", pcm)
|
|
if not samples: return 0
|
|
return sum(abs(s) for s in samples) // len(samples)
|
|
except: return 0
|
|
|
|
def calibrate_mic(self):
|
|
log("🤫 Calibrating Microphone... (Stay Silent)", "info")
|
|
try:
|
|
stream = self.pya.open(format=FORMAT, channels=CHANNELS, rate=SEND_SAMPLE_RATE, input=True, frames_per_buffer=CHUNK_SIZE)
|
|
values = []
|
|
for _ in range(20):
|
|
data = stream.read(CHUNK_SIZE, exception_on_overflow=False)
|
|
values.append(self.audio_energy(data))
|
|
stream.stop_stream()
|
|
stream.close()
|
|
|
|
avg_noise = sum(values) / len(values)
|
|
self.barge_in_threshold = max(self.MIN_THRESHOLD, avg_noise * 3.0)
|
|
log(f"✅ Baseline: {avg_noise:.1f} | Threshold: {self.barge_in_threshold:.1f}", "info")
|
|
except Exception as e:
|
|
log(f"⚠️ Calibration failed: {e}", "warning")
|
|
|
|
async def run(self):
|
|
self.audio_q = asyncio.Queue()
|
|
self.calibrate_mic()
|
|
|
|
log(f"🚀 Connecting to Gemini ({MODEL})...", "info")
|
|
async with websockets.connect(URI, extra_headers={"Content-Type": "application/json"}) as ws:
|
|
|
|
setup_msg = {
|
|
"setup": {
|
|
"model": MODEL,
|
|
"generationConfig": {
|
|
"responseModalities": ["AUDIO"],
|
|
"speechConfig": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": VOICE_NAME}}},
|
|
},
|
|
"systemInstruction": {"parts": [{"text": SYSTEM_PROMPT}]},
|
|
}
|
|
}
|
|
await ws.send(json.dumps(setup_msg))
|
|
await ws.recv()
|
|
log("🎙️ Connected! Listening...", "success")
|
|
|
|
tasks = [
|
|
asyncio.create_task(self.capture_mic(ws)),
|
|
asyncio.create_task(self.receive_audio(ws)),
|
|
asyncio.create_task(self.play_audio()),
|
|
]
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
finally:
|
|
for t in tasks: t.cancel()
|
|
|
|
async def capture_mic(self, ws):
|
|
stream = await to_thread(self.pya.open, format=FORMAT, channels=CHANNELS, rate=SEND_SAMPLE_RATE, input=True, frames_per_buffer=CHUNK_SIZE)
|
|
loud_chunks = 0
|
|
|
|
while True:
|
|
try:
|
|
data = await to_thread(stream.read, CHUNK_SIZE, exception_on_overflow=False)
|
|
energy = self.audio_energy(data)
|
|
now = time.time()
|
|
|
|
# --- INTERRUPTION LOGIC ---
|
|
if self.speaking and (now >= self._barge_in_block_until):
|
|
if (now - self._ai_speaking_since) >= self.AI_SPEAK_GRACE:
|
|
if energy > self.barge_in_threshold:
|
|
loud_chunks += 1
|
|
else:
|
|
loud_chunks = 0
|
|
|
|
if loud_chunks > self.REQUIRED_LOUD_CHUNKS:
|
|
log(f"🛑 Interruption! (Energy: {energy})", "warning")
|
|
self.interrupted = True
|
|
self.speaking = False
|
|
self._last_interruption_time = now # Mark time of interruption
|
|
loud_chunks = 0
|
|
self._barge_in_block_until = now + self.BARGE_IN_COOLDOWN
|
|
|
|
# Drain the audio queue instantly so robot stops talking
|
|
while not self.audio_q.empty():
|
|
try: self.audio_q.get_nowait()
|
|
except asyncio.QueueEmpty: break
|
|
|
|
# Send Silence if Robot is Speaking
|
|
data_to_send = data
|
|
if self.SEND_SILENCE_WHEN_SPEAKING and self.speaking:
|
|
gate = self.barge_in_threshold * self.SPEAKING_ENERGY_GATE
|
|
if energy < gate:
|
|
data_to_send = self._silence_pcm
|
|
|
|
b64_audio = base64.b64encode(data_to_send).decode("utf-8")
|
|
msg = {"realtime_input": {"media_chunks": [{"data": b64_audio, "mime_type": f"audio/pcm;rate={SEND_SAMPLE_RATE}"}]}}
|
|
await ws.send(json.dumps(msg))
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log("⚠️ WebSocket closed.", "error"); break
|
|
except Exception as e:
|
|
log(f"❌ Mic Error: {e}", "error"); break
|
|
|
|
async def receive_audio(self, ws):
|
|
async for msg in ws:
|
|
try:
|
|
response = json.loads(msg)
|
|
server_content = response.get("serverContent", {})
|
|
|
|
# If server confirms interruption, unlock immediately
|
|
if server_content.get("interrupted"):
|
|
self.interrupted = False
|
|
|
|
if self.interrupted: continue
|
|
|
|
model_turn = server_content.get("modelTurn")
|
|
if model_turn:
|
|
parts = model_turn.get("parts", [])
|
|
for part in parts:
|
|
inline_data = part.get("inlineData")
|
|
if inline_data:
|
|
audio_b64 = inline_data.get("data")
|
|
if audio_b64:
|
|
now = time.time()
|
|
if not self.speaking:
|
|
self._ai_speaking_since = now
|
|
self.speaking = True
|
|
self._last_ai_audio_time = now
|
|
self._ignore_input_until = now + self.ECHO_GUARD_SEC
|
|
await self.audio_q.put(base64.b64decode(audio_b64))
|
|
except Exception as e:
|
|
log(f"❌ Parse Error: {e}", "error")
|
|
|
|
async def play_audio(self):
|
|
stream = await to_thread(self.pya.open, format=FORMAT, channels=CHANNELS, rate=RECEIVE_SAMPLE_RATE, output=True, frames_per_buffer=CHUNK_SIZE)
|
|
buffered = False
|
|
while True:
|
|
try:
|
|
# 🛑 ANTI-FREEZE CHECK
|
|
# If interrupted for more than 2 seconds, Force Reset.
|
|
if self.interrupted:
|
|
if (time.time() - self._last_interruption_time) > self.INTERRUPTION_RESET_TIMEOUT:
|
|
log("⚠️ Interruption lock timed out. Force resetting.", "warning")
|
|
self.interrupted = False
|
|
|
|
# While interrupted, discard audio and sleep
|
|
while not self.audio_q.empty():
|
|
try: self.audio_q.get_nowait()
|
|
except asyncio.QueueEmpty: break
|
|
await asyncio.sleep(0.01)
|
|
continue
|
|
|
|
if self.speaking and not buffered:
|
|
while self.audio_q.qsize() < self.PREBUFFER_CHUNKS and self.speaking and not self.interrupted:
|
|
await asyncio.sleep(0.01)
|
|
buffered = True
|
|
|
|
try:
|
|
data = await asyncio.wait_for(self.audio_q.get(), timeout=self.PLAYBACK_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
if self.audio_q.empty() and (time.time() - self._last_ai_audio_time) > 0.25:
|
|
self.speaking = False
|
|
buffered = False
|
|
continue
|
|
|
|
if data: await to_thread(stream.write, data)
|
|
if self.audio_q.empty():
|
|
if (time.time() - self._last_ai_audio_time) > 0.25:
|
|
self.speaking = False
|
|
buffered = False
|
|
except Exception as e:
|
|
log(f"❌ Speaker Error: {e}", "error"); break
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
client = HamadGeminiVoice()
|
|
asyncio.run(client.run())
|
|
except KeyboardInterrupt: pass |