Source code for ttslearn.pretrained

import os
import shutil
import tarfile
from os.path import join
from pathlib import Path
from urllib.request import urlretrieve

from tqdm.auto import tqdm
from ttslearn.util import dynamic_import

_urls = {
    "v0.2.0": "https://github.com/r9y9/ttslearn/releases/download/v0.2.0",
    "v0.2.1": "https://github.com/r9y9/ttslearn/releases/download/v0.2.1",
}

DEFAULT_CACHE_DIR = join(os.path.expanduser("~"), ".cache", "ttslearn")
CACHE_DIR = os.environ.get("TTSLEARN_CACHE_DIR", DEFAULT_CACHE_DIR)


model_registry = {
    # v0.2.0
    "dnntts": {
        "url": f"{_urls['v0.2.0']}/dnntts.tar.gz",
        "_target_": "ttslearn.dnntts:DNNTTS",
    },
    "wavenettts": {
        "url": f"{_urls['v0.2.0']}/wavenettts.tar.gz",
        "_target_": "ttslearn.wavenet:WaveNetTTS",
    },
    "tacotron2": {
        "url": f"{_urls['v0.2.0']}/tacotron2.tar.gz",
        "_target_": "ttslearn.tacotron:Tacotron2TTS",
    },
    "tacotron2_pwg_jsut16k": {
        "url": f"{_urls['v0.2.0']}/tacotron2_pwg_jsut16k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "tacotron2_pwg_jsut24k": {
        "url": f"{_urls['v0.2.0']}/tacotron2_pwg_jsut24k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "multspk_tacotron2_pwg_jvs16k": {
        "url": f"{_urls['v0.2.0']}/multspk_tacotron2_pwg_jvs16k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "multspk_tacotron2_pwg_jvs24k": {
        "url": f"{_urls['v0.2.0']}/multspk_tacotron2_pwg_jvs24k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    # v0.2.1
    "tacotron2_hifipwg_jsut24k": {
        "url": f"{_urls['v0.2.1']}/tacotron2_hifipwg_jsut24k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "multspk_tacotron2_hifipwg_jvs24k": {
        "url": f"{_urls['v0.2.1']}/multspk_tacotron2_hifipwg_jvs24k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "multspk_tacotron2_pwg_cv16k": {
        "url": f"{_urls['v0.2.1']}/multspk_tacotron2_pwg_cv16k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
    "multspk_tacotron2_pwg_cv24k": {
        "url": f"{_urls['v0.2.1']}/multspk_tacotron2_pwg_cv24k.tar.gz",
        "_target_": "ttslearn.contrib:Tacotron2PWGTTS",
    },
}


[docs]def create_tts_engine(name, *args, **kwargs): """Create TTS engine from official pretrained models. Args: name (str): Pre-trained model name args (list): Additional args for instantiation kwargs (dict): Additional kwargs for instantiation Returns: object: instance of TTS engine Examples: >>> from ttslearn.pretrained import create_tts_engine >>> create_tts_engine("dnntts") DNNTTS (sampling rate: 16000) """ if name not in model_registry: s = "" for model_id in get_available_model_ids(): s += f"'{model_id}'\n" raise ValueError( f""" Pretrained model '{name}' does not exist! Available models: {s[:-1]}""" ) # download if not exists model_dir = retrieve_pretrained_model(name) # create an instance return dynamic_import(model_registry[name]["_target_"])(model_dir, *args, **kwargs)
[docs]def get_available_model_ids(): """Get available pretrained model names. Returns: list: List of available pretrained model names. Examples: >>> from ttslearn.pretrained import get_available_model_ids >>> get_available_model_ids()[:3] ['dnntts', 'wavenettts', 'tacotron2'] """ return list(model_registry.keys())
# https://github.com/tqdm/tqdm#hooks-and-callbacks class _TqdmUpTo(tqdm): # type: ignore def update_to(self, b=1, bsize=1, tsize=None): if tsize is not None: self.total = tsize return self.update(b * bsize - self.n) def is_pretrained_model_ready(name): out_dir = Path(CACHE_DIR) / name if out_dir.exists() and len(list(out_dir.glob("*.pth"))) == 0: return False return out_dir.exists()
[docs]def retrieve_pretrained_model(name): """Retrieve pretrained model from local cache or download from GitHub. Args: name (str): Name of pretrained model. Returns: str: Path to the pretrained model. Raises: ValueError: If the pretrained model is not found. Examples: >>> from ttslearn.pretrained import retrieve_pretrained_model >>> from ttslearn.contrib import Tacotron2PWGTTS >>> model_dir = retrieve_pretrained_model("tacotron2_pwg_jsut24k") >>> engine = Tacotron2PWGTTS(model_dir=model_dir, device="cpu") >>> wav, sr = engine.tts("センパイ、かっこいいです、ほれちゃいます!") """ global model_registry if name not in model_registry: s = "" for model_id in get_available_model_ids(): s += f"'{model_id}'\n" raise ValueError( f""" Pretrained model '{name}' does not exist! Available models: {s[:-1]}""" ) url = model_registry[name]["url"] # NOTE: assuming that filename and extracted is the same out_dir = Path(CACHE_DIR) / name out_dir.mkdir(parents=True, exist_ok=True) filename = Path(CACHE_DIR) / f"{name}.tar.gz" # re-download models if out_dir.exists() and len(list(out_dir.glob("*.pth"))) == 0: shutil.rmtree(out_dir) if not out_dir.exists(): print( """The use of pre-trained models is permitted for non-commercial use only. Please visit https://github.com/r9y9/ttslearn to confirm the license.""" ) print('Downloading: "{}"'.format(url)) with _TqdmUpTo( unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"{name}.tar.gz", ) as t: # all optional kwargs urlretrieve(url, filename, reporthook=t.update_to) t.total = t.n with tarfile.open(filename, mode="r|gz") as f: f.extractall(path=CACHE_DIR) os.remove(filename) return out_dir