Source code for ttslearn.tacotron.decoder

# Acknowledgement: some of the code was adapted from ESPnet
#  Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


import torch
import torch.nn.functional as F
from torch import nn
from ttslearn.tacotron.attention import LocationSensitiveAttention
from ttslearn.util import make_pad_mask


def decoder_init(m):
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight, nn.init.calculate_gain("tanh"))


class ZoneOutCell(nn.Module):
    def __init__(self, cell, zoneout=0.1):
        super().__init__()
        self.cell = cell
        self.hidden_size = cell.hidden_size
        self.zoneout = zoneout

    def forward(self, inputs, hidden):
        next_hidden = self.cell(inputs, hidden)
        next_hidden = self._zoneout(hidden, next_hidden, self.zoneout)
        return next_hidden

    def _zoneout(self, h, next_h, prob):
        h_0, c_0 = h
        h_1, c_1 = next_h
        h_1 = self._apply_zoneout(h_0, h_1, prob)
        c_1 = self._apply_zoneout(c_0, c_1, prob)
        return h_1, c_1

    def _apply_zoneout(self, h, next_h, prob):
        if self.training:
            mask = h.new(*h.size()).bernoulli_(prob)
            return mask * h + (1 - mask) * next_h
        else:
            return prob * h + (1 - prob) * next_h


[docs]class Prenet(nn.Module): """Pre-Net of Tacotron/Tacotron 2. Args: in_dim (int) : dimension of input layers (int) : number of pre-net layers hidden_dim (int) : dimension of hidden layer dropout (float) : dropout rate """ def __init__(self, in_dim, layers=2, hidden_dim=256, dropout=0.5): super().__init__() self.dropout = dropout prenet = nn.ModuleList() for layer in range(layers): prenet += [ nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim), nn.ReLU(), ] self.prenet = nn.Sequential(*prenet)
[docs] def forward(self, x): """Forward step Args: x (torch.Tensor) : input tensor Returns: torch.Tensor : output tensor """ for layer in self.prenet: # 学習時、推論時の両方で Dropout を適用します」 x = F.dropout(layer(x), self.dropout, training=True) return x
[docs]class Decoder(nn.Module): """Decoder of Tacotron 2. Args: encoder_hidden_dim (int) : dimension of encoder hidden layer out_dim (int) : dimension of output layers (int) : number of LSTM layers hidden_dim (int) : dimension of hidden layer prenet_layers (int) : number of pre-net layers prenet_hidden_dim (int) : dimension of pre-net hidden layer prenet_dropout (float) : dropout rate of pre-net zoneout (float) : zoneout rate reduction_factor (int) : reduction factor attention_hidden_dim (int) : dimension of attention hidden layer attention_conv_channel (int) : number of attention convolution channels attention_conv_kernel_size (int) : kernel size of attention convolution """ def __init__( self, encoder_hidden_dim=512, # エンコーダの隠れ層の次元数 out_dim=80, # 出力の次元数 layers=2, # LSTM 層の数 hidden_dim=1024, # LSTM層の次元数 prenet_layers=2, # Pre-Net の層の数 prenet_hidden_dim=256, # Pre-Net の隠れ層の次元数 prenet_dropout=0.5, # Pre-Net の Dropout 率 zoneout=0.1, # Zoneout 率 reduction_factor=1, # Reduction factor attention_hidden_dim=128, # アテンション層の次元数 attention_conv_channels=32, # アテンションの畳み込みのチャネル数 attention_conv_kernel_size=31, # アテンションの畳み込みのカーネルサイズ ): super().__init__() self.out_dim = out_dim # 注意機構 self.attention = LocationSensitiveAttention( encoder_hidden_dim, hidden_dim, attention_hidden_dim, attention_conv_channels, attention_conv_kernel_size, ) self.reduction_factor = reduction_factor # Pre-Net self.prenet = Prenet(out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout) # 片方向 LSTM self.lstm = nn.ModuleList() for layer in range(layers): lstm = nn.LSTMCell( encoder_hidden_dim + prenet_hidden_dim if layer == 0 else hidden_dim, hidden_dim, ) self.lstm += [ZoneOutCell(lstm, zoneout)] # 出力への projection 層 proj_in_dim = encoder_hidden_dim + hidden_dim self.feat_out = nn.Linear(proj_in_dim, out_dim * reduction_factor, bias=False) self.prob_out = nn.Linear(proj_in_dim, reduction_factor) self.apply(decoder_init) def _zero_state(self, hs): init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) return init_hs
[docs] def forward(self, encoder_outs, in_lens, decoder_targets=None): """Forward step Args: encoder_outs (torch.Tensor) : encoder outputs in_lens (torch.Tensor) : input lengths decoder_targets (torch.Tensor) : decoder targets for teacher-forcing. Returns: tuple: tuple of outputs, stop token prediction, and attention weights """ is_inference = decoder_targets is None # Reduction factor に基づくフレーム数の調整 # (B, Lmax, out_dim) -> (B, Lmax/r, out_dim) if self.reduction_factor > 1 and not is_inference: decoder_targets = decoder_targets[ :, self.reduction_factor - 1 :: self.reduction_factor ] # デコーダの系列長を保持 # 推論時は、エンコーダの系列長から経験的に上限を定める if is_inference: max_decoder_time_steps = int(encoder_outs.shape[1] * 10.0) else: max_decoder_time_steps = decoder_targets.shape[1] # ゼロパディングされた部分に対するマスク mask = make_pad_mask(in_lens).to(encoder_outs.device) # LSTM の状態をゼロで初期化 h_list, c_list = [], [] for _ in range(len(self.lstm)): h_list.append(self._zero_state(encoder_outs)) c_list.append(self._zero_state(encoder_outs)) # デコーダの最初の入力 go_frame = encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim) prev_out = go_frame # 1 つ前の時刻のアテンション重み prev_att_w = None self.attention.reset() # メインループ outs, logits, att_ws = [], [], [] t = 0 while True: # コンテキストベクトル、アテンション重みの計算 att_c, att_w = self.attention( encoder_outs, in_lens, h_list[0], prev_att_w, mask ) # Pre-Net prenet_out = self.prenet(prev_out) # LSTM xs = torch.cat([att_c, prenet_out], dim=1) h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0])) for i in range(1, len(self.lstm)): h_list[i], c_list[i] = self.lstm[i]( h_list[i - 1], (h_list[i], c_list[i]) ) # 出力の計算 hcs = torch.cat([h_list[-1], att_c], dim=1) outs.append(self.feat_out(hcs).view(encoder_outs.size(0), self.out_dim, -1)) logits.append(self.prob_out(hcs)) att_ws.append(att_w) # 次の時刻のデコーダの入力を更新 if is_inference: prev_out = outs[-1][:, :, -1] # (1, out_dim) else: # Teacher forcing prev_out = decoder_targets[:, t, :] # 累積アテンション重み prev_att_w = att_w if prev_att_w is None else prev_att_w + att_w t += 1 # 停止条件のチェック if t >= max_decoder_time_steps: break if is_inference and (torch.sigmoid(logits[-1]) >= 0.5).any(): break # 各時刻の出力を結合 logits = torch.cat(logits, dim=1) # (B, Lmax) outs = torch.cat(outs, dim=2) # (B, out_dim, Lmax) att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) if self.reduction_factor > 1: outs = outs.view(outs.size(0), self.out_dim, -1) # (B, out_dim, Lmax) return outs, logits, att_ws