import shutil
from pathlib import Path
import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
from torch import nn, optim
from torch.utils import data as data_utils
from torch.utils.tensorboard import SummaryWriter
from ttslearn.logger import getLogger
from ttslearn.util import init_seed, load_utt_list, pad_1d, pad_2d
[docs]def get_epochs_with_optional_tqdm(tqdm_mode, nepochs):
"""Get epochs with optional progress bar.
Args:
tqdm_mode (str): Progress bar mode.
nepochs (int): Number of epochs.
Returns:
iterable: Epochs.
"""
if tqdm_mode == "tqdm":
from tqdm import tqdm
epochs = tqdm(range(1, nepochs + 1), desc="epoch")
else:
epochs = range(1, nepochs + 1)
return epochs
[docs]def moving_average_(model, model_test, beta=0.9999):
"""Exponential moving average (EMA) of model parameters.
Args:
model (torch.nn.Module): Model to perform EMA on.
model_test (torch.nn.Module): Model to use for the test phase.
beta (float, optional): [description]. Defaults to 0.9999.
"""
for param, param_test in zip(model.parameters(), model_test.parameters()):
param_test.data = torch.lerp(param.data, param_test.data, beta)
[docs]def num_trainable_params(model):
"""Count the number of trainable parameters in the model.
Args:
model (torch.nn.Module): Model to count the number of trainable parameters.
Returns:
int: Number of trainable parameters.
"""
parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in parameters])
[docs]class Dataset(data_utils.Dataset): # type: ignore
"""Dataset for numpy files
Args:
in_paths (list): List of paths to input files
out_paths (list): List of paths to output files
"""
def __init__(self, in_paths, out_paths):
self.in_paths = in_paths
self.out_paths = out_paths
[docs] def __getitem__(self, idx):
"""Get a pair of input and target
Args:
idx (int): index of the pair
Returns:
tuple: input and target in numpy format
"""
return np.load(self.in_paths[idx]), np.load(self.out_paths[idx])
[docs] def __len__(self):
"""Returns the size of the dataset
Returns:
int: size of the dataset
"""
return len(self.in_paths)
[docs]def get_data_loaders(data_config, collate_fn):
"""Get data loaders for training and validation.
Args:
data_config (dict): Data configuration.
collate_fn (callable): Collate function.
Returns:
dict: Data loaders.
"""
data_loaders = {}
for phase in ["train", "dev"]:
utt_ids = load_utt_list(to_absolute_path(data_config[phase].utt_list))
in_dir = Path(to_absolute_path(data_config[phase].in_dir))
out_dir = Path(to_absolute_path(data_config[phase].out_dir))
in_feats_paths = [in_dir / f"{utt_id}-feats.npy" for utt_id in utt_ids]
out_feats_paths = [out_dir / f"{utt_id}-feats.npy" for utt_id in utt_ids]
dataset = Dataset(in_feats_paths, out_feats_paths)
data_loaders[phase] = data_utils.DataLoader(
dataset,
batch_size=data_config.batch_size,
collate_fn=collate_fn,
pin_memory=True,
num_workers=data_config.num_workers,
shuffle=phase.startswith("train"),
)
return data_loaders
[docs]def collate_fn_dnntts(batch):
"""Collate function for DNN-TTS.
Args:
batch (list): List of tuples of the form (inputs, targets).
Returns:
tuple: Batch of inputs, targets, and lengths.
"""
lengths = [len(x[0]) for x in batch]
max_len = max(lengths)
x_batch = torch.stack([torch.from_numpy(pad_2d(x[0], max_len)) for x in batch])
y_batch = torch.stack([torch.from_numpy(pad_2d(x[1], max_len)) for x in batch])
l_batch = torch.tensor(lengths, dtype=torch.long)
return x_batch, y_batch, l_batch
[docs]def collate_fn_wavenet(batch, max_time_frames=100, hop_size=80, aux_context_window=2):
"""Collate function for WaveNet.
Args:
batch (list): List of tuples of the form (inputs, targets).
max_time_frames (int, optional): Number of time frames. Defaults to 100.
hop_size (int, optional): Hop size. Defaults to 80.
aux_context_window (int, optional): Auxiliary context window. Defaults to 2.
Returns:
tuple: Batch of waveforms and conditional features.
"""
max_time_steps = max_time_frames * hop_size
xs, cs = [b[1] for b in batch], [b[0] for b in batch]
# 条件付け特徴量の開始位置をランダム抽出した後、それに相当する短い音声波形を切り出します
c_lengths = [len(c) for c in cs]
start_frames = np.array(
[
np.random.randint(
aux_context_window, cl - aux_context_window - max_time_frames
)
for cl in c_lengths
]
)
x_starts = start_frames * hop_size
x_ends = x_starts + max_time_steps
c_starts = start_frames - aux_context_window
c_ends = start_frames + max_time_frames + aux_context_window
x_cut = [x[s:e] for x, s, e in zip(xs, x_starts, x_ends)]
c_cut = [c[s:e] for c, s, e in zip(cs, c_starts, c_ends)]
# numpy.ndarray のリスト型から torch.Tensor 型に変換します
x_batch = torch.tensor(x_cut, dtype=torch.long) # (B, T)
c_batch = torch.tensor(c_cut, dtype=torch.float).transpose(2, 1) # (B, C, T')
return x_batch, c_batch
def ensure_divisible_by(feats, N):
"""Ensure that the number of frames is divisible by N.
Args:
feats (np.ndarray): Input features.
N (int): Target number of frames.
Returns:
np.ndarray: Input features with number of frames divisible by N.
"""
if N == 1:
return feats
mod = len(feats) % N
if mod != 0:
feats = feats[: len(feats) - mod]
return feats
[docs]def collate_fn_tacotron(batch, reduction_factor=1):
"""Collate function for Tacotron.
Args:
batch (list): List of tuples of the form (inputs, targets).
reduction_factor (int, optional): Reduction factor. Defaults to 1.
Returns:
tuple: Batch of inputs, input lengths, targets, target lengths and stop flags.
"""
xs = [x[0] for x in batch]
ys = [ensure_divisible_by(x[1], reduction_factor) for x in batch]
in_lens = [len(x) for x in xs]
out_lens = [len(y) for y in ys]
in_max_len = max(in_lens)
out_max_len = max(out_lens)
x_batch = torch.stack([torch.from_numpy(pad_1d(x, in_max_len)) for x in xs])
y_batch = torch.stack([torch.from_numpy(pad_2d(y, out_max_len)) for y in ys])
il_batch = torch.tensor(in_lens, dtype=torch.long)
ol_batch = torch.tensor(out_lens, dtype=torch.long)
stop_flags = torch.zeros(y_batch.shape[0], y_batch.shape[1])
for idx, out_len in enumerate(out_lens):
stop_flags[idx, out_len - 1 :] = 1.0
return x_batch, il_batch, y_batch, ol_batch, stop_flags
def set_epochs_based_on_max_steps_(train_config, steps_per_epoch, logger):
"""Set epochs based on max steps.
Args:
train_config (TrainConfig): Train config.
steps_per_epoch (int): Number of steps per epoch.
logger (logging.Logger): Logger.
"""
logger.info(f"Number of iterations per epoch: {steps_per_epoch}")
if train_config.max_train_steps < 0:
# Set max_train_steps based on nepochs
max_train_steps = train_config.nepochs * steps_per_epoch
train_config.max_train_steps = max_train_steps
logger.info(
"Number of max_train_steps is set based on nepochs: {}".format(
max_train_steps
)
)
else:
# Set nepochs based on max_train_steps
max_train_steps = train_config.max_train_steps
epochs = int(np.ceil(max_train_steps / steps_per_epoch))
train_config.nepochs = epochs
logger.info(
"Number of epochs is set based on max_train_steps: {}".format(epochs)
)
logger.info(f"Number of epochs: {train_config.nepochs}")
logger.info(f"Number of iterations: {train_config.max_train_steps}")
[docs]def save_checkpoint(
logger, out_dir, model, optimizer, epoch, is_best=False, postfix=""
):
"""Save a checkpoint.
Args:
logger (logging.Logger): Logger.
out_dir (str): Output directory.
model (nn.Module): Model.
optimizer (Optimizer): Optimizer.
epoch (int): Current epoch.
is_best (bool, optional): Whether or not the current model is the best.
Defaults to False.
postfix (str, optional): Postfix. Defaults to "".
"""
if isinstance(model, nn.DataParallel):
model = model.module
out_dir.mkdir(parents=True, exist_ok=True)
if is_best:
path = out_dir / f"best_loss{postfix}.pth"
else:
path = out_dir / "epoch{:04d}{}.pth".format(epoch, postfix)
torch.save(
{
"state_dict": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
},
path,
)
logger.info(f"Saved checkpoint at {path}")
if not is_best:
shutil.copyfile(path, out_dir / f"latest{postfix}.pth")
[docs]def plot_attention(alignment):
"""Plot attention.
Args:
alignment (np.ndarray): Attention.
"""
fig, ax = plt.subplots()
alignment = alignment.cpu().data.numpy().T
im = ax.imshow(alignment, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
plt.xlabel("Decoder time step")
plt.ylabel("Encoder time step")
return fig
[docs]def plot_2d_feats(feats, title=None):
"""Plot 2D features.
Args:
feats (np.ndarray): Input features.
title (str, optional): Title. Defaults to None.
"""
feats = feats.cpu().data.numpy().T
fig, ax = plt.subplots()
im = ax.imshow(
feats, aspect="auto", origin="lower", interpolation="none", cmap="viridis"
)
fig.colorbar(im, ax=ax)
if title is not None:
ax.set_title(title)
return fig
[docs]def setup(config, device, collate_fn):
"""Setup for traiining
Args:
config (dict): configuration for training
device (torch.device): device to use for training
collate_fn (callable): function to collate mini-batches
Returns:
(tuple): tuple containing model, optimizer, learning rate scheduler,
data loaders, tensorboard writer, and logger.
.. note::
書籍に記載のコードは、この関数を一部簡略化しています。
"""
# NOTE: hydra は内部で stream logger を追加するので、二重に追加しないことに注意
logger = getLogger(config.verbose, add_stream_handler=False)
logger.info(f"PyTorch version: {torch.__version__}")
# CUDA 周りの設定
if torch.cuda.is_available():
from torch.backends import cudnn
cudnn.benchmark = config.cudnn.benchmark
cudnn.deterministic = config.cudnn.deterministic
logger.info(f"cudnn.deterministic: {cudnn.deterministic}")
logger.info(f"cudnn.benchmark: {cudnn.benchmark}")
if torch.backends.cudnn.version() is not None:
logger.info(f"cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"Random seed: {config.seed}")
init_seed(config.seed)
# モデルのインスタンス化
model = hydra.utils.instantiate(config.model.netG).to(device)
logger.info(model)
logger.info(
"Number of trainable params: {:.3f} million".format(
num_trainable_params(model) / 1000000.0
)
)
# (optional) 学習済みモデルの読み込み
# ファインチューニングしたい場合
pretrained_checkpoint = config.train.pretrained.checkpoint
if pretrained_checkpoint is not None and len(pretrained_checkpoint) > 0:
logger.info(
"Fine-tuning! Loading a checkpoint: {}".format(pretrained_checkpoint)
)
checkpoint = torch.load(pretrained_checkpoint, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
# 複数 GPU 対応
if config.data_parallel:
model = nn.DataParallel(model)
# Optimizer
optimizer_class = getattr(optim, config.train.optim.optimizer.name)
optimizer = optimizer_class(
model.parameters(), **config.train.optim.optimizer.params
)
# 学習率スケジューラ
lr_scheduler_class = getattr(
optim.lr_scheduler, config.train.optim.lr_scheduler.name
)
lr_scheduler = lr_scheduler_class(
optimizer, **config.train.optim.lr_scheduler.params
)
# DataLoader
data_loaders = get_data_loaders(config.data, collate_fn)
set_epochs_based_on_max_steps_(config.train, len(data_loaders["train"]), logger)
# Tensorboard の設定
writer = SummaryWriter(to_absolute_path(config.train.log_dir))
# config ファイルを保存しておく
out_dir = Path(to_absolute_path(config.train.out_dir))
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "model.yaml", "w") as f:
OmegaConf.save(config.model, f)
with open(out_dir / "config.yaml", "w") as f:
OmegaConf.save(config, f)
return model, optimizer, lr_scheduler, data_loaders, writer, logger