404 lines
14 KiB
Python
404 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import inspect
|
|
import json
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
|
|
import pyaudio
|
|
import websockets
|
|
|
|
from Core import audio_prompts
|
|
from Core import settings as config
|
|
|
|
FORMAT = pyaudio.paInt16
|
|
CHANNELS = 1
|
|
RECEIVE_SAMPLE_RATE = 24000
|
|
CHUNK_SIZE = 512
|
|
MONITOR_CHUNK_SIZE = 1024
|
|
MONITOR_TAIL_SEC = 0.20
|
|
|
|
DEFAULT_SINK = "alsa_output.usb-Anker_PowerConf_A3321-DEV-SN1-01.analog-stereo"
|
|
DEFAULT_SOURCE = "alsa_input.usb-Anker_PowerConf_A3321-DEV-SN1-01.mono-fallback"
|
|
|
|
STRICT_REPLAY_SYSTEM_PROMPT = (
|
|
"You are Sanad (Bousandah), a wise and friendly Emirati assistant. "
|
|
"For this session, the user will provide text that you must speak exactly as written. "
|
|
"Do not translate it. Do not summarize it. Do not answer it. "
|
|
"Do not rephrase it into another dialect or style. "
|
|
"Do not add greetings, punctuation changes, or extra words. "
|
|
"Keep the same word order and language as the provided text. "
|
|
"Your job is only to speak the exact user text verbatim. "
|
|
"Speak only the user text. Do not add filler words. "
|
|
"Do not change tense, grammar, dialect, or wording."
|
|
)
|
|
|
|
|
|
def _env_sink() -> str:
|
|
return str(os.environ.get("SINK", DEFAULT_SINK) or DEFAULT_SINK).strip()
|
|
|
|
|
|
def _env_source() -> str:
|
|
return str(os.environ.get("SOURCE", DEFAULT_SOURCE) or DEFAULT_SOURCE).strip()
|
|
|
|
|
|
def _monitor_source() -> str:
|
|
override = str(os.environ.get("MONITOR_SOURCE", "") or "").strip()
|
|
return override or f"{_env_sink()}.monitor"
|
|
|
|
|
|
def _build_verbatim_request_text(text: str) -> str:
|
|
return (
|
|
"Speak exactly the text inside the <verbatim> block.\n"
|
|
"Do not add or remove any words.\n"
|
|
"Do not translate or paraphrase.\n"
|
|
"Do not change the language or dialect.\n"
|
|
"Read only the content inside the tags.\n"
|
|
"<verbatim>\n"
|
|
f"{text}\n"
|
|
"</verbatim>"
|
|
)
|
|
|
|
|
|
def _iter_input_devices(audio: pyaudio.PyAudio):
|
|
for index in range(audio.get_device_count()):
|
|
info = audio.get_device_info_by_index(index)
|
|
if int(info.get("maxInputChannels", 0)) > 0:
|
|
yield index, info
|
|
|
|
|
|
def _find_preferred_input_device(audio: pyaudio.PyAudio):
|
|
for preferred_name in ("pulse", "default"):
|
|
hint = preferred_name.strip().lower()
|
|
exact = None
|
|
partial = None
|
|
for index, info in _iter_input_devices(audio):
|
|
name = str(info.get("name", "")).strip().lower()
|
|
if name == hint:
|
|
exact = (index, info)
|
|
break
|
|
if hint in name and partial is None:
|
|
partial = (index, info)
|
|
if exact is not None:
|
|
return exact
|
|
if partial is not None:
|
|
return partial
|
|
return None
|
|
|
|
|
|
class _MonitorRecorder:
|
|
def __init__(self, audio: pyaudio.PyAudio, device_config: dict):
|
|
self.audio = audio
|
|
self.device_config = device_config
|
|
self.frames: list[bytes] = []
|
|
self._stop_event = threading.Event()
|
|
self._thread = None
|
|
self._stream = None
|
|
self._error = None
|
|
|
|
def start(self):
|
|
self._stream = self.audio.open(
|
|
format=FORMAT,
|
|
channels=self.device_config["channels"],
|
|
rate=self.device_config["rate"],
|
|
input=True,
|
|
input_device_index=self.device_config["index"],
|
|
frames_per_buffer=self.device_config["chunk_size"],
|
|
)
|
|
self._thread = threading.Thread(target=self._record_loop, daemon=True)
|
|
self._thread.start()
|
|
time.sleep(0.05)
|
|
|
|
def _record_loop(self):
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
data = self._stream.read(
|
|
self.device_config["chunk_size"],
|
|
exception_on_overflow=False,
|
|
)
|
|
self.frames.append(data)
|
|
except Exception as exc:
|
|
if not self._stop_event.is_set():
|
|
self._error = exc
|
|
break
|
|
|
|
def stop(self) -> bytes:
|
|
time.sleep(MONITOR_TAIL_SEC)
|
|
self._stop_event.set()
|
|
if self._stream is not None:
|
|
try:
|
|
self._stream.stop_stream()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
self._stream.close()
|
|
except Exception:
|
|
pass
|
|
if self._thread is not None:
|
|
self._thread.join(timeout=1.0)
|
|
if self._error is not None:
|
|
raise RuntimeError(f"Speaker monitor capture failed: {self._error}")
|
|
return b"".join(self.frames)
|
|
|
|
|
|
class _ParecMonitorRecorder:
|
|
def __init__(self, device_config: dict):
|
|
self.device_config = device_config
|
|
self.frames: list[bytes] = []
|
|
self._stop_event = threading.Event()
|
|
self._thread = None
|
|
self._proc = None
|
|
self._error = None
|
|
|
|
def start(self):
|
|
command = [
|
|
"parec",
|
|
f"--device={self.device_config['monitor_source']}",
|
|
"--format=s16le",
|
|
f"--rate={self.device_config['rate']}",
|
|
f"--channels={self.device_config['channels']}",
|
|
]
|
|
self._proc = subprocess.Popen(
|
|
command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.DEVNULL,
|
|
)
|
|
self._thread = threading.Thread(target=self._record_loop, daemon=True)
|
|
self._thread.start()
|
|
time.sleep(0.05)
|
|
|
|
def _record_loop(self):
|
|
if self._proc is None or self._proc.stdout is None:
|
|
self._error = RuntimeError("parec process did not start correctly.")
|
|
return
|
|
read_size = self.device_config["chunk_size"] * self.device_config["channels"] * 2
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
data = self._proc.stdout.read(read_size)
|
|
if data:
|
|
self.frames.append(data)
|
|
continue
|
|
if self._proc.poll() is not None:
|
|
break
|
|
except Exception as exc:
|
|
if not self._stop_event.is_set():
|
|
self._error = exc
|
|
break
|
|
|
|
def stop(self) -> bytes:
|
|
time.sleep(MONITOR_TAIL_SEC)
|
|
self._stop_event.set()
|
|
if self._proc is not None and self._proc.poll() is None:
|
|
self._proc.terminate()
|
|
try:
|
|
self._proc.wait(timeout=1.0)
|
|
except subprocess.TimeoutExpired:
|
|
self._proc.kill()
|
|
try:
|
|
self._proc.wait(timeout=1.0)
|
|
except subprocess.TimeoutExpired:
|
|
pass
|
|
if self._thread is not None:
|
|
self._thread.join(timeout=1.0)
|
|
if self._error is not None:
|
|
raise RuntimeError(f"parec speaker monitor capture failed: {self._error}")
|
|
return b"".join(self.frames)
|
|
|
|
|
|
class AudioPromptRecorderClient:
|
|
def __init__(self):
|
|
self.pya = pyaudio.PyAudio()
|
|
self.monitor_config = self._resolve_capture_device()
|
|
|
|
def close(self):
|
|
self.pya.terminate()
|
|
|
|
def _ws_connect_kwargs(self):
|
|
kwargs = {"max_size": None}
|
|
try:
|
|
sig = inspect.signature(websockets.connect)
|
|
if "extra_headers" in sig.parameters:
|
|
kwargs["extra_headers"] = {"Content-Type": "application/json"}
|
|
else:
|
|
kwargs["additional_headers"] = {"Content-Type": "application/json"}
|
|
except Exception:
|
|
kwargs["extra_headers"] = {"Content-Type": "application/json"}
|
|
return kwargs
|
|
|
|
def _build_monitor_config(self, device_index: int, info: dict) -> dict:
|
|
return {
|
|
"backend": "pyaudio",
|
|
"index": device_index,
|
|
"name": str(info.get("name", "unknown")),
|
|
"rate": int(info.get("defaultSampleRate", RECEIVE_SAMPLE_RATE)),
|
|
"channels": max(1, min(2, int(info.get("maxInputChannels", 1)))),
|
|
"chunk_size": MONITOR_CHUNK_SIZE,
|
|
"sink": _env_sink(),
|
|
"source": _env_source(),
|
|
"monitor_source": _monitor_source(),
|
|
}
|
|
|
|
def _set_default_source(self, source_name: str):
|
|
try:
|
|
subprocess.run(
|
|
["pactl", "set-default-source", source_name],
|
|
check=True,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
)
|
|
except FileNotFoundError as exc:
|
|
raise RuntimeError("pactl is required but was not found.") from exc
|
|
except subprocess.CalledProcessError as exc:
|
|
stderr = (exc.stderr or "").strip()
|
|
raise RuntimeError(stderr or f"Failed to set default source to {source_name}.") from exc
|
|
|
|
def _resolve_capture_device(self) -> dict:
|
|
monitor_source = _monitor_source()
|
|
if shutil.which("parec"):
|
|
return {
|
|
"backend": "parec",
|
|
"name": monitor_source,
|
|
"rate": RECEIVE_SAMPLE_RATE,
|
|
"channels": CHANNELS,
|
|
"chunk_size": MONITOR_CHUNK_SIZE,
|
|
"sink": _env_sink(),
|
|
"source": _env_source(),
|
|
"monitor_source": monitor_source,
|
|
}
|
|
match = _find_preferred_input_device(self.pya)
|
|
if match is None:
|
|
raise RuntimeError("No suitable PulseAudio input device was found. Expected 'pulse' or 'default'.")
|
|
return self._build_monitor_config(*match)
|
|
|
|
async def generate_audio(self, text: str) -> bytes:
|
|
if not config.GEMINI_API_KEY:
|
|
raise RuntimeError("Gemini API key is missing in config.")
|
|
|
|
async with websockets.connect(config.URI, **self._ws_connect_kwargs()) as ws:
|
|
setup_msg = {
|
|
"setup": {
|
|
"model": config.GEMINI_MODEL,
|
|
"generationConfig": {
|
|
"responseModalities": ["AUDIO"],
|
|
"speechConfig": {
|
|
"voiceConfig": {
|
|
"prebuiltVoiceConfig": {"voiceName": config.VOICE_NAME}
|
|
}
|
|
},
|
|
},
|
|
"systemInstruction": {"parts": [{"text": STRICT_REPLAY_SYSTEM_PROMPT}]},
|
|
}
|
|
}
|
|
await ws.send(json.dumps(setup_msg))
|
|
await ws.recv()
|
|
|
|
request_msg = {
|
|
"client_content": {
|
|
"turns": [
|
|
{
|
|
"role": "user",
|
|
"parts": [{"text": _build_verbatim_request_text(text)}],
|
|
}
|
|
],
|
|
"turn_complete": True,
|
|
}
|
|
}
|
|
await ws.send(json.dumps(request_msg))
|
|
|
|
audio_chunks: list[bytes] = []
|
|
while True:
|
|
raw_msg = await ws.recv()
|
|
response = json.loads(raw_msg)
|
|
server_content = response.get("serverContent", {})
|
|
model_turn = server_content.get("modelTurn", {})
|
|
|
|
for part in model_turn.get("parts", []):
|
|
inline_data = part.get("inlineData")
|
|
if inline_data and inline_data.get("data"):
|
|
audio_chunks.append(base64.b64decode(inline_data["data"]))
|
|
|
|
if server_content.get("turnComplete") or server_content.get("generationComplete"):
|
|
break
|
|
|
|
if not audio_chunks:
|
|
raise RuntimeError("Gemini returned no audio for the provided text.")
|
|
return b"".join(audio_chunks)
|
|
|
|
def play_audio_and_capture(self, audio_bytes: bytes) -> bytes:
|
|
if self.monitor_config.get("backend") == "parec":
|
|
recorder = _ParecMonitorRecorder(self.monitor_config)
|
|
else:
|
|
recorder = _MonitorRecorder(self.pya, self.monitor_config)
|
|
stream = None
|
|
if self.monitor_config.get("backend") != "parec":
|
|
self._set_default_source(self.monitor_config["monitor_source"])
|
|
try:
|
|
stream = self.pya.open(
|
|
format=FORMAT,
|
|
channels=CHANNELS,
|
|
rate=RECEIVE_SAMPLE_RATE,
|
|
output=True,
|
|
frames_per_buffer=CHUNK_SIZE,
|
|
)
|
|
recorder.start()
|
|
for offset in range(0, len(audio_bytes), CHUNK_SIZE * 2):
|
|
stream.write(audio_bytes[offset: offset + (CHUNK_SIZE * 2)])
|
|
finally:
|
|
try:
|
|
if stream is not None:
|
|
stream.stop_stream()
|
|
finally:
|
|
if stream is not None:
|
|
stream.close()
|
|
try:
|
|
captured_audio = recorder.stop()
|
|
finally:
|
|
if self.monitor_config.get("backend") != "parec":
|
|
self._set_default_source(self.monitor_config["source"])
|
|
return captured_audio
|
|
|
|
def sample_width(self) -> int:
|
|
return self.pya.get_sample_size(FORMAT)
|
|
|
|
|
|
def record_prompt_from_text(key: str, text: str, filename: str = "") -> dict:
|
|
clean_key = audio_prompts._clean_key(key) # type: ignore[attr-defined]
|
|
clean_text = str(text or "").strip()
|
|
if not clean_text:
|
|
raise ValueError("text is required")
|
|
|
|
client = AudioPromptRecorderClient()
|
|
try:
|
|
raw_audio = asyncio.run(client.generate_audio(clean_text))
|
|
speaker_audio = client.play_audio_and_capture(raw_audio)
|
|
result = audio_prompts.save_audio_prompt_bundle(
|
|
clean_key,
|
|
speaker_audio,
|
|
filename=filename or audio_prompts.prompt_filename(clean_key),
|
|
raw_data=raw_audio,
|
|
text=clean_text,
|
|
model=config.GEMINI_MODEL,
|
|
voice_name=config.VOICE_NAME,
|
|
replay_count=1,
|
|
speaker_rate=int(client.monitor_config["rate"]),
|
|
speaker_channels=int(client.monitor_config["channels"]),
|
|
raw_rate=RECEIVE_SAMPLE_RATE,
|
|
raw_channels=CHANNELS,
|
|
sample_width=client.sample_width(),
|
|
capture_device=str(client.monitor_config.get("name", "")),
|
|
sink=str(client.monitor_config.get("sink", "")),
|
|
source=str(client.monitor_config.get("source", "")),
|
|
monitor_source=str(client.monitor_config.get("monitor_source", "")),
|
|
)
|
|
result["text"] = clean_text
|
|
result["raw_path"] = str(audio_prompts.raw_prompt_path(clean_key))
|
|
return result
|
|
finally:
|
|
client.close()
|