第9章 Tacotron 2: 一貫学習を狙った音声合成

Open In Colab

準備

Python version

[1]:
!python -VV
Python 3.8.6 | packaged by conda-forge | (default, Dec 26 2020, 05:05:16)
[GCC 9.3.0]

ttslearn のインストール

[2]:
%%capture
try:
    import ttslearn
except ImportError:
    !pip install ttslearn
[3]:
import ttslearn
ttslearn.__version__
[3]:
'0.2.1'

パッケージのインポート

[4]:
%pylab inline
%load_ext autoreload
%load_ext tensorboard
%autoreload
import IPython
from IPython.display import Audio
import tensorboard as tb
import os
Populating the interactive namespace from numpy and matplotlib
[5]:
# 数値演算
import numpy as np
import torch
from torch import nn
# 音声波形の読み込み
from scipy.io import wavfile
# フルコンテキストラベル、質問ファイルの読み込み
from nnmnkwii.io import hts
# 音声分析
import pyworld
# 音声分析、可視化
import librosa
import librosa.display
# Pythonで学ぶ音声合成
import ttslearn
[6]:
# シードの固定
from ttslearn.util import init_seed
init_seed(773)
[7]:
torch.__version__
[7]:
'1.8.1'

描画周りの設定

[8]:
from ttslearn.notebook import get_cmap, init_plot_style, savefig
cmap = get_cmap()
init_plot_style()

9.3 エンコーダ

文字列から数値列への変換

[9]:
# 語彙の定義
characters = "abcdefghijklmnopqrstuvwxyz!'(),-.:;? "
# その他特殊記号
extra_symbols = [
    "^",  # 文の先頭を表す特殊記号 <SOS>
    "$",  # 文の末尾を表す特殊記号 <EOS>
]
_pad = "~"

# NOTE: パディングを 0 番目に配置
symbols = [_pad] + extra_symbols + list(characters)

# 文字列⇔数値の相互変換のための辞書
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
[10]:
len(symbols)
[10]:
40
[11]:
def text_to_sequence(text):
    # 簡易のため、大文字と小文字を区別せず、全ての大文字を小文字に変換
    text = text.lower()

    # <SOS>
    seq = [_symbol_to_id["^"]]

    # 本文
    seq += [_symbol_to_id[s] for s in text]

    # <EOS>
    seq.append(_symbol_to_id["$"])

    return seq


def sequence_to_text(seq):
    return [_id_to_symbol[s] for s in seq]
[12]:
seq = text_to_sequence("Hello!")
print(f"文字列から数値列への変換: {seq}")
print(f"数値列から文字列への逆変換: {sequence_to_text(seq)}")
文字列から数値列への変換: [1, 10, 7, 14, 14, 17, 29, 2]
数値列から文字列への逆変換: ['^', 'h', 'e', 'l', 'l', 'o', '!', '$']

文字埋め込み

[13]:
class SimplestEncoder(nn.Module):
    def __init__(self, num_vocab=40, embed_dim=256):
        super().__init__()
        self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)

    def forward(self, seqs):
        return self.embed(seqs)
[14]:
SimplestEncoder()
[14]:
SimplestEncoder(
  (embed): Embedding(40, 256, padding_idx=0)
)
[15]:
from ttslearn.util import pad_1d

def get_dummy_input():
    # バッチサイズに 2 を想定して、適当な文字列を作成
    seqs = [
        text_to_sequence("What is your favorite language?"),
        text_to_sequence("Hello world."),
    ]
    in_lens = torch.tensor([len(x) for x in seqs], dtype=torch.long)
    max_len = max(len(x) for x in seqs)
    seqs = torch.stack([torch.from_numpy(pad_1d(seq, max_len)) for seq in seqs])

    return seqs, in_lens
[16]:
seqs, in_lens = get_dummy_input()
print("入力", seqs)
print("系列長:", in_lens)
入力 tensor([[ 1, 25, 10,  3, 22, 39, 11, 21, 39, 27, 17, 23, 20, 39,  8,  3, 24, 17,
         20, 11, 22,  7, 39, 14,  3, 16,  9, 23,  3,  9,  7, 38,  2],
        [ 1, 10,  7, 14, 14, 17, 39, 25, 17, 20, 14,  6, 35,  2,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
系列長: tensor([33, 14])
[17]:
encoder = SimplestEncoder(num_vocab=40, embed_dim=256)
seqs, in_lens = get_dummy_input()
encoder_outs = encoder(seqs)
print(f"入力のサイズ: {tuple(seqs.shape)}")
print(f"出力のサイズ: {tuple(encoder_outs.shape)}")
入力のサイズ: (2, 33)
出力のサイズ: (2, 33, 256)
[18]:
# パディングの部分は0を取り、それ以外は連続値で表されます
encoder_outs
[18]:
tensor([[[ 0.7055,  0.6891,  0.0332,  ...,  0.7174,  0.4686,  1.1468],
         [-0.1568, -0.3719, -1.0086,  ..., -0.9326, -1.2187, -0.0714],
         [-0.1901, -0.1983,  0.2274,  ...,  0.2284,  1.6452, -0.3408],
         ...,
         [-1.9353,  0.2628, -0.1449,  ...,  1.6056, -0.3912, -0.0740],
         [ 0.4687,  0.3258, -0.6565,  ...,  1.0895,  0.9105,  0.2814],
         [ 0.8940,  0.3002, -0.2105,  ...,  0.7973,  0.2230, -0.1975]],

        [[ 0.7055,  0.6891,  0.0332,  ...,  0.7174,  0.4686,  1.1468],
         [-0.1901, -0.1983,  0.2274,  ...,  0.2284,  1.6452, -0.3408],
         [-1.9353,  0.2628, -0.1449,  ...,  1.6056, -0.3912, -0.0740],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<EmbeddingBackward>)

1次元畳み込みの導入

[19]:
class ConvEncoder(nn.Module):
    def __init__(
        self,
        num_vocab=40,
        embed_dim=256,
        conv_layers=3,
        conv_channels=256,
        conv_kernel_size=5,
    ):
        super().__init__()
        # 文字埋め込み
        self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)

        # 1次元畳み込みの重ね合わせ:局所的な依存関係のモデル化
        self.convs = nn.ModuleList()
        for layer in range(conv_layers):
            in_channels = embed_dim if layer == 0 else conv_channels
            self.convs += [
                nn.Conv1d(
                    in_channels,
                    conv_channels,
                    conv_kernel_size,
                    padding=(conv_kernel_size - 1) // 2,
                    bias=False,
                ),
                nn.BatchNorm1d(conv_channels),
                nn.ReLU(),
                nn.Dropout(0.5),
            ]
        self.convs = nn.Sequential(*self.convs)

    def forward(self, seqs):
        emb = self.embed(seqs)
        # 1 次元畳み込みと embedding では、入力のサイズが異なるので注意
        out = self.convs(emb.transpose(1, 2)).transpose(1, 2)
        return out
[20]:
ConvEncoder()
[20]:
ConvEncoder(
  (embed): Embedding(40, 256, padding_idx=0)
  (convs): Sequential(
    (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.5, inplace=False)
  )
)
[21]:
encoder = ConvEncoder(num_vocab=40, embed_dim=256)
seqs, in_lens = get_dummy_input()
encoder_outs = encoder(seqs)
print(f"入力のサイズ: {tuple(seqs.shape)}")
print(f"出力のサイズ: {tuple(encoder_outs.shape)}")
入力のサイズ: (2, 33)
出力のサイズ: (2, 33, 256)

双方向LSTM の導入

[22]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Encoder(ConvEncoder):
    def __init__(
        self,
        num_vocab=40,
        embed_dim=512,
        hidden_dim=512,
        conv_layers=3,
        conv_channels=512,
        conv_kernel_size=5,
    ):
        super().__init__(
            num_vocab, embed_dim, conv_layers, conv_channels, conv_kernel_size
        )
        # 双方向 LSTM による長期依存関係のモデル化
        self.blstm = nn.LSTM(
            conv_channels, hidden_dim // 2, 1, batch_first=True, bidirectional=True
        )

    def forward(self, seqs, in_lens):
        emb = self.embed(seqs)
        # 1 次元畳み込みと embedding では、入力のサイズ が異なるので注意
        out = self.convs(emb.transpose(1, 2)).transpose(1, 2)

        # 双方向 LSTM の計算
        out = pack_padded_sequence(out, in_lens, batch_first=True)
        out, _ = self.blstm(out)
        out, _ = pad_packed_sequence(out, batch_first=True)
        return out
[23]:
Encoder()
[23]:
Encoder(
  (embed): Embedding(40, 512, padding_idx=0)
  (convs): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.5, inplace=False)
  )
  (blstm): LSTM(512, 256, batch_first=True, bidirectional=True)
)
[24]:
encoder = Encoder(num_vocab=40, embed_dim=256)
seqs, in_lens = get_dummy_input()
in_lens, indices = torch.sort(in_lens, dim=0, descending=True)
seqs = seqs[indices]

encoder_outs = encoder(seqs, in_lens)
print(f"入力のサイズ: {tuple(seqs.shape)}")
print(f"出力のサイズ: {tuple(encoder_outs.shape)}")
入力のサイズ: (2, 33)
出力のサイズ: (2, 33, 512)

9.4 注意機構

内容依存の注意機構

[25]:
from torch.nn import functional as F

# 書籍中の数式に沿って、わかりやすさを重視した実装
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim=512, decoder_dim=1024, hidden_dim=128):
        super().__init__()
        self.V = nn.Linear(encoder_dim, hidden_dim)
        self.W = nn.Linear(decoder_dim, hidden_dim, bias=False)
        # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません
        self.w = nn.Linear(hidden_dim, 1)

    def forward(self, encoder_out, decoder_state, mask=None):
        # 式 (9.11) の計算
        erg = self.w(
            torch.tanh(self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs))
        ).squeeze(-1)

        if mask is not None:
            erg.masked_fill_(mask, -float("inf"))

        attention_weights = F.softmax(erg, dim=1)

        # エンコーダ出力の長さ方向に対して重み付き和を取ります
        attention_context = torch.sum(
            encoder_outs * attention_weights.unsqueeze(-1), dim=1
        )

        return attention_context, attention_weights
[26]:
BahdanauAttention()
[26]:
BahdanauAttention(
  (V): Linear(in_features=512, out_features=128, bias=True)
  (W): Linear(in_features=1024, out_features=128, bias=False)
  (w): Linear(in_features=128, out_features=1, bias=True)
)
[27]:
from ttslearn.util import make_pad_mask

mask =  make_pad_mask(in_lens).to(encoder_outs.device)
attention = BahdanauAttention()

decoder_input = torch.ones(len(seqs), 1024)

attention_context, attention_weights = attention(encoder_outs, decoder_input, mask)

print(f"エンコーダの出力のサイズ: {tuple(encoder_outs.shape)}")
print(f"デコーダの隠れ状態のサイズ: {tuple(decoder_input.shape)}")
print(f"コンテキストベクトルのサイズ: {tuple(attention_context.shape)}")
print(f"アテンション重みのサイズ: {tuple(attention_weights.shape)}")
エンコーダの出力のサイズ: (2, 33, 512)
デコーダの隠れ状態のサイズ: (2, 1024)
コンテキストベクトルのサイズ: (2, 512)
アテンション重みのサイズ: (2, 33)

ハイブリッド注意機構

[28]:
class LocationSensitiveAttention(nn.Module):
    def __init__(
        self,
        encoder_dim=512,
        decoder_dim=1024,
        hidden_dim=128,
        conv_channels=32,
        conv_kernel_size=31,
    ):
        super().__init__()
        self.V = nn.Linear(encoder_dim, hidden_dim)
        self.W = nn.Linear(decoder_dim, hidden_dim, bias=False)
        self.U = nn.Linear(conv_channels, hidden_dim, bias=False)
        self.F = nn.Conv1d(
            1,
            conv_channels,
            conv_kernel_size,
            padding=(conv_kernel_size - 1) // 2,
            bias=False,
        )
        # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません
        self.w = nn.Linear(hidden_dim, 1)

    def forward(self, encoder_outs, src_lens, decoder_state, att_prev, mask=None):
        # アテンション重みを一様分布で初期化
        if att_prev is None:
            att_prev = 1.0 - make_pad_mask(src_lens).to(
                device=decoder_state.device, dtype=decoder_state.dtype
            )
            att_prev = att_prev / src_lens.unsqueeze(-1).to(encoder_outs.device)

        # (B x T_enc) -> (B x 1 x T_enc) -> (B x conv_channels x T_enc) ->
        # (B x T_enc x conv_channels)
        f = self.F(att_prev.unsqueeze(1)).transpose(1, 2)

        # 式 (9.13) の計算
        erg = self.w(
            torch.tanh(
                self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs) + self.U(f)
            )
        ).squeeze(-1)

        if mask is not None:
            erg.masked_fill_(mask, -float("inf"))

        attention_weights = F.softmax(erg, dim=1)

        # エンコーダ出力の長さ方向に対して重み付き和を取ります
        attention_context = torch.sum(
            encoder_outs * attention_weights.unsqueeze(-1), dim=1
        )

        return attention_context, attention_weights
[29]:
LocationSensitiveAttention()
[29]:
LocationSensitiveAttention(
  (V): Linear(in_features=512, out_features=128, bias=True)
  (W): Linear(in_features=1024, out_features=128, bias=False)
  (U): Linear(in_features=32, out_features=128, bias=False)
  (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)
  (w): Linear(in_features=128, out_features=1, bias=True)
)
[30]:
from ttslearn.util import make_pad_mask

mask =  make_pad_mask(in_lens).to(encoder_outs.device)
attention = LocationSensitiveAttention()

decoder_input = torch.ones(len(seqs), 1024)

attention_context, attention_weights = attention(encoder_outs, in_lens, decoder_input, None, mask)

print(f"エンコーダの出力のサイズ: {tuple(encoder_outs.shape)}")
print(f"デコーダの隠れ状態のサイズ: {tuple(decoder_input.shape)}")
print(f"コンテキストベクトルのサイズ: {tuple(attention_context.shape)}")
print(f"アテンション重みのサイズ: {tuple(attention_weights.shape)}")
エンコーダの出力のサイズ: (2, 33, 512)
デコーダの隠れ状態のサイズ: (2, 1024)
コンテキストベクトルのサイズ: (2, 512)
アテンション重みのサイズ: (2, 33)

9.5 デコーダ

Pre-Net

[31]:
class Prenet(nn.Module):
    def __init__(self, in_dim, layers=2, hidden_dim=256, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        prenet = nn.ModuleList()
        for layer in range(layers):
            prenet += [
                nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim),
                nn.ReLU(),
            ]
        self.prenet = nn.Sequential(*prenet)

    def forward(self, x):
        for layer in self.prenet:
            # 学習時、推論時の両方で Dropout を適用します
            x = F.dropout(layer(x), self.dropout, training=True)
        return x
[32]:
Prenet(80)
[32]:
Prenet(
  (prenet): Sequential(
    (0): Linear(in_features=80, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
  )
)
[33]:
decoder_input = torch.ones(len(seqs), 80)

prenet = Prenet(80)
out = prenet(decoder_input)
print(f"デコーダの入力のサイズ: {tuple(decoder_input.shape)}")
print(f"Pre-Net の出力のサイズ: {tuple(out.shape)}")
デコーダの入力のサイズ: (2, 80)
Pre-Net の出力のサイズ: (2, 256)

注意機構付きデコーダ

[34]:
from ttslearn.tacotron.decoder import ZoneOutCell

class Decoder(nn.Module):
    def __init__(
        self,
        encoder_hidden_dim=512,
        out_dim=80,
        layers=2,
        hidden_dim=1024,
        prenet_layers=2,
        prenet_hidden_dim=256,
        prenet_dropout=0.5,
        zoneout=0.1,
        reduction_factor=1,
        attention_hidden_dim=128,
        attention_conv_channels=32,
        attention_conv_kernel_size=31,
    ):
        super().__init__()
        self.out_dim = out_dim

        # 注意機構
        self.attention = LocationSensitiveAttention(
            encoder_hidden_dim,
            hidden_dim,
            attention_hidden_dim,
            attention_conv_channels,
            attention_conv_kernel_size,
        )
        self.reduction_factor = reduction_factor

        # Prenet
        self.prenet = Prenet(out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout)

        # 片方向LSTM
        self.lstm = nn.ModuleList()
        for layer in range(layers):
            lstm = nn.LSTMCell(
                encoder_hidden_dim + prenet_hidden_dim if layer == 0 else hidden_dim,
                hidden_dim,
            )
            lstm = ZoneOutCell(lstm, zoneout)
            self.lstm += [lstm]

        # 出力への projection 層
        proj_in_dim = encoder_hidden_dim + hidden_dim
        self.feat_out = nn.Linear(proj_in_dim, out_dim * reduction_factor, bias=False)
        self.prob_out = nn.Linear(proj_in_dim, reduction_factor)

    def _zero_state(self, hs):
        init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
        return init_hs

    def forward(self, encoder_outs, in_lens, decoder_targets=None):
        is_inference = decoder_targets is None

        # Reduction factor に基づくフレーム数の調整
        # (B, Lmax, out_dim) ->  (B, Lmax/r, out_dim)
        if self.reduction_factor > 1 and not is_inference:
            decoder_targets = decoder_targets[
                :, self.reduction_factor - 1 :: self.reduction_factor
            ]

        # デコーダの系列長を保持
        # 推論時は、エンコーダの系列長から経験的に上限を定める
        if is_inference:
            max_decoder_time_steps = int(encoder_outs.shape[1] * 10.0)
        else:
            max_decoder_time_steps = decoder_targets.shape[1]

        # ゼロパディングされた部分に対するマスク
        mask = make_pad_mask(in_lens).to(encoder_outs.device)

        # LSTM の状態をゼロで初期化
        h_list, c_list = [], []
        for _ in range(len(self.lstm)):
            h_list.append(self._zero_state(encoder_outs))
            c_list.append(self._zero_state(encoder_outs))

        # デコーダの最初の入力
        go_frame = encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim)
        prev_out = go_frame

        # 1つ前の時刻のアテンション重み
        prev_att_w = None

        # メインループ
        outs, logits, att_ws = [], [], []
        t = 0
        while True:
            # コンテキストベクトル、アテンション重みの計算
            att_c, att_w = self.attention(
                encoder_outs, in_lens, h_list[0], prev_att_w, mask
            )

            # Pre-Net
            prenet_out = self.prenet(prev_out)

            # LSTM
            xs = torch.cat([att_c, prenet_out], dim=1)
            h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0]))
            for i in range(1, len(self.lstm)):
                h_list[i], c_list[i] = self.lstm[i](
                    h_list[i - 1], (h_list[i], c_list[i])
                )
            # 出力の計算
            hcs = torch.cat([h_list[-1], att_c], dim=1)
            outs.append(self.feat_out(hcs).view(encoder_outs.size(0), self.out_dim, -1))
            logits.append(self.prob_out(hcs))
            att_ws.append(att_w)

            # 次の時刻のデコーダの入力を更新
            if is_inference:
                prev_out = outs[-1][:, :, -1]  # (1, out_dim)
            else:
                # Teacher forcing
                prev_out = decoder_targets[:, t, :]

            # 累積アテンション重み
            prev_att_w = att_w if prev_att_w is None else prev_att_w + att_w

            t += 1
            # 停止条件のチェック
            if t >= max_decoder_time_steps:
                break
            if is_inference and (torch.sigmoid(logits[-1]) >= 0.5).any():
                break

        # 各時刻の出力を結合
        logits = torch.cat(logits, dim=1)  # (B, Lmax)
        outs = torch.cat(outs, dim=2)  # (B, out_dim, Lmax)
        att_ws = torch.stack(att_ws, dim=1)  # (B, Lmax, Tmax)

        if self.reduction_factor > 1:
            outs = outs.view(outs.size(0), self.out_dim, -1)  # (B, out_dim, Lmax)

        return outs, logits, att_ws
[35]:
Decoder()
[35]:
Decoder(
  (attention): LocationSensitiveAttention(
    (V): Linear(in_features=512, out_features=128, bias=True)
    (W): Linear(in_features=1024, out_features=128, bias=False)
    (U): Linear(in_features=32, out_features=128, bias=False)
    (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)
    (w): Linear(in_features=128, out_features=1, bias=True)
  )
  (prenet): Prenet(
    (prenet): Sequential(
      (0): Linear(in_features=80, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (lstm): ModuleList(
    (0): ZoneOutCell(
      (cell): LSTMCell(768, 1024)
    )
    (1): ZoneOutCell(
      (cell): LSTMCell(1024, 1024)
    )
  )
  (feat_out): Linear(in_features=1536, out_features=80, bias=False)
  (prob_out): Linear(in_features=1536, out_features=1, bias=True)
)
[36]:
decoder_targets = torch.ones(encoder_outs.shape[0], 120, 80)
decoder = Decoder(encoder_outs.shape[-1], 80)

# Teaccher forcing: decoder_targets (教師データ) を与える
with torch.no_grad():
    outs, logits, att_ws = decoder(encoder_outs, in_lens, decoder_targets);

print(f"デコーダの入力のサイズ: {tuple(decoder_input.shape)}")
print(f"デコーダの出力のサイズ: {tuple(outs.shape)}")
print(f"stop token (logits) のサイズ: {tuple(logits.shape)}")
print(f"アテンション重みのサイズ: {tuple(att_ws.shape)}")
デコーダの入力のサイズ: (2, 80)
デコーダの出力のサイズ: (2, 80, 120)
stop token (logits) のサイズ: (2, 120)
アテンション重みのサイズ: (2, 120, 33)
[37]:
# 自己回帰に基づく推論
with torch.no_grad():
    decoder(encoder_outs[0], torch.tensor([in_lens[0]]))

9.6 Post-Net

[38]:
class Postnet(nn.Module):
    def __init__(
        self,
        in_dim=80,
        layers=5,
        channels=512,
        kernel_size=5,
        dropout=0.5,
    ):
        super().__init__()
        postnet = nn.ModuleList()
        for layer in range(layers):
            in_channels = in_dim if layer == 0 else channels
            out_channels = in_dim if layer == layers - 1 else channels
            postnet += [
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    padding=(kernel_size - 1) // 2,
                    bias=False,
                ),
                nn.BatchNorm1d(out_channels),
            ]
            if layer != layers - 1:
                postnet += [nn.Tanh()]
            postnet += [nn.Dropout(dropout)]
        self.postnet = nn.Sequential(*postnet)

    def forward(self, xs):
        return self.postnet(xs)
[39]:
Postnet()
[39]:
Postnet(
  (postnet): Sequential(
    (0): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Tanh()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Tanh()
    (11): Dropout(p=0.5, inplace=False)
    (12): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Tanh()
    (15): Dropout(p=0.5, inplace=False)
    (16): Conv1d(512, 80, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (17): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (18): Dropout(p=0.5, inplace=False)
  )
)
[40]:
postnet = Postnet(80)
residual = postnet(outs)

print(f"入力のサイズ: {tuple(outs.shape)}")
print(f"出力のサイズ: {tuple(residual.shape)}")
入力のサイズ: (2, 80, 120)
出力のサイズ: (2, 80, 120)

9.7 Tacotron 2 の実装

Tacotron 2 のモデル定義

[41]:
class Tacotron2(nn.Module):
    def __init__(self
    ):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.postnet = Postnet()

    def forward(self, seq, in_lens, decoder_targets):
        # エンコーダによるテキストに潜在する表現の獲得
        encoder_outs = self.encoder(seq, in_lens)

        # デコーダによるメルスペクトログラム、stop token の予測
        outs, logits, att_ws = self.decoder(encoder_outs, in_lens, decoder_targets)

        # Post-Net によるメルスペクトログラムの残差の予測
        outs_fine = outs + self.postnet(outs)

        # (B, C, T) -> (B, T, C)
        outs = outs.transpose(2, 1)
        outs_fine = outs_fine.transpose(2, 1)

        return outs, outs_fine, logits, att_ws

    def inference(self, seq):
        seq = seq.unsqueeze(0) if len(seq.shape) == 1 else seq
        in_lens = torch.tensor([seq.shape[-1]], dtype=torch.long, device=seq.device)

        return self.forward(seq, in_lens, None)
[42]:
seqs, in_lens = get_dummy_input()
model = Tacotron2()

# Tacotron 2 の計算
outs, outs_fine, logits, att_ws = model(seqs, in_lens, decoder_targets)

print(f"入力のサイズ: {tuple(seqs.shape)}")
print(f"デコーダの出力のサイズ: {tuple(outs.shape)}")
print(f"Post-Netの出力のサイズ: {tuple(outs_fine.shape)}")
print(f"stop token (logits) のサイズ: {tuple(logits.shape)}")
print(f"アテンション重みのサイズ: {tuple(att_ws.shape)}")
入力のサイズ: (2, 33)
デコーダの出力のサイズ: (2, 120, 80)
Post-Netの出力のサイズ: (2, 120, 80)
stop token (logits) のサイズ: (2, 120)
アテンション重みのサイズ: (2, 120, 33)
[43]:
model
[43]:
Tacotron2(
  (encoder): Encoder(
    (embed): Embedding(40, 512, padding_idx=0)
    (convs): Sequential(
      (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.5, inplace=False)
      (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.5, inplace=False)
      (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.5, inplace=False)
    )
    (blstm): LSTM(512, 256, batch_first=True, bidirectional=True)
  )
  (decoder): Decoder(
    (attention): LocationSensitiveAttention(
      (V): Linear(in_features=512, out_features=128, bias=True)
      (W): Linear(in_features=1024, out_features=128, bias=False)
      (U): Linear(in_features=32, out_features=128, bias=False)
      (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)
      (w): Linear(in_features=128, out_features=1, bias=True)
    )
    (prenet): Prenet(
      (prenet): Sequential(
        (0): Linear(in_features=80, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): ReLU()
      )
    )
    (lstm): ModuleList(
      (0): ZoneOutCell(
        (cell): LSTMCell(768, 1024)
      )
      (1): ZoneOutCell(
        (cell): LSTMCell(1024, 1024)
      )
    )
    (feat_out): Linear(in_features=1536, out_features=80, bias=False)
    (prob_out): Linear(in_features=1536, out_features=1, bias=True)
  )
  (postnet): Postnet(
    (postnet): Sequential(
      (0): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
      (3): Dropout(p=0.5, inplace=False)
      (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Tanh()
      (7): Dropout(p=0.5, inplace=False)
      (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Tanh()
      (11): Dropout(p=0.5, inplace=False)
      (12): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): Tanh()
      (15): Dropout(p=0.5, inplace=False)
      (16): Conv1d(512, 80, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
      (17): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (18): Dropout(p=0.5, inplace=False)
    )
  )
)

トイモデルを利用したTacotron 2の動作確認

[44]:
from ttslearn.tacotron import Tacotron2
model = Tacotron2(encoder_conv_layers=1, decoder_prenet_layers=1, decoder_layers=1, postnet_layers=1)
[45]:
def get_dummy_inout():
    seqs, in_lens = get_dummy_input()

    # デコーダの出力(メルスペクトログラム)の教師データ
    decoder_targets = torch.ones(2, 120, 80)

    # stop token の教師データ
    # stop token の予測値は確率ですが、教師データは 二値のラベルです
    # 1 は、デコーダの出力が完了したことを表します
    stop_tokens = torch.zeros(2, 120)
    stop_tokens[:, -1:] = 1.0

    return seqs, in_lens, decoder_targets, stop_tokens
[46]:
# 適当な入出力を生成
seqs, in_lens, decoder_targets, stop_tokens = get_dummy_inout()

# Tacotron 2 の出力を計算
# NOTE: teacher-forcing のため、 decoder targets を明示的に与える
outs, outs_fine, logits, att_ws = model(seqs, in_lens, decoder_targets)

print("入力のサイズ:", tuple(seqs.shape))
print("デコーダの出力のサイズ:", tuple(outs.shape))
print("Stop token のサイズ:", tuple(logits.shape))
print("アテンション重みのサイズ:", tuple(att_ws.shape))
入力のサイズ: (2, 33)
デコーダの出力のサイズ: (2, 120, 80)
Stop token のサイズ: (2, 120)
アテンション重みのサイズ: (2, 120, 33)

Tacotron 2の損失関数の計算

[47]:
# 1. デコーダの出力に対する損失
out_loss = nn.MSELoss()(outs, decoder_targets)
# 2. Post-Net のあとの出力に対する損失
out_fine_loss = nn.MSELoss()(outs_fine, decoder_targets)
# 3. Stop token に対する損失
stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_tokens)
[48]:
print("out_loss: ", out_loss.item())
print("out_fine_loss: ", out_fine_loss.item())
print("stop_token_loss: ", stop_token_loss.item())
out_loss:  0.9949015378952026
out_fine_loss:  2.896300792694092
stop_token_loss:  0.6844527125358582