Sanad/voice/local_tts.py

129 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Local Arabic TTS using MBZUAI/speecht5_tts_clartts_ar (SpeechT5 fine-tuned on CLArTTS).
Loads model/vocoder/speaker-embedding from the local Model/ directory.
Lazy-loads on first call so the webserver starts quickly.
Output: 16 kHz mono int16 PCM bytes (matching WAV conventions).
"""
from __future__ import annotations
import re
import threading
from pathlib import Path
from typing import Any
# ── Local paths (all pre-downloaded under model/) — sourced from config ──
try:
from Project.Sanad.core.config_loader import section as _cfg_section
_TTS = _cfg_section("voice", "local_tts")
except Exception:
_TTS = {}
_PROJECT_DIR = Path(__file__).resolve().parent.parent # Sanad/
_MODEL_ROOT = _PROJECT_DIR / "model"
MODEL_DIR = _MODEL_ROOT / _TTS.get("model_subdir", "speecht5_tts_clartts_ar")
VOCODER_DIR = _MODEL_ROOT / _TTS.get("vocoder_subdir", "speecht5_hifigan")
XVECTOR_PATH = _MODEL_ROOT / _TTS.get("xvector_filename", "arabic_xvector_embedding.pt")
MODEL_ID = str(MODEL_DIR)
VOCODER_ID = str(VOCODER_DIR)
SAMPLE_RATE = _TTS.get("sample_rate", 16000)
CHANNELS = _TTS.get("channels", 1)
# Arabic diacritics (tashkeel) Unicode range model was trained without them.
_DIACRITICS_RE = re.compile(r"[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]")
def strip_diacritics(text: str) -> str:
return _DIACRITICS_RE.sub("", text)
class LocalTTSEngine:
def __init__(self):
self._lock = threading.Lock()
self._loaded = False
self._processor = None
self._model = None
self._vocoder = None
self._speaker_embedding = None
def _ensure_loaded(self):
if self._loaded:
return
with self._lock:
if self._loaded:
return
for label, p in [("Model", MODEL_DIR), ("Vocoder", VOCODER_DIR), ("XVector", XVECTOR_PATH)]:
if not p.exists():
raise RuntimeError(f"{label} not found at {p}")
import torch
from transformers import (
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
SpeechT5Processor,
)
self._processor = SpeechT5Processor.from_pretrained(MODEL_ID)
self._model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
self._vocoder = SpeechT5HifiGan.from_pretrained(VOCODER_ID)
self._speaker_embedding = torch.load(str(XVECTOR_PATH), map_location="cpu")
self._loaded = True
@property
def ready(self) -> bool:
return self._loaded
def status(self) -> dict[str, Any]:
return {
"loaded": self._loaded,
"model_dir": str(MODEL_DIR),
"vocoder_dir": str(VOCODER_DIR),
"xvector_path": str(XVECTOR_PATH),
"model_exists": MODEL_DIR.exists(),
"vocoder_exists": VOCODER_DIR.exists(),
"xvector_exists": XVECTOR_PATH.exists(),
"sample_rate": SAMPLE_RATE,
}
def synthesize(self, text: str) -> bytes:
"""Convert Arabic text to 16 kHz mono int16 PCM bytes."""
self._ensure_loaded()
import torch
clean_text = strip_diacritics(text.strip())
if not clean_text:
raise RuntimeError("Text is empty after stripping diacritics.")
inputs = self._processor(text=clean_text, return_tensors="pt")
with torch.no_grad():
speech = self._model.generate_speech(
inputs["input_ids"],
self._speaker_embedding,
vocoder=self._vocoder,
)
# speech is a 1-D float32 tensor in [-1, 1] at 16 kHz
pcm_float = speech.numpy()
# Convert float32 → int16 PCM bytes
pcm_int16 = (pcm_float * 32767).clip(-32768, 32767).astype("int16")
return pcm_int16.tobytes()
def synthesize_wav(self, text: str) -> bytes:
"""Return a complete WAV file (bytes) for the given text."""
import io
import wave
pcm = self.synthesize(text)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(2) # int16
wf.setframerate(SAMPLE_RATE)
wf.writeframes(pcm)
return buf.getvalue()