Source code for ttslearn.wavenet.tts

import json
from pathlib import Path

import numpy as np
import pyopenjtalk
import torch
from hydra.utils import instantiate
from nnmnkwii.io import hts
from omegaconf import OmegaConf
from tqdm import tqdm
from ttslearn.dnntts.gen import predict_acoustic, predict_duration
from ttslearn.pretrained import retrieve_pretrained_model
from ttslearn.util import StandardScaler
from ttslearn.wavenet.gen import gen_waveform


[docs]class WaveNetTTS(object): """WaveNet-based text-to-speech Args: model_dir (str): model directory. A pre-trained model (ID: ``wavenettts``) is used if None. device (str): cpu or cuda. Examples: >>> from ttslearn.wavenet import WaveNetTTS >>> engine = WaveNetTTS() >>> wav, sr = engine.tts("今日もいい天気ですね。") """ def __init__(self, model_dir=None, device="cpu"): self.device = device if model_dir is None: model_dir = retrieve_pretrained_model("wavenettts") if isinstance(model_dir, str): model_dir = Path(model_dir) # search for config.yaml if (model_dir / "config.yaml").exists(): config = OmegaConf.load(model_dir / "config.yaml") self.sample_rate = config.sample_rate self.mu = config.mu else: self.sample_rate = 16000 self.mu = 255 # qst self.binary_dict, self.numeric_dict = hts.load_question_set( model_dir / "qst.hed" ) # 継続長モデル self.duration_config = OmegaConf.load(model_dir / "duration_model.yaml") self.duration_model = instantiate(self.duration_config.netG).to(device) checkpoint = torch.load( model_dir / "duration_model.pth", map_location=device, ) self.duration_model.load_state_dict(checkpoint["state_dict"]) self.duration_in_scaler = StandardScaler( np.load(model_dir / "in_duration_scaler_mean.npy"), np.load(model_dir / "in_duration_scaler_var.npy"), np.load(model_dir / "in_duration_scaler_scale.npy"), ) self.duration_out_scaler = StandardScaler( np.load(model_dir / "out_duration_scaler_mean.npy"), np.load(model_dir / "out_duration_scaler_var.npy"), np.load(model_dir / "out_duration_scaler_scale.npy"), ) self.duration_model.eval() # 対数基本周波数の予測モデル self.logf0_config = OmegaConf.load(model_dir / "logf0_model.yaml") self.logf0_model = instantiate(self.logf0_config.netG).to(device) checkpoint = torch.load( model_dir / "logf0_model.pth", map_location=device, ) self.logf0_model.load_state_dict(checkpoint["state_dict"]) self.logf0_in_scaler = StandardScaler( np.load(model_dir / "in_logf0_scaler_mean.npy"), np.load(model_dir / "in_logf0_scaler_var.npy"), np.load(model_dir / "in_logf0_scaler_scale.npy"), ) self.logf0_out_scaler = StandardScaler( np.load(model_dir / "out_logf0_scaler_mean.npy"), np.load(model_dir / "out_logf0_scaler_var.npy"), np.load(model_dir / "out_logf0_scaler_scale.npy"), ) self.logf0_model.eval() # WaveNet self.wavenet_config = OmegaConf.load(model_dir / "wavenet_model.yaml") self.wavenet_model = instantiate(self.wavenet_config.netG).to(device) checkpoint = torch.load( model_dir / "wavenet_model.pth", map_location=device, ) self.wavenet_model.load_state_dict(checkpoint["state_dict"]) self.wavenet_in_scaler = StandardScaler( np.load(model_dir / "in_wavenet_scaler_mean.npy"), np.load(model_dir / "in_wavenet_scaler_var.npy"), np.load(model_dir / "in_wavenet_scaler_scale.npy"), ) self.wavenet_model.eval() self.wavenet_model.remove_weight_norm_() def __repr__(self): duration_str = json.dumps( OmegaConf.to_container(self.duration_config.netG), sort_keys=False, indent=4, ) logf0_str = json.dumps( OmegaConf.to_container(self.logf0_config.netG), sort_keys=False, indent=4 ) wavenet_str = json.dumps( OmegaConf.to_container(self.wavenet_config.netG), sort_keys=False, indent=4, ) return f"""WaveNet TTS (sampling rate: {self.sample_rate}) Duration model: {duration_str} Log-f0 model: {logf0_str} WaveNet: {wavenet_str} """
[docs] def set_device(self, device): """Set device for the TTS models Args: device (str): cpu or cuda. """ self.device = device self.duration_model.to(device) self.logf0_model.to(device) self.wavenet_model.to(device)
[docs] @torch.no_grad() def tts(self, text, tqdm=tqdm): """Run TTS Args: text (str): Input text tqdm (object, optional): tqdm object. Defaults to None. Returns: tuple: audio array (np.int16) and sampling rate (int) """ # OpenJTalkを用いて言語特徴量の抽出 contexts = pyopenjtalk.extract_fullcontext(text) # HTS 形式に変換 if hasattr(hts.HTSLabelFile, "create_from_contexts"): labels = hts.HTSLabelFile.create_from_contexts(contexts) else: labels = hts.load(None, contexts) # 音素継続長の予測 durations = predict_duration( self.device, labels, self.duration_model, self.duration_config, self.duration_in_scaler, self.duration_out_scaler, self.binary_dict, self.numeric_dict, ) labels.set_durations(durations) # 対数基本周波数(および有声無声フラグ)の予測 logf0_vuv = predict_acoustic( self.device, labels, self.logf0_model, self.logf0_config, self.logf0_in_scaler, self.logf0_out_scaler, self.binary_dict, self.numeric_dict, mlpg=False, ) # WaveNetによる音声波形の生成 wav = gen_waveform( self.device, labels, logf0_vuv, self.wavenet_model, self.wavenet_in_scaler, self.binary_dict, self.numeric_dict, tqdm, ) return self.post_process(wav), self.sample_rate
def post_process(self, wav): wav = np.clip(wav, -1.0, 1.0) wav = (wav * 32767.0).astype(np.int16) return wav