AI_Photographer/Core/audio_prompt_recorder.py
2026-04-12 18:52:37 +04:00

404 lines
14 KiB
Python

from __future__ import annotations
import asyncio
import base64
import inspect
import json
import os
import shutil
import subprocess
import threading
import time
import pyaudio
import websockets
from Core import audio_prompts
from Core import settings as config
FORMAT = pyaudio.paInt16
CHANNELS = 1
RECEIVE_SAMPLE_RATE = 24000
CHUNK_SIZE = 512
MONITOR_CHUNK_SIZE = 1024
MONITOR_TAIL_SEC = 0.20
DEFAULT_SINK = "alsa_output.usb-Anker_PowerConf_A3321-DEV-SN1-01.analog-stereo"
DEFAULT_SOURCE = "alsa_input.usb-Anker_PowerConf_A3321-DEV-SN1-01.mono-fallback"
STRICT_REPLAY_SYSTEM_PROMPT = (
"You are Sanad (Bousandah), a wise and friendly Emirati assistant. "
"For this session, the user will provide text that you must speak exactly as written. "
"Do not translate it. Do not summarize it. Do not answer it. "
"Do not rephrase it into another dialect or style. "
"Do not add greetings, punctuation changes, or extra words. "
"Keep the same word order and language as the provided text. "
"Your job is only to speak the exact user text verbatim. "
"Speak only the user text. Do not add filler words. "
"Do not change tense, grammar, dialect, or wording."
)
def _env_sink() -> str:
return str(os.environ.get("SINK", DEFAULT_SINK) or DEFAULT_SINK).strip()
def _env_source() -> str:
return str(os.environ.get("SOURCE", DEFAULT_SOURCE) or DEFAULT_SOURCE).strip()
def _monitor_source() -> str:
override = str(os.environ.get("MONITOR_SOURCE", "") or "").strip()
return override or f"{_env_sink()}.monitor"
def _build_verbatim_request_text(text: str) -> str:
return (
"Speak exactly the text inside the <verbatim> block.\n"
"Do not add or remove any words.\n"
"Do not translate or paraphrase.\n"
"Do not change the language or dialect.\n"
"Read only the content inside the tags.\n"
"<verbatim>\n"
f"{text}\n"
"</verbatim>"
)
def _iter_input_devices(audio: pyaudio.PyAudio):
for index in range(audio.get_device_count()):
info = audio.get_device_info_by_index(index)
if int(info.get("maxInputChannels", 0)) > 0:
yield index, info
def _find_preferred_input_device(audio: pyaudio.PyAudio):
for preferred_name in ("pulse", "default"):
hint = preferred_name.strip().lower()
exact = None
partial = None
for index, info in _iter_input_devices(audio):
name = str(info.get("name", "")).strip().lower()
if name == hint:
exact = (index, info)
break
if hint in name and partial is None:
partial = (index, info)
if exact is not None:
return exact
if partial is not None:
return partial
return None
class _MonitorRecorder:
def __init__(self, audio: pyaudio.PyAudio, device_config: dict):
self.audio = audio
self.device_config = device_config
self.frames: list[bytes] = []
self._stop_event = threading.Event()
self._thread = None
self._stream = None
self._error = None
def start(self):
self._stream = self.audio.open(
format=FORMAT,
channels=self.device_config["channels"],
rate=self.device_config["rate"],
input=True,
input_device_index=self.device_config["index"],
frames_per_buffer=self.device_config["chunk_size"],
)
self._thread = threading.Thread(target=self._record_loop, daemon=True)
self._thread.start()
time.sleep(0.05)
def _record_loop(self):
while not self._stop_event.is_set():
try:
data = self._stream.read(
self.device_config["chunk_size"],
exception_on_overflow=False,
)
self.frames.append(data)
except Exception as exc:
if not self._stop_event.is_set():
self._error = exc
break
def stop(self) -> bytes:
time.sleep(MONITOR_TAIL_SEC)
self._stop_event.set()
if self._stream is not None:
try:
self._stream.stop_stream()
except Exception:
pass
try:
self._stream.close()
except Exception:
pass
if self._thread is not None:
self._thread.join(timeout=1.0)
if self._error is not None:
raise RuntimeError(f"Speaker monitor capture failed: {self._error}")
return b"".join(self.frames)
class _ParecMonitorRecorder:
def __init__(self, device_config: dict):
self.device_config = device_config
self.frames: list[bytes] = []
self._stop_event = threading.Event()
self._thread = None
self._proc = None
self._error = None
def start(self):
command = [
"parec",
f"--device={self.device_config['monitor_source']}",
"--format=s16le",
f"--rate={self.device_config['rate']}",
f"--channels={self.device_config['channels']}",
]
self._proc = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
self._thread = threading.Thread(target=self._record_loop, daemon=True)
self._thread.start()
time.sleep(0.05)
def _record_loop(self):
if self._proc is None or self._proc.stdout is None:
self._error = RuntimeError("parec process did not start correctly.")
return
read_size = self.device_config["chunk_size"] * self.device_config["channels"] * 2
while not self._stop_event.is_set():
try:
data = self._proc.stdout.read(read_size)
if data:
self.frames.append(data)
continue
if self._proc.poll() is not None:
break
except Exception as exc:
if not self._stop_event.is_set():
self._error = exc
break
def stop(self) -> bytes:
time.sleep(MONITOR_TAIL_SEC)
self._stop_event.set()
if self._proc is not None and self._proc.poll() is None:
self._proc.terminate()
try:
self._proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
self._proc.kill()
try:
self._proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
pass
if self._thread is not None:
self._thread.join(timeout=1.0)
if self._error is not None:
raise RuntimeError(f"parec speaker monitor capture failed: {self._error}")
return b"".join(self.frames)
class AudioPromptRecorderClient:
def __init__(self):
self.pya = pyaudio.PyAudio()
self.monitor_config = self._resolve_capture_device()
def close(self):
self.pya.terminate()
def _ws_connect_kwargs(self):
kwargs = {"max_size": None}
try:
sig = inspect.signature(websockets.connect)
if "extra_headers" in sig.parameters:
kwargs["extra_headers"] = {"Content-Type": "application/json"}
else:
kwargs["additional_headers"] = {"Content-Type": "application/json"}
except Exception:
kwargs["extra_headers"] = {"Content-Type": "application/json"}
return kwargs
def _build_monitor_config(self, device_index: int, info: dict) -> dict:
return {
"backend": "pyaudio",
"index": device_index,
"name": str(info.get("name", "unknown")),
"rate": int(info.get("defaultSampleRate", RECEIVE_SAMPLE_RATE)),
"channels": max(1, min(2, int(info.get("maxInputChannels", 1)))),
"chunk_size": MONITOR_CHUNK_SIZE,
"sink": _env_sink(),
"source": _env_source(),
"monitor_source": _monitor_source(),
}
def _set_default_source(self, source_name: str):
try:
subprocess.run(
["pactl", "set-default-source", source_name],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
text=True,
)
except FileNotFoundError as exc:
raise RuntimeError("pactl is required but was not found.") from exc
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
raise RuntimeError(stderr or f"Failed to set default source to {source_name}.") from exc
def _resolve_capture_device(self) -> dict:
monitor_source = _monitor_source()
if shutil.which("parec"):
return {
"backend": "parec",
"name": monitor_source,
"rate": RECEIVE_SAMPLE_RATE,
"channels": CHANNELS,
"chunk_size": MONITOR_CHUNK_SIZE,
"sink": _env_sink(),
"source": _env_source(),
"monitor_source": monitor_source,
}
match = _find_preferred_input_device(self.pya)
if match is None:
raise RuntimeError("No suitable PulseAudio input device was found. Expected 'pulse' or 'default'.")
return self._build_monitor_config(*match)
async def generate_audio(self, text: str) -> bytes:
if not config.GEMINI_API_KEY:
raise RuntimeError("Gemini API key is missing in config.")
async with websockets.connect(config.URI, **self._ws_connect_kwargs()) as ws:
setup_msg = {
"setup": {
"model": config.GEMINI_MODEL,
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {"voiceName": config.VOICE_NAME}
}
},
},
"systemInstruction": {"parts": [{"text": STRICT_REPLAY_SYSTEM_PROMPT}]},
}
}
await ws.send(json.dumps(setup_msg))
await ws.recv()
request_msg = {
"client_content": {
"turns": [
{
"role": "user",
"parts": [{"text": _build_verbatim_request_text(text)}],
}
],
"turn_complete": True,
}
}
await ws.send(json.dumps(request_msg))
audio_chunks: list[bytes] = []
while True:
raw_msg = await ws.recv()
response = json.loads(raw_msg)
server_content = response.get("serverContent", {})
model_turn = server_content.get("modelTurn", {})
for part in model_turn.get("parts", []):
inline_data = part.get("inlineData")
if inline_data and inline_data.get("data"):
audio_chunks.append(base64.b64decode(inline_data["data"]))
if server_content.get("turnComplete") or server_content.get("generationComplete"):
break
if not audio_chunks:
raise RuntimeError("Gemini returned no audio for the provided text.")
return b"".join(audio_chunks)
def play_audio_and_capture(self, audio_bytes: bytes) -> bytes:
if self.monitor_config.get("backend") == "parec":
recorder = _ParecMonitorRecorder(self.monitor_config)
else:
recorder = _MonitorRecorder(self.pya, self.monitor_config)
stream = None
if self.monitor_config.get("backend") != "parec":
self._set_default_source(self.monitor_config["monitor_source"])
try:
stream = self.pya.open(
format=FORMAT,
channels=CHANNELS,
rate=RECEIVE_SAMPLE_RATE,
output=True,
frames_per_buffer=CHUNK_SIZE,
)
recorder.start()
for offset in range(0, len(audio_bytes), CHUNK_SIZE * 2):
stream.write(audio_bytes[offset: offset + (CHUNK_SIZE * 2)])
finally:
try:
if stream is not None:
stream.stop_stream()
finally:
if stream is not None:
stream.close()
try:
captured_audio = recorder.stop()
finally:
if self.monitor_config.get("backend") != "parec":
self._set_default_source(self.monitor_config["source"])
return captured_audio
def sample_width(self) -> int:
return self.pya.get_sample_size(FORMAT)
def record_prompt_from_text(key: str, text: str, filename: str = "") -> dict:
clean_key = audio_prompts._clean_key(key) # type: ignore[attr-defined]
clean_text = str(text or "").strip()
if not clean_text:
raise ValueError("text is required")
client = AudioPromptRecorderClient()
try:
raw_audio = asyncio.run(client.generate_audio(clean_text))
speaker_audio = client.play_audio_and_capture(raw_audio)
result = audio_prompts.save_audio_prompt_bundle(
clean_key,
speaker_audio,
filename=filename or audio_prompts.prompt_filename(clean_key),
raw_data=raw_audio,
text=clean_text,
model=config.GEMINI_MODEL,
voice_name=config.VOICE_NAME,
replay_count=1,
speaker_rate=int(client.monitor_config["rate"]),
speaker_channels=int(client.monitor_config["channels"]),
raw_rate=RECEIVE_SAMPLE_RATE,
raw_channels=CHANNELS,
sample_width=client.sample_width(),
capture_device=str(client.monitor_config.get("name", "")),
sink=str(client.monitor_config.get("sink", "")),
source=str(client.monitor_config.get("source", "")),
monitor_source=str(client.monitor_config.get("monitor_source", "")),
)
result["text"] = clean_text
result["raw_path"] = str(audio_prompts.raw_prompt_path(clean_key))
return result
finally:
client.close()