Source code for ttslearn.tacotron.tts

import json
from pathlib import Path

import numpy as np
import pyopenjtalk
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tqdm import tqdm
from ttslearn.dsp import inv_mulaw_quantize, logmelspectrogram_to_audio
from ttslearn.pretrained import retrieve_pretrained_model
from ttslearn.tacotron.frontend.openjtalk import pp_symbols, text_to_sequence
from ttslearn.util import StandardScaler

[docs]class Tacotron2TTS(object): """Tacotron 2 based text-to-speech Args: model_dir (str): model directory. A pre-trained model (ID: ``tacotron2``) is used if None. device (str): cpu or cuda. Examples: >>> from ttslearn.tacotron import Tacotron2TTS >>> engine = Tacotron2TTS() >>> wav, sr = engine.tts("一貫学習にチャレンジしましょう!") """ def __init__(self, model_dir=None, device="cpu"): self.device = device if model_dir is None: model_dir = retrieve_pretrained_model("tacotron2") if isinstance(model_dir, str): model_dir = Path(model_dir) # search for config.yaml if (model_dir / "config.yaml").exists(): config = OmegaConf.load(model_dir / "config.yaml") self.sample_rate = config.sample_rate = else: self.sample_rate = 16000 = 255 # 音響モデル self.acoustic_config = OmegaConf.load(model_dir / "acoustic_model.yaml") self.acoustic_model = instantiate(self.acoustic_config.netG).to(device) checkpoint = torch.load( model_dir / "acoustic_model.pth", map_location=device, ) self.acoustic_model.load_state_dict(checkpoint["state_dict"]) self.acoustic_out_scaler = StandardScaler( np.load(model_dir / "out_tacotron_scaler_mean.npy"), np.load(model_dir / "out_tacotron_scaler_var.npy"), np.load(model_dir / "out_tacotron_scaler_scale.npy"), ) self.acoustic_model.eval() # WaveNet vocoder self.wavenet_config = OmegaConf.load(model_dir / "wavenet_model.yaml") self.wavenet_model = instantiate(self.wavenet_config.netG).to(device) checkpoint = torch.load( model_dir / "wavenet_model.pth", map_location=device, ) self.wavenet_model.load_state_dict(checkpoint["state_dict"]) self.wavenet_model.eval() self.wavenet_model.remove_weight_norm_() def __repr__(self): acoustic_str = json.dumps( OmegaConf.to_container(self.acoustic_config["netG"]), sort_keys=False, indent=4, ) wavenet_str = json.dumps( OmegaConf.to_container(self.wavenet_config["netG"]), sort_keys=False, indent=4, ) return f"""Tacotron2 TTS (sampling rate: {self.sample_rate}) Acoustic model: {acoustic_str} Vocoder model: {wavenet_str} """
[docs] def set_device(self, device): """Set device for the TTS models Args: device (str): cpu or cuda. """ self.device = device
[docs] @torch.no_grad() def tts(self, text, griffin_lim=False, tqdm=tqdm): """Run TTS Args: text (str): Input text griffin_lim (bool, optional): Use Griffin-Lim algorithm or not. Defaults to False. tqdm (object, optional): tqdm object. Defaults to None. Returns: tuple: audio array (np.int16) and sampling rate (int) """ # OpenJTalkを用いて言語特徴量の抽出 contexts = pyopenjtalk.extract_fullcontext(text) # 韻律記号付き音素列に変換 in_feats = text_to_sequence(pp_symbols(contexts)) in_feats = torch.tensor(in_feats, dtype=torch.long).to(self.device) # (T, C) _, out_feats, _, _ = self.acoustic_model.inference(in_feats) if griffin_lim: # Griffin-Lim のアルゴリズムに基づく音声波形合成 out_feats = out_feats.cpu().data.numpy() # 正規化の逆変換 logmel = self.acoustic_out_scaler.inverse_transform(out_feats) gen_wav = logmelspectrogram_to_audio(logmel, self.sample_rate) else: # (B, T, C) -> (B, C, T) c = out_feats.view(1, -1, out_feats.size(-1)).transpose(1, 2) # 音声波形の長さを計算 upsample_scale = T = ( c.shape[-1] - self.wavenet_model.aux_context_window * 2 ) * upsample_scale # WaveNet ボコーダによる音声波形の生成 # NOTE: 計算に時間を要するため、tqdm によるプログレスバーを利用します gen_wav = self.wavenet_model.inference(c, T, tqdm) # One-hot ベクトルから1次元の信号に変換 gen_wav = gen_wav.max(1)[1].float().cpu().numpy().reshape(-1) # Mu-law 量子化の逆変換 # NOTE: muは出力チャンネル数-1だと仮定 gen_wav = inv_mulaw_quantize(gen_wav, self.wavenet_model.out_channels - 1) return self.post_process(gen_wav), self.sample_rate
def post_process(self, wav): wav = np.clip(wav, -1.0, 1.0) wav = (wav * 32767.0).astype(np.int16) return wav
def randomize_tts_engine_(engine: Tacotron2TTS) -> Tacotron2TTS: # アテンションのパラメータの一部を強制的に乱数で初期化することで、学習済みモデルを破壊する torch.nn.init.normal_( return engine