129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
"""Local Arabic TTS using MBZUAI/speecht5_tts_clartts_ar (SpeechT5 fine-tuned on CLArTTS).
|
||
|
||
Loads model/vocoder/speaker-embedding from the local Model/ directory.
|
||
Lazy-loads on first call so the webserver starts quickly.
|
||
|
||
Output: 16 kHz mono int16 PCM bytes (matching WAV conventions).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
import threading
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
# ── Local paths (all pre-downloaded under model/) — sourced from config ──
|
||
try:
|
||
from Project.Sanad.core.config_loader import section as _cfg_section
|
||
_TTS = _cfg_section("voice", "local_tts")
|
||
except Exception:
|
||
_TTS = {}
|
||
|
||
_PROJECT_DIR = Path(__file__).resolve().parent.parent # Sanad/
|
||
_MODEL_ROOT = _PROJECT_DIR / "model"
|
||
MODEL_DIR = _MODEL_ROOT / _TTS.get("model_subdir", "speecht5_tts_clartts_ar")
|
||
VOCODER_DIR = _MODEL_ROOT / _TTS.get("vocoder_subdir", "speecht5_hifigan")
|
||
XVECTOR_PATH = _MODEL_ROOT / _TTS.get("xvector_filename", "arabic_xvector_embedding.pt")
|
||
|
||
MODEL_ID = str(MODEL_DIR)
|
||
VOCODER_ID = str(VOCODER_DIR)
|
||
SAMPLE_RATE = _TTS.get("sample_rate", 16000)
|
||
CHANNELS = _TTS.get("channels", 1)
|
||
|
||
# Arabic diacritics (tashkeel) Unicode range – model was trained without them.
|
||
_DIACRITICS_RE = re.compile(r"[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]")
|
||
|
||
|
||
def strip_diacritics(text: str) -> str:
|
||
return _DIACRITICS_RE.sub("", text)
|
||
|
||
|
||
class LocalTTSEngine:
|
||
def __init__(self):
|
||
self._lock = threading.Lock()
|
||
self._loaded = False
|
||
self._processor = None
|
||
self._model = None
|
||
self._vocoder = None
|
||
self._speaker_embedding = None
|
||
|
||
def _ensure_loaded(self):
|
||
if self._loaded:
|
||
return
|
||
with self._lock:
|
||
if self._loaded:
|
||
return
|
||
|
||
for label, p in [("Model", MODEL_DIR), ("Vocoder", VOCODER_DIR), ("XVector", XVECTOR_PATH)]:
|
||
if not p.exists():
|
||
raise RuntimeError(f"{label} not found at {p}")
|
||
|
||
import torch
|
||
from transformers import (
|
||
SpeechT5ForTextToSpeech,
|
||
SpeechT5HifiGan,
|
||
SpeechT5Processor,
|
||
)
|
||
|
||
self._processor = SpeechT5Processor.from_pretrained(MODEL_ID)
|
||
self._model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
|
||
self._vocoder = SpeechT5HifiGan.from_pretrained(VOCODER_ID)
|
||
self._speaker_embedding = torch.load(str(XVECTOR_PATH), map_location="cpu")
|
||
|
||
self._loaded = True
|
||
|
||
@property
|
||
def ready(self) -> bool:
|
||
return self._loaded
|
||
|
||
def status(self) -> dict[str, Any]:
|
||
return {
|
||
"loaded": self._loaded,
|
||
"model_dir": str(MODEL_DIR),
|
||
"vocoder_dir": str(VOCODER_DIR),
|
||
"xvector_path": str(XVECTOR_PATH),
|
||
"model_exists": MODEL_DIR.exists(),
|
||
"vocoder_exists": VOCODER_DIR.exists(),
|
||
"xvector_exists": XVECTOR_PATH.exists(),
|
||
"sample_rate": SAMPLE_RATE,
|
||
}
|
||
|
||
def synthesize(self, text: str) -> bytes:
|
||
"""Convert Arabic text to 16 kHz mono int16 PCM bytes."""
|
||
self._ensure_loaded()
|
||
import torch
|
||
|
||
clean_text = strip_diacritics(text.strip())
|
||
if not clean_text:
|
||
raise RuntimeError("Text is empty after stripping diacritics.")
|
||
|
||
inputs = self._processor(text=clean_text, return_tensors="pt")
|
||
|
||
with torch.no_grad():
|
||
speech = self._model.generate_speech(
|
||
inputs["input_ids"],
|
||
self._speaker_embedding,
|
||
vocoder=self._vocoder,
|
||
)
|
||
|
||
# speech is a 1-D float32 tensor in [-1, 1] at 16 kHz
|
||
pcm_float = speech.numpy()
|
||
# Convert float32 → int16 PCM bytes
|
||
pcm_int16 = (pcm_float * 32767).clip(-32768, 32767).astype("int16")
|
||
return pcm_int16.tobytes()
|
||
|
||
def synthesize_wav(self, text: str) -> bytes:
|
||
"""Return a complete WAV file (bytes) for the given text."""
|
||
import io
|
||
import wave
|
||
|
||
pcm = self.synthesize(text)
|
||
buf = io.BytesIO()
|
||
with wave.open(buf, "wb") as wf:
|
||
wf.setnchannels(CHANNELS)
|
||
wf.setsampwidth(2) # int16
|
||
wf.setframerate(SAMPLE_RATE)
|
||
wf.writeframes(pcm)
|
||
return buf.getvalue()
|