{ "cells": [ { "cell_type": "markdown", "id": "celtic-greek", "metadata": {}, "source": [ "# 第9章 Tacotron 2: 一貫学習を狙った音声合成\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/r9y9/ttslearn/blob/master/notebooks/ch09_Tacotron.ipynb)" ] }, { "cell_type": "markdown", "id": "modular-biography", "metadata": { "tags": [] }, "source": [ "## 準備" ] }, { "cell_type": "markdown", "id": "configured-cause", "metadata": {}, "source": [ "### Python version" ] }, { "cell_type": "code", "execution_count": 1, "id": "comfortable-conference", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:50.388734Z", "iopub.status.busy": "2021-08-21T07:51:50.388441Z", "iopub.status.idle": "2021-08-21T07:51:50.497189Z", "shell.execute_reply": "2021-08-21T07:51:50.497485Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python 3.8.6 | packaged by conda-forge | (default, Dec 26 2020, 05:05:16) \r\n", "[GCC 9.3.0]\r\n" ] } ], "source": [ "!python -VV" ] }, { "cell_type": "markdown", "id": "limited-senator", "metadata": {}, "source": [ "### ttslearn のインストール" ] }, { "cell_type": "code", "execution_count": 2, "id": "innovative-conservative", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:50.500837Z", "iopub.status.busy": "2021-08-21T07:51:50.500529Z", "iopub.status.idle": "2021-08-21T07:51:51.125478Z", "shell.execute_reply": "2021-08-21T07:51:51.125152Z" } }, "outputs": [], "source": [ "%%capture\n", "try:\n", " import ttslearn\n", "except ImportError:\n", " !pip install ttslearn" ] }, { "cell_type": "code", "execution_count": 3, "id": "physical-fleece", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:51.130402Z", "iopub.status.busy": "2021-08-21T07:51:51.130110Z", "iopub.status.idle": "2021-08-21T07:51:51.132270Z", "shell.execute_reply": "2021-08-21T07:51:51.131975Z" } }, "outputs": [ { "data": { "text/plain": [ "'0.2.1'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ttslearn\n", "ttslearn.__version__" ] }, { "cell_type": "markdown", "id": "executive-bargain", "metadata": {}, "source": [ "### パッケージのインポート" ] }, { "cell_type": "code", "execution_count": 4, "id": "legal-livestock", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:51.136887Z", "iopub.status.busy": "2021-08-21T07:51:51.136604Z", "iopub.status.idle": "2021-08-21T07:51:51.357279Z", "shell.execute_reply": "2021-08-21T07:51:51.356972Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "%pylab inline\n", "%load_ext autoreload\n", "%load_ext tensorboard\n", "%autoreload\n", "import IPython\n", "from IPython.display import Audio\n", "import tensorboard as tb\n", "import os" ] }, { "cell_type": "code", "execution_count": 5, "id": "hungry-consolidation", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:51.359483Z", "iopub.status.busy": "2021-08-21T07:51:51.359180Z", "iopub.status.idle": "2021-08-21T07:51:52.244964Z", "shell.execute_reply": "2021-08-21T07:51:52.245263Z" } }, "outputs": [], "source": [ "# 数値演算\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "# 音声波形の読み込み\n", "from scipy.io import wavfile\n", "# フルコンテキストラベル、質問ファイルの読み込み\n", "from nnmnkwii.io import hts\n", "# 音声分析\n", "import pyworld\n", "# 音声分析、可視化\n", "import librosa\n", "import librosa.display\n", "# Pythonで学ぶ音声合成\n", "import ttslearn" ] }, { "cell_type": "code", "execution_count": 6, "id": "constitutional-compound", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.247474Z", "iopub.status.busy": "2021-08-21T07:51:52.247141Z", "iopub.status.idle": "2021-08-21T07:51:52.258587Z", "shell.execute_reply": "2021-08-21T07:51:52.258895Z" } }, "outputs": [], "source": [ "# シードの固定\n", "from ttslearn.util import init_seed\n", "init_seed(773)" ] }, { "cell_type": "code", "execution_count": 7, "id": "decent-oliver", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.261220Z", "iopub.status.busy": "2021-08-21T07:51:52.260842Z", "iopub.status.idle": "2021-08-21T07:51:52.262809Z", "shell.execute_reply": "2021-08-21T07:51:52.263081Z" } }, "outputs": [ { "data": { "text/plain": [ "'1.8.1'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.__version__" ] }, { "cell_type": "markdown", "id": "needed-reader", "metadata": {}, "source": [ "### 描画周りの設定" ] }, { "cell_type": "code", "execution_count": 8, "id": "fundamental-formula", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.265268Z", "iopub.status.busy": "2021-08-21T07:51:52.264860Z", "iopub.status.idle": "2021-08-21T07:51:52.267216Z", "shell.execute_reply": "2021-08-21T07:51:52.266911Z" } }, "outputs": [], "source": [ "from ttslearn.notebook import get_cmap, init_plot_style, savefig\n", "cmap = get_cmap()\n", "init_plot_style()" ] }, { "cell_type": "markdown", "id": "prescribed-advantage", "metadata": {}, "source": [ "## 9.3 エンコーダ" ] }, { "cell_type": "markdown", "id": "opened-effort", "metadata": {}, "source": [ "### 文字列から数値列への変換" ] }, { "cell_type": "code", "execution_count": 9, "id": "palestinian-basement", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.269976Z", "iopub.status.busy": "2021-08-21T07:51:52.269692Z", "iopub.status.idle": "2021-08-21T07:51:52.270945Z", "shell.execute_reply": "2021-08-21T07:51:52.271217Z" } }, "outputs": [], "source": [ "# 語彙の定義\n", "characters = \"abcdefghijklmnopqrstuvwxyz!'(),-.:;? \"\n", "# その他特殊記号\n", "extra_symbols = [\n", " \"^\", # 文の先頭を表す特殊記号 \n", " \"$\", # 文の末尾を表す特殊記号 \n", "]\n", "_pad = \"~\"\n", "\n", "# NOTE: パディングを 0 番目に配置\n", "symbols = [_pad] + extra_symbols + list(characters)\n", "\n", "# 文字列⇔数値の相互変換のための辞書\n", "_symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", "_id_to_symbol = {i: s for i, s in enumerate(symbols)}" ] }, { "cell_type": "code", "execution_count": 10, "id": "future-dialogue", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.273405Z", "iopub.status.busy": "2021-08-21T07:51:52.273014Z", "iopub.status.idle": "2021-08-21T07:51:52.275253Z", "shell.execute_reply": "2021-08-21T07:51:52.274953Z" } }, "outputs": [ { "data": { "text/plain": [ "40" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(symbols)" ] }, { "cell_type": "code", "execution_count": 11, "id": "broad-madison", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.277820Z", "iopub.status.busy": "2021-08-21T07:51:52.277540Z", "iopub.status.idle": "2021-08-21T07:51:52.279086Z", "shell.execute_reply": "2021-08-21T07:51:52.278796Z" } }, "outputs": [], "source": [ "def text_to_sequence(text):\n", " # 簡易のため、大文字と小文字を区別せず、全ての大文字を小文字に変換\n", " text = text.lower()\n", "\n", " # \n", " seq = [_symbol_to_id[\"^\"]]\n", "\n", " # 本文\n", " seq += [_symbol_to_id[s] for s in text]\n", "\n", " # \n", " seq.append(_symbol_to_id[\"$\"])\n", "\n", " return seq\n", "\n", "\n", "def sequence_to_text(seq):\n", " return [_id_to_symbol[s] for s in seq]" ] }, { "cell_type": "code", "execution_count": 12, "id": "partial-burst", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.281301Z", "iopub.status.busy": "2021-08-21T07:51:52.280986Z", "iopub.status.idle": "2021-08-21T07:51:52.282625Z", "shell.execute_reply": "2021-08-21T07:51:52.282847Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "文字列から数値列への変換: [1, 10, 7, 14, 14, 17, 29, 2]\n", "数値列から文字列への逆変換: ['^', 'h', 'e', 'l', 'l', 'o', '!', '$']\n" ] } ], "source": [ "seq = text_to_sequence(\"Hello!\")\n", "print(f\"文字列から数値列への変換: {seq}\")\n", "print(f\"数値列から文字列への逆変換: {sequence_to_text(seq)}\")" ] }, { "cell_type": "markdown", "id": "surface-sister", "metadata": {}, "source": [ "### 文字埋め込み" ] }, { "cell_type": "code", "execution_count": 13, "id": "fallen-broadcasting", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.285364Z", "iopub.status.busy": "2021-08-21T07:51:52.285090Z", "iopub.status.idle": "2021-08-21T07:51:52.286631Z", "shell.execute_reply": "2021-08-21T07:51:52.286390Z" } }, "outputs": [], "source": [ "class SimplestEncoder(nn.Module):\n", " def __init__(self, num_vocab=40, embed_dim=256):\n", " super().__init__()\n", " self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)\n", " \n", " def forward(self, seqs):\n", " return self.embed(seqs)" ] }, { "cell_type": "code", "execution_count": 14, "id": "welcome-theater", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.291319Z", "iopub.status.busy": "2021-08-21T07:51:52.291029Z", "iopub.status.idle": "2021-08-21T07:51:52.293413Z", "shell.execute_reply": "2021-08-21T07:51:52.293637Z" } }, "outputs": [ { "data": { "text/plain": [ "SimplestEncoder(\n", " (embed): Embedding(40, 256, padding_idx=0)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SimplestEncoder()" ] }, { "cell_type": "code", "execution_count": 15, "id": "comparative-wiring", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.296598Z", "iopub.status.busy": "2021-08-21T07:51:52.296318Z", "iopub.status.idle": "2021-08-21T07:51:52.297868Z", "shell.execute_reply": "2021-08-21T07:51:52.297629Z" } }, "outputs": [], "source": [ "from ttslearn.util import pad_1d\n", "\n", "def get_dummy_input():\n", " # バッチサイズに 2 を想定して、適当な文字列を作成\n", " seqs = [\n", " text_to_sequence(\"What is your favorite language?\"),\n", " text_to_sequence(\"Hello world.\"),\n", " ]\n", " in_lens = torch.tensor([len(x) for x in seqs], dtype=torch.long)\n", " max_len = max(len(x) for x in seqs)\n", " seqs = torch.stack([torch.from_numpy(pad_1d(seq, max_len)) for seq in seqs])\n", " \n", " return seqs, in_lens" ] }, { "cell_type": "code", "execution_count": 16, "id": "confidential-expert", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.299944Z", "iopub.status.busy": "2021-08-21T07:51:52.299655Z", "iopub.status.idle": "2021-08-21T07:51:52.302438Z", "shell.execute_reply": "2021-08-21T07:51:52.302196Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力 tensor([[ 1, 25, 10, 3, 22, 39, 11, 21, 39, 27, 17, 23, 20, 39, 8, 3, 24, 17,\n", " 20, 11, 22, 7, 39, 14, 3, 16, 9, 23, 3, 9, 7, 38, 2],\n", " [ 1, 10, 7, 14, 14, 17, 39, 25, 17, 20, 14, 6, 35, 2, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", "系列長: tensor([33, 14])\n" ] } ], "source": [ "seqs, in_lens = get_dummy_input()\n", "print(\"入力\", seqs)\n", "print(\"系列長:\", in_lens)" ] }, { "cell_type": "code", "execution_count": 17, "id": "prospective-aggregate", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.304702Z", "iopub.status.busy": "2021-08-21T07:51:52.304426Z", "iopub.status.idle": "2021-08-21T07:51:52.307795Z", "shell.execute_reply": "2021-08-21T07:51:52.308018Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 33)\n", "出力のサイズ: (2, 33, 256)\n" ] } ], "source": [ "encoder = SimplestEncoder(num_vocab=40, embed_dim=256)\n", "seqs, in_lens = get_dummy_input()\n", "encoder_outs = encoder(seqs)\n", "print(f\"入力のサイズ: {tuple(seqs.shape)}\")\n", "print(f\"出力のサイズ: {tuple(encoder_outs.shape)}\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "irish-paintball", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.310384Z", "iopub.status.busy": "2021-08-21T07:51:52.310111Z", "iopub.status.idle": "2021-08-21T07:51:52.316966Z", "shell.execute_reply": "2021-08-21T07:51:52.317188Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.7055, 0.6891, 0.0332, ..., 0.7174, 0.4686, 1.1468],\n", " [-0.1568, -0.3719, -1.0086, ..., -0.9326, -1.2187, -0.0714],\n", " [-0.1901, -0.1983, 0.2274, ..., 0.2284, 1.6452, -0.3408],\n", " ...,\n", " [-1.9353, 0.2628, -0.1449, ..., 1.6056, -0.3912, -0.0740],\n", " [ 0.4687, 0.3258, -0.6565, ..., 1.0895, 0.9105, 0.2814],\n", " [ 0.8940, 0.3002, -0.2105, ..., 0.7973, 0.2230, -0.1975]],\n", "\n", " [[ 0.7055, 0.6891, 0.0332, ..., 0.7174, 0.4686, 1.1468],\n", " [-0.1901, -0.1983, 0.2274, ..., 0.2284, 1.6452, -0.3408],\n", " [-1.9353, 0.2628, -0.1449, ..., 1.6056, -0.3912, -0.0740],\n", " ...,\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],\n", " grad_fn=)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# パディングの部分は0を取り、それ以外は連続値で表されます\n", "encoder_outs" ] }, { "cell_type": "markdown", "id": "quality-thanks", "metadata": {}, "source": [ "### 1次元畳み込みの導入" ] }, { "cell_type": "code", "execution_count": 19, "id": "chubby-commander", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.321164Z", "iopub.status.busy": "2021-08-21T07:51:52.320867Z", "iopub.status.idle": "2021-08-21T07:51:52.322154Z", "shell.execute_reply": "2021-08-21T07:51:52.322431Z" } }, "outputs": [], "source": [ "class ConvEncoder(nn.Module):\n", " def __init__(\n", " self,\n", " num_vocab=40,\n", " embed_dim=256,\n", " conv_layers=3,\n", " conv_channels=256,\n", " conv_kernel_size=5,\n", " ):\n", " super().__init__()\n", " # 文字埋め込み\n", " self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)\n", "\n", " # 1次元畳み込みの重ね合わせ:局所的な依存関係のモデル化\n", " self.convs = nn.ModuleList()\n", " for layer in range(conv_layers):\n", " in_channels = embed_dim if layer == 0 else conv_channels\n", " self.convs += [\n", " nn.Conv1d(\n", " in_channels,\n", " conv_channels,\n", " conv_kernel_size,\n", " padding=(conv_kernel_size - 1) // 2,\n", " bias=False,\n", " ),\n", " nn.BatchNorm1d(conv_channels),\n", " nn.ReLU(),\n", " nn.Dropout(0.5),\n", " ]\n", " self.convs = nn.Sequential(*self.convs)\n", "\n", " def forward(self, seqs):\n", " emb = self.embed(seqs)\n", " # 1 次元畳み込みと embedding では、入力のサイズが異なるので注意\n", " out = self.convs(emb.transpose(1, 2)).transpose(1, 2)\n", " return out" ] }, { "cell_type": "code", "execution_count": 20, "id": "precise-leonard", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.324253Z", "iopub.status.busy": "2021-08-21T07:51:52.323985Z", "iopub.status.idle": "2021-08-21T07:51:52.332493Z", "shell.execute_reply": "2021-08-21T07:51:52.332199Z" } }, "outputs": [ { "data": { "text/plain": [ "ConvEncoder(\n", " (embed): Embedding(40, 256, padding_idx=0)\n", " (convs): Sequential(\n", " (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Dropout(p=0.5, inplace=False)\n", " (4): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): ReLU()\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (10): ReLU()\n", " (11): Dropout(p=0.5, inplace=False)\n", " )\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ConvEncoder()" ] }, { "cell_type": "code", "execution_count": 21, "id": "interim-helmet", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.334801Z", "iopub.status.busy": "2021-08-21T07:51:52.334521Z", "iopub.status.idle": "2021-08-21T07:51:52.355067Z", "shell.execute_reply": "2021-08-21T07:51:52.355395Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 33)\n", "出力のサイズ: (2, 33, 256)\n" ] } ], "source": [ "encoder = ConvEncoder(num_vocab=40, embed_dim=256)\n", "seqs, in_lens = get_dummy_input()\n", "encoder_outs = encoder(seqs)\n", "print(f\"入力のサイズ: {tuple(seqs.shape)}\")\n", "print(f\"出力のサイズ: {tuple(encoder_outs.shape)}\")" ] }, { "cell_type": "markdown", "id": "understanding-track", "metadata": {}, "source": [ "### 双方向LSTM の導入" ] }, { "cell_type": "code", "execution_count": 22, "id": "driven-sullivan", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.359839Z", "iopub.status.busy": "2021-08-21T07:51:52.359490Z", "iopub.status.idle": "2021-08-21T07:51:52.361485Z", "shell.execute_reply": "2021-08-21T07:51:52.361887Z" } }, "outputs": [], "source": [ "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", "\n", "class Encoder(ConvEncoder):\n", " def __init__(\n", " self,\n", " num_vocab=40,\n", " embed_dim=512,\n", " hidden_dim=512,\n", " conv_layers=3,\n", " conv_channels=512,\n", " conv_kernel_size=5,\n", " ):\n", " super().__init__(\n", " num_vocab, embed_dim, conv_layers, conv_channels, conv_kernel_size\n", " )\n", " # 双方向 LSTM による長期依存関係のモデル化\n", " self.blstm = nn.LSTM(\n", " conv_channels, hidden_dim // 2, 1, batch_first=True, bidirectional=True\n", " )\n", "\n", " def forward(self, seqs, in_lens):\n", " emb = self.embed(seqs)\n", " # 1 次元畳み込みと embedding では、入力のサイズ が異なるので注意\n", " out = self.convs(emb.transpose(1, 2)).transpose(1, 2)\n", "\n", " # 双方向 LSTM の計算\n", " out = pack_padded_sequence(out, in_lens, batch_first=True)\n", " out, _ = self.blstm(out)\n", " out, _ = pad_packed_sequence(out, batch_first=True)\n", " return out" ] }, { "cell_type": "code", "execution_count": 23, "id": "geological-billion", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.363847Z", "iopub.status.busy": "2021-08-21T07:51:52.363578Z", "iopub.status.idle": "2021-08-21T07:51:52.404083Z", "shell.execute_reply": "2021-08-21T07:51:52.403718Z" } }, "outputs": [ { "data": { "text/plain": [ "Encoder(\n", " (embed): Embedding(40, 512, padding_idx=0)\n", " (convs): Sequential(\n", " (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Dropout(p=0.5, inplace=False)\n", " (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): ReLU()\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (10): ReLU()\n", " (11): Dropout(p=0.5, inplace=False)\n", " )\n", " (blstm): LSTM(512, 256, batch_first=True, bidirectional=True)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Encoder()" ] }, { "cell_type": "code", "execution_count": 24, "id": "detected-growth", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.406525Z", "iopub.status.busy": "2021-08-21T07:51:52.406185Z", "iopub.status.idle": "2021-08-21T07:51:52.446541Z", "shell.execute_reply": "2021-08-21T07:51:52.446236Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 33)\n", "出力のサイズ: (2, 33, 512)\n" ] } ], "source": [ "encoder = Encoder(num_vocab=40, embed_dim=256)\n", "seqs, in_lens = get_dummy_input()\n", "in_lens, indices = torch.sort(in_lens, dim=0, descending=True)\n", "seqs = seqs[indices]\n", "\n", "encoder_outs = encoder(seqs, in_lens)\n", "print(f\"入力のサイズ: {tuple(seqs.shape)}\")\n", "print(f\"出力のサイズ: {tuple(encoder_outs.shape)}\")" ] }, { "cell_type": "markdown", "id": "complicated-female", "metadata": {}, "source": [ "## 9.4 注意機構" ] }, { "cell_type": "markdown", "id": "therapeutic-kenya", "metadata": {}, "source": [ "### 内容依存の注意機構" ] }, { "cell_type": "code", "execution_count": 25, "id": "realistic-range", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.450310Z", "iopub.status.busy": "2021-08-21T07:51:52.450024Z", "iopub.status.idle": "2021-08-21T07:51:52.451795Z", "shell.execute_reply": "2021-08-21T07:51:52.451437Z" } }, "outputs": [], "source": [ "from torch.nn import functional as F\n", "\n", "# 書籍中の数式に沿って、わかりやすさを重視した実装\n", "class BahdanauAttention(nn.Module):\n", " def __init__(self, encoder_dim=512, decoder_dim=1024, hidden_dim=128):\n", " super().__init__()\n", " self.V = nn.Linear(encoder_dim, hidden_dim)\n", " self.W = nn.Linear(decoder_dim, hidden_dim, bias=False)\n", " # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません\n", " self.w = nn.Linear(hidden_dim, 1)\n", "\n", " def forward(self, encoder_out, decoder_state, mask=None):\n", " # 式 (9.11) の計算\n", " erg = self.w(\n", " torch.tanh(self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs))\n", " ).squeeze(-1)\n", "\n", " if mask is not None:\n", " erg.masked_fill_(mask, -float(\"inf\"))\n", "\n", " attention_weights = F.softmax(erg, dim=1)\n", "\n", " # エンコーダ出力の長さ方向に対して重み付き和を取ります\n", " attention_context = torch.sum(\n", " encoder_outs * attention_weights.unsqueeze(-1), dim=1\n", " )\n", "\n", " return attention_context, attention_weights" ] }, { "cell_type": "code", "execution_count": 26, "id": "southeast-advisory", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.453826Z", "iopub.status.busy": "2021-08-21T07:51:52.453509Z", "iopub.status.idle": "2021-08-21T07:51:52.457103Z", "shell.execute_reply": "2021-08-21T07:51:52.457440Z" } }, "outputs": [ { "data": { "text/plain": [ "BahdanauAttention(\n", " (V): Linear(in_features=512, out_features=128, bias=True)\n", " (W): Linear(in_features=1024, out_features=128, bias=False)\n", " (w): Linear(in_features=128, out_features=1, bias=True)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "BahdanauAttention()" ] }, { "cell_type": "code", "execution_count": 27, "id": "defensive-satellite", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.460396Z", "iopub.status.busy": "2021-08-21T07:51:52.460052Z", "iopub.status.idle": "2021-08-21T07:51:52.464303Z", "shell.execute_reply": "2021-08-21T07:51:52.464581Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "エンコーダの出力のサイズ: (2, 33, 512)\n", "デコーダの隠れ状態のサイズ: (2, 1024)\n", "コンテキストベクトルのサイズ: (2, 512)\n", "アテンション重みのサイズ: (2, 33)\n" ] } ], "source": [ "from ttslearn.util import make_pad_mask\n", "\n", "mask = make_pad_mask(in_lens).to(encoder_outs.device)\n", "attention = BahdanauAttention()\n", "\n", "decoder_input = torch.ones(len(seqs), 1024)\n", "\n", "attention_context, attention_weights = attention(encoder_outs, decoder_input, mask)\n", "\n", "print(f\"エンコーダの出力のサイズ: {tuple(encoder_outs.shape)}\")\n", "print(f\"デコーダの隠れ状態のサイズ: {tuple(decoder_input.shape)}\")\n", "print(f\"コンテキストベクトルのサイズ: {tuple(attention_context.shape)}\")\n", "print(f\"アテンション重みのサイズ: {tuple(attention_weights.shape)}\")" ] }, { "cell_type": "markdown", "id": "played-bacteria", "metadata": {}, "source": [ "### ハイブリッド注意機構" ] }, { "cell_type": "code", "execution_count": 28, "id": "empirical-storm", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.470553Z", "iopub.status.busy": "2021-08-21T07:51:52.470242Z", "iopub.status.idle": "2021-08-21T07:51:52.471997Z", "shell.execute_reply": "2021-08-21T07:51:52.471702Z" } }, "outputs": [], "source": [ "class LocationSensitiveAttention(nn.Module):\n", " def __init__(\n", " self,\n", " encoder_dim=512,\n", " decoder_dim=1024,\n", " hidden_dim=128,\n", " conv_channels=32,\n", " conv_kernel_size=31,\n", " ):\n", " super().__init__()\n", " self.V = nn.Linear(encoder_dim, hidden_dim)\n", " self.W = nn.Linear(decoder_dim, hidden_dim, bias=False)\n", " self.U = nn.Linear(conv_channels, hidden_dim, bias=False)\n", " self.F = nn.Conv1d(\n", " 1,\n", " conv_channels,\n", " conv_kernel_size,\n", " padding=(conv_kernel_size - 1) // 2,\n", " bias=False,\n", " )\n", " # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません\n", " self.w = nn.Linear(hidden_dim, 1)\n", "\n", " def forward(self, encoder_outs, src_lens, decoder_state, att_prev, mask=None):\n", " # アテンション重みを一様分布で初期化\n", " if att_prev is None:\n", " att_prev = 1.0 - make_pad_mask(src_lens).to(\n", " device=decoder_state.device, dtype=decoder_state.dtype\n", " )\n", " att_prev = att_prev / src_lens.unsqueeze(-1).to(encoder_outs.device)\n", "\n", " # (B x T_enc) -> (B x 1 x T_enc) -> (B x conv_channels x T_enc) ->\n", " # (B x T_enc x conv_channels)\n", " f = self.F(att_prev.unsqueeze(1)).transpose(1, 2)\n", "\n", " # 式 (9.13) の計算\n", " erg = self.w(\n", " torch.tanh(\n", " self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs) + self.U(f)\n", " )\n", " ).squeeze(-1)\n", "\n", " if mask is not None:\n", " erg.masked_fill_(mask, -float(\"inf\"))\n", "\n", " attention_weights = F.softmax(erg, dim=1)\n", "\n", " # エンコーダ出力の長さ方向に対して重み付き和を取ります\n", " attention_context = torch.sum(\n", " encoder_outs * attention_weights.unsqueeze(-1), dim=1\n", " )\n", "\n", " return attention_context, attention_weights" ] }, { "cell_type": "code", "execution_count": 29, "id": "virtual-walker", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.473918Z", "iopub.status.busy": "2021-08-21T07:51:52.473574Z", "iopub.status.idle": "2021-08-21T07:51:52.477060Z", "shell.execute_reply": "2021-08-21T07:51:52.476758Z" } }, "outputs": [ { "data": { "text/plain": [ "LocationSensitiveAttention(\n", " (V): Linear(in_features=512, out_features=128, bias=True)\n", " (W): Linear(in_features=1024, out_features=128, bias=False)\n", " (U): Linear(in_features=32, out_features=128, bias=False)\n", " (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)\n", " (w): Linear(in_features=128, out_features=1, bias=True)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LocationSensitiveAttention()" ] }, { "cell_type": "code", "execution_count": 30, "id": "fourth-bulletin", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.479609Z", "iopub.status.busy": "2021-08-21T07:51:52.479255Z", "iopub.status.idle": "2021-08-21T07:51:52.505154Z", "shell.execute_reply": "2021-08-21T07:51:52.505495Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "エンコーダの出力のサイズ: (2, 33, 512)\n", "デコーダの隠れ状態のサイズ: (2, 1024)\n", "コンテキストベクトルのサイズ: (2, 512)\n", "アテンション重みのサイズ: (2, 33)\n" ] } ], "source": [ "from ttslearn.util import make_pad_mask\n", "\n", "mask = make_pad_mask(in_lens).to(encoder_outs.device)\n", "attention = LocationSensitiveAttention()\n", "\n", "decoder_input = torch.ones(len(seqs), 1024)\n", "\n", "attention_context, attention_weights = attention(encoder_outs, in_lens, decoder_input, None, mask)\n", "\n", "print(f\"エンコーダの出力のサイズ: {tuple(encoder_outs.shape)}\")\n", "print(f\"デコーダの隠れ状態のサイズ: {tuple(decoder_input.shape)}\")\n", "print(f\"コンテキストベクトルのサイズ: {tuple(attention_context.shape)}\")\n", "print(f\"アテンション重みのサイズ: {tuple(attention_weights.shape)}\")" ] }, { "cell_type": "markdown", "id": "peaceful-tackle", "metadata": {}, "source": [ "## 9.5 デコーダ" ] }, { "cell_type": "markdown", "id": "formal-facility", "metadata": {}, "source": [ "### Pre-Net" ] }, { "cell_type": "code", "execution_count": 31, "id": "medieval-specific", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.509519Z", "iopub.status.busy": "2021-08-21T07:51:52.509158Z", "iopub.status.idle": "2021-08-21T07:51:52.510822Z", "shell.execute_reply": "2021-08-21T07:51:52.510463Z" } }, "outputs": [], "source": [ "class Prenet(nn.Module):\n", " def __init__(self, in_dim, layers=2, hidden_dim=256, dropout=0.5):\n", " super().__init__()\n", " self.dropout = dropout\n", " prenet = nn.ModuleList()\n", " for layer in range(layers):\n", " prenet += [\n", " nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim),\n", " nn.ReLU(),\n", " ]\n", " self.prenet = nn.Sequential(*prenet)\n", "\n", " def forward(self, x):\n", " for layer in self.prenet:\n", " # 学習時、推論時の両方で Dropout を適用します\n", " x = F.dropout(layer(x), self.dropout, training=True)\n", " return x" ] }, { "cell_type": "code", "execution_count": 32, "id": "thousand-holocaust", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.512838Z", "iopub.status.busy": "2021-08-21T07:51:52.512444Z", "iopub.status.idle": "2021-08-21T07:51:52.515897Z", "shell.execute_reply": "2021-08-21T07:51:52.516239Z" } }, "outputs": [ { "data": { "text/plain": [ "Prenet(\n", " (prenet): Sequential(\n", " (0): Linear(in_features=80, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=256, out_features=256, bias=True)\n", " (3): ReLU()\n", " )\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Prenet(80)" ] }, { "cell_type": "code", "execution_count": 33, "id": "running-broadcasting", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.518638Z", "iopub.status.busy": "2021-08-21T07:51:52.518256Z", "iopub.status.idle": "2021-08-21T07:51:52.521584Z", "shell.execute_reply": "2021-08-21T07:51:52.521223Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "デコーダの入力のサイズ: (2, 80)\n", "Pre-Net の出力のサイズ: (2, 256)\n" ] } ], "source": [ "decoder_input = torch.ones(len(seqs), 80)\n", "\n", "prenet = Prenet(80)\n", "out = prenet(decoder_input)\n", "print(f\"デコーダの入力のサイズ: {tuple(decoder_input.shape)}\")\n", "print(f\"Pre-Net の出力のサイズ: {tuple(out.shape)}\")" ] }, { "cell_type": "markdown", "id": "second-junction", "metadata": {}, "source": [ "### 注意機構付きデコーダ" ] }, { "cell_type": "code", "execution_count": 34, "id": "apparent-consensus", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.530333Z", "iopub.status.busy": "2021-08-21T07:51:52.529926Z", "iopub.status.idle": "2021-08-21T07:51:52.669350Z", "shell.execute_reply": "2021-08-21T07:51:52.669743Z" } }, "outputs": [], "source": [ "from ttslearn.tacotron.decoder import ZoneOutCell\n", "\n", "class Decoder(nn.Module):\n", " def __init__(\n", " self,\n", " encoder_hidden_dim=512,\n", " out_dim=80,\n", " layers=2,\n", " hidden_dim=1024,\n", " prenet_layers=2,\n", " prenet_hidden_dim=256,\n", " prenet_dropout=0.5,\n", " zoneout=0.1,\n", " reduction_factor=1,\n", " attention_hidden_dim=128,\n", " attention_conv_channels=32,\n", " attention_conv_kernel_size=31,\n", " ):\n", " super().__init__()\n", " self.out_dim = out_dim\n", "\n", " # 注意機構\n", " self.attention = LocationSensitiveAttention(\n", " encoder_hidden_dim,\n", " hidden_dim,\n", " attention_hidden_dim,\n", " attention_conv_channels,\n", " attention_conv_kernel_size,\n", " )\n", " self.reduction_factor = reduction_factor\n", "\n", " # Prenet\n", " self.prenet = Prenet(out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout)\n", "\n", " # 片方向LSTM\n", " self.lstm = nn.ModuleList()\n", " for layer in range(layers):\n", " lstm = nn.LSTMCell(\n", " encoder_hidden_dim + prenet_hidden_dim if layer == 0 else hidden_dim,\n", " hidden_dim,\n", " )\n", " lstm = ZoneOutCell(lstm, zoneout)\n", " self.lstm += [lstm]\n", "\n", " # 出力への projection 層\n", " proj_in_dim = encoder_hidden_dim + hidden_dim\n", " self.feat_out = nn.Linear(proj_in_dim, out_dim * reduction_factor, bias=False)\n", " self.prob_out = nn.Linear(proj_in_dim, reduction_factor)\n", "\n", " def _zero_state(self, hs):\n", " init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)\n", " return init_hs\n", "\n", " def forward(self, encoder_outs, in_lens, decoder_targets=None):\n", " is_inference = decoder_targets is None\n", "\n", " # Reduction factor に基づくフレーム数の調整\n", " # (B, Lmax, out_dim) -> (B, Lmax/r, out_dim)\n", " if self.reduction_factor > 1 and not is_inference:\n", " decoder_targets = decoder_targets[\n", " :, self.reduction_factor - 1 :: self.reduction_factor\n", " ]\n", "\n", " # デコーダの系列長を保持\n", " # 推論時は、エンコーダの系列長から経験的に上限を定める\n", " if is_inference:\n", " max_decoder_time_steps = int(encoder_outs.shape[1] * 10.0)\n", " else:\n", " max_decoder_time_steps = decoder_targets.shape[1]\n", "\n", " # ゼロパディングされた部分に対するマスク\n", " mask = make_pad_mask(in_lens).to(encoder_outs.device)\n", "\n", " # LSTM の状態をゼロで初期化\n", " h_list, c_list = [], []\n", " for _ in range(len(self.lstm)):\n", " h_list.append(self._zero_state(encoder_outs))\n", " c_list.append(self._zero_state(encoder_outs))\n", "\n", " # デコーダの最初の入力\n", " go_frame = encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim)\n", " prev_out = go_frame\n", "\n", " # 1つ前の時刻のアテンション重み\n", " prev_att_w = None\n", "\n", " # メインループ\n", " outs, logits, att_ws = [], [], []\n", " t = 0\n", " while True:\n", " # コンテキストベクトル、アテンション重みの計算\n", " att_c, att_w = self.attention(\n", " encoder_outs, in_lens, h_list[0], prev_att_w, mask\n", " )\n", "\n", " # Pre-Net\n", " prenet_out = self.prenet(prev_out)\n", "\n", " # LSTM\n", " xs = torch.cat([att_c, prenet_out], dim=1)\n", " h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0]))\n", " for i in range(1, len(self.lstm)):\n", " h_list[i], c_list[i] = self.lstm[i](\n", " h_list[i - 1], (h_list[i], c_list[i])\n", " )\n", " # 出力の計算\n", " hcs = torch.cat([h_list[-1], att_c], dim=1)\n", " outs.append(self.feat_out(hcs).view(encoder_outs.size(0), self.out_dim, -1))\n", " logits.append(self.prob_out(hcs))\n", " att_ws.append(att_w)\n", "\n", " # 次の時刻のデコーダの入力を更新\n", " if is_inference:\n", " prev_out = outs[-1][:, :, -1] # (1, out_dim)\n", " else:\n", " # Teacher forcing\n", " prev_out = decoder_targets[:, t, :]\n", "\n", " # 累積アテンション重み\n", " prev_att_w = att_w if prev_att_w is None else prev_att_w + att_w\n", "\n", " t += 1\n", " # 停止条件のチェック\n", " if t >= max_decoder_time_steps:\n", " break\n", " if is_inference and (torch.sigmoid(logits[-1]) >= 0.5).any():\n", " break\n", " \n", " # 各時刻の出力を結合\n", " logits = torch.cat(logits, dim=1) # (B, Lmax)\n", " outs = torch.cat(outs, dim=2) # (B, out_dim, Lmax)\n", " att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)\n", "\n", " if self.reduction_factor > 1:\n", " outs = outs.view(outs.size(0), self.out_dim, -1) # (B, out_dim, Lmax)\n", "\n", " return outs, logits, att_ws" ] }, { "cell_type": "code", "execution_count": 35, "id": "integrated-lesbian", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.672090Z", "iopub.status.busy": "2021-08-21T07:51:52.671712Z", "iopub.status.idle": "2021-08-21T07:51:52.760681Z", "shell.execute_reply": "2021-08-21T07:51:52.760383Z" } }, "outputs": [ { "data": { "text/plain": [ "Decoder(\n", " (attention): LocationSensitiveAttention(\n", " (V): Linear(in_features=512, out_features=128, bias=True)\n", " (W): Linear(in_features=1024, out_features=128, bias=False)\n", " (U): Linear(in_features=32, out_features=128, bias=False)\n", " (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)\n", " (w): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " (prenet): Prenet(\n", " (prenet): Sequential(\n", " (0): Linear(in_features=80, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=256, out_features=256, bias=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (lstm): ModuleList(\n", " (0): ZoneOutCell(\n", " (cell): LSTMCell(768, 1024)\n", " )\n", " (1): ZoneOutCell(\n", " (cell): LSTMCell(1024, 1024)\n", " )\n", " )\n", " (feat_out): Linear(in_features=1536, out_features=80, bias=False)\n", " (prob_out): Linear(in_features=1536, out_features=1, bias=True)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Decoder()" ] }, { "cell_type": "code", "execution_count": 36, "id": "pharmaceutical-reasoning", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:52.763608Z", "iopub.status.busy": "2021-08-21T07:51:52.763211Z", "iopub.status.idle": "2021-08-21T07:51:53.364703Z", "shell.execute_reply": "2021-08-21T07:51:53.365084Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "デコーダの入力のサイズ: (2, 80)\n", "デコーダの出力のサイズ: (2, 80, 120)\n", "stop token (logits) のサイズ: (2, 120)\n", "アテンション重みのサイズ: (2, 120, 33)\n" ] } ], "source": [ "decoder_targets = torch.ones(encoder_outs.shape[0], 120, 80)\n", "decoder = Decoder(encoder_outs.shape[-1], 80)\n", "\n", "# Teaccher forcing: decoder_targets (教師データ) を与える\n", "with torch.no_grad():\n", " outs, logits, att_ws = decoder(encoder_outs, in_lens, decoder_targets);\n", "\n", "print(f\"デコーダの入力のサイズ: {tuple(decoder_input.shape)}\")\n", "print(f\"デコーダの出力のサイズ: {tuple(outs.shape)}\")\n", "print(f\"stop token (logits) のサイズ: {tuple(logits.shape)}\")\n", "print(f\"アテンション重みのサイズ: {tuple(att_ws.shape)}\")" ] }, { "cell_type": "code", "execution_count": 37, "id": "organic-roller", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.367215Z", "iopub.status.busy": "2021-08-21T07:51:53.366870Z", "iopub.status.idle": "2021-08-21T07:51:53.378744Z", "shell.execute_reply": "2021-08-21T07:51:53.378369Z" } }, "outputs": [], "source": [ "# 自己回帰に基づく推論\n", "with torch.no_grad():\n", " decoder(encoder_outs[0], torch.tensor([in_lens[0]]))" ] }, { "cell_type": "markdown", "id": "handled-pixel", "metadata": {}, "source": [ "## 9.6 Post-Net" ] }, { "cell_type": "code", "execution_count": 38, "id": "further-apparel", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.382602Z", "iopub.status.busy": "2021-08-21T07:51:53.382257Z", "iopub.status.idle": "2021-08-21T07:51:53.383541Z", "shell.execute_reply": "2021-08-21T07:51:53.383879Z" } }, "outputs": [], "source": [ "class Postnet(nn.Module):\n", " def __init__(\n", " self,\n", " in_dim=80,\n", " layers=5,\n", " channels=512,\n", " kernel_size=5,\n", " dropout=0.5,\n", " ):\n", " super().__init__()\n", " postnet = nn.ModuleList()\n", " for layer in range(layers):\n", " in_channels = in_dim if layer == 0 else channels\n", " out_channels = in_dim if layer == layers - 1 else channels\n", " postnet += [\n", " nn.Conv1d(\n", " in_channels,\n", " out_channels,\n", " kernel_size,\n", " stride=1,\n", " padding=(kernel_size - 1) // 2,\n", " bias=False,\n", " ),\n", " nn.BatchNorm1d(out_channels),\n", " ]\n", " if layer != layers - 1:\n", " postnet += [nn.Tanh()]\n", " postnet += [nn.Dropout(dropout)]\n", " self.postnet = nn.Sequential(*postnet)\n", "\n", " def forward(self, xs):\n", " return self.postnet(xs)" ] }, { "cell_type": "code", "execution_count": 39, "id": "sunrise-thumb", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.385807Z", "iopub.status.busy": "2021-08-21T07:51:53.385479Z", "iopub.status.idle": "2021-08-21T07:51:53.409831Z", "shell.execute_reply": "2021-08-21T07:51:53.409530Z" } }, "outputs": [ { "data": { "text/plain": [ "Postnet(\n", " (postnet): Sequential(\n", " (0): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): Tanh()\n", " (3): Dropout(p=0.5, inplace=False)\n", " (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): Tanh()\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (10): Tanh()\n", " (11): Dropout(p=0.5, inplace=False)\n", " (12): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (14): Tanh()\n", " (15): Dropout(p=0.5, inplace=False)\n", " (16): Conv1d(512, 80, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (17): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (18): Dropout(p=0.5, inplace=False)\n", " )\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Postnet()" ] }, { "cell_type": "code", "execution_count": 40, "id": "mexican-watson", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.412246Z", "iopub.status.busy": "2021-08-21T07:51:53.411907Z", "iopub.status.idle": "2021-08-21T07:51:53.463029Z", "shell.execute_reply": "2021-08-21T07:51:53.463369Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 80, 120)\n", "出力のサイズ: (2, 80, 120)\n" ] } ], "source": [ "postnet = Postnet(80)\n", "residual = postnet(outs)\n", "\n", "print(f\"入力のサイズ: {tuple(outs.shape)}\")\n", "print(f\"出力のサイズ: {tuple(residual.shape)}\")" ] }, { "cell_type": "markdown", "id": "rational-promotion", "metadata": {}, "source": [ "## 9.7 Tacotron 2 の実装" ] }, { "cell_type": "markdown", "id": "considered-lightning", "metadata": {}, "source": [ "### Tacotron 2 のモデル定義" ] }, { "cell_type": "code", "execution_count": 41, "id": "center-tyler", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.468755Z", "iopub.status.busy": "2021-08-21T07:51:53.468404Z", "iopub.status.idle": "2021-08-21T07:51:53.469710Z", "shell.execute_reply": "2021-08-21T07:51:53.470042Z" } }, "outputs": [], "source": [ "class Tacotron2(nn.Module):\n", " def __init__(self\n", " ):\n", " super().__init__()\n", " self.encoder = Encoder()\n", " self.decoder = Decoder()\n", " self.postnet = Postnet()\n", "\n", " def forward(self, seq, in_lens, decoder_targets):\n", " # エンコーダによるテキストに潜在する表現の獲得\n", " encoder_outs = self.encoder(seq, in_lens)\n", "\n", " # デコーダによるメルスペクトログラム、stop token の予測\n", " outs, logits, att_ws = self.decoder(encoder_outs, in_lens, decoder_targets)\n", "\n", " # Post-Net によるメルスペクトログラムの残差の予測\n", " outs_fine = outs + self.postnet(outs)\n", "\n", " # (B, C, T) -> (B, T, C)\n", " outs = outs.transpose(2, 1)\n", " outs_fine = outs_fine.transpose(2, 1)\n", "\n", " return outs, outs_fine, logits, att_ws\n", " \n", " def inference(self, seq):\n", " seq = seq.unsqueeze(0) if len(seq.shape) == 1 else seq\n", " in_lens = torch.tensor([seq.shape[-1]], dtype=torch.long, device=seq.device)\n", "\n", " return self.forward(seq, in_lens, None)" ] }, { "cell_type": "code", "execution_count": 42, "id": "united-rating", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:53.472779Z", "iopub.status.busy": "2021-08-21T07:51:53.472439Z", "iopub.status.idle": "2021-08-21T07:51:54.297932Z", "shell.execute_reply": "2021-08-21T07:51:54.298274Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 33)\n", "デコーダの出力のサイズ: (2, 120, 80)\n", "Post-Netの出力のサイズ: (2, 120, 80)\n", "stop token (logits) のサイズ: (2, 120)\n", "アテンション重みのサイズ: (2, 120, 33)\n" ] } ], "source": [ "seqs, in_lens = get_dummy_input()\n", "model = Tacotron2()\n", "\n", "# Tacotron 2 の計算\n", "outs, outs_fine, logits, att_ws = model(seqs, in_lens, decoder_targets)\n", "\n", "print(f\"入力のサイズ: {tuple(seqs.shape)}\")\n", "print(f\"デコーダの出力のサイズ: {tuple(outs.shape)}\")\n", "print(f\"Post-Netの出力のサイズ: {tuple(outs_fine.shape)}\")\n", "print(f\"stop token (logits) のサイズ: {tuple(logits.shape)}\")\n", "print(f\"アテンション重みのサイズ: {tuple(att_ws.shape)}\")" ] }, { "cell_type": "code", "execution_count": 43, "id": "resistant-nevada", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.301165Z", "iopub.status.busy": "2021-08-21T07:51:54.300831Z", "iopub.status.idle": "2021-08-21T07:51:54.303004Z", "shell.execute_reply": "2021-08-21T07:51:54.303345Z" } }, "outputs": [ { "data": { "text/plain": [ "Tacotron2(\n", " (encoder): Encoder(\n", " (embed): Embedding(40, 512, padding_idx=0)\n", " (convs): Sequential(\n", " (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Dropout(p=0.5, inplace=False)\n", " (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): ReLU()\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (10): ReLU()\n", " (11): Dropout(p=0.5, inplace=False)\n", " )\n", " (blstm): LSTM(512, 256, batch_first=True, bidirectional=True)\n", " )\n", " (decoder): Decoder(\n", " (attention): LocationSensitiveAttention(\n", " (V): Linear(in_features=512, out_features=128, bias=True)\n", " (W): Linear(in_features=1024, out_features=128, bias=False)\n", " (U): Linear(in_features=32, out_features=128, bias=False)\n", " (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)\n", " (w): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " (prenet): Prenet(\n", " (prenet): Sequential(\n", " (0): Linear(in_features=80, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=256, out_features=256, bias=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (lstm): ModuleList(\n", " (0): ZoneOutCell(\n", " (cell): LSTMCell(768, 1024)\n", " )\n", " (1): ZoneOutCell(\n", " (cell): LSTMCell(1024, 1024)\n", " )\n", " )\n", " (feat_out): Linear(in_features=1536, out_features=80, bias=False)\n", " (prob_out): Linear(in_features=1536, out_features=1, bias=True)\n", " )\n", " (postnet): Postnet(\n", " (postnet): Sequential(\n", " (0): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): Tanh()\n", " (3): Dropout(p=0.5, inplace=False)\n", " (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): Tanh()\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (10): Tanh()\n", " (11): Dropout(p=0.5, inplace=False)\n", " (12): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (14): Tanh()\n", " (15): Dropout(p=0.5, inplace=False)\n", " (16): Conv1d(512, 80, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " (17): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (18): Dropout(p=0.5, inplace=False)\n", " )\n", " )\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "id": "affecting-extraction", "metadata": {}, "source": [ "### トイモデルを利用したTacotron 2の動作確認" ] }, { "cell_type": "code", "execution_count": 44, "id": "expired-style", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.305691Z", "iopub.status.busy": "2021-08-21T07:51:54.305353Z", "iopub.status.idle": "2021-08-21T07:51:54.387444Z", "shell.execute_reply": "2021-08-21T07:51:54.387804Z" } }, "outputs": [], "source": [ "from ttslearn.tacotron import Tacotron2\n", "model = Tacotron2(encoder_conv_layers=1, decoder_prenet_layers=1, decoder_layers=1, postnet_layers=1)" ] }, { "cell_type": "code", "execution_count": 45, "id": "color-cabinet", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.390837Z", "iopub.status.busy": "2021-08-21T07:51:54.390488Z", "iopub.status.idle": "2021-08-21T07:51:54.391818Z", "shell.execute_reply": "2021-08-21T07:51:54.392150Z" } }, "outputs": [], "source": [ "def get_dummy_inout():\n", " seqs, in_lens = get_dummy_input()\n", " \n", " # デコーダの出力(メルスペクトログラム)の教師データ\n", " decoder_targets = torch.ones(2, 120, 80)\n", " \n", " # stop token の教師データ\n", " # stop token の予測値は確率ですが、教師データは 二値のラベルです\n", " # 1 は、デコーダの出力が完了したことを表します\n", " stop_tokens = torch.zeros(2, 120)\n", " stop_tokens[:, -1:] = 1.0\n", " \n", " return seqs, in_lens, decoder_targets, stop_tokens" ] }, { "cell_type": "code", "execution_count": 46, "id": "cardiac-grant", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.395435Z", "iopub.status.busy": "2021-08-21T07:51:54.394983Z", "iopub.status.idle": "2021-08-21T07:51:54.753068Z", "shell.execute_reply": "2021-08-21T07:51:54.752683Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "入力のサイズ: (2, 33)\n", "デコーダの出力のサイズ: (2, 120, 80)\n", "Stop token のサイズ: (2, 120)\n", "アテンション重みのサイズ: (2, 120, 33)\n" ] } ], "source": [ "# 適当な入出力を生成\n", "seqs, in_lens, decoder_targets, stop_tokens = get_dummy_inout()\n", "\n", "# Tacotron 2 の出力を計算\n", "# NOTE: teacher-forcing のため、 decoder targets を明示的に与える\n", "outs, outs_fine, logits, att_ws = model(seqs, in_lens, decoder_targets)\n", "\n", "print(\"入力のサイズ:\", tuple(seqs.shape))\n", "print(\"デコーダの出力のサイズ:\", tuple(outs.shape))\n", "print(\"Stop token のサイズ:\", tuple(logits.shape))\n", "print(\"アテンション重みのサイズ:\", tuple(att_ws.shape))" ] }, { "cell_type": "markdown", "id": "decreased-commercial", "metadata": {}, "source": [ "### Tacotron 2の損失関数の計算" ] }, { "cell_type": "code", "execution_count": 47, "id": "communist-republican", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.755565Z", "iopub.status.busy": "2021-08-21T07:51:54.755222Z", "iopub.status.idle": "2021-08-21T07:51:54.757739Z", "shell.execute_reply": "2021-08-21T07:51:54.758079Z" } }, "outputs": [], "source": [ "# 1. デコーダの出力に対する損失\n", "out_loss = nn.MSELoss()(outs, decoder_targets)\n", "# 2. Post-Net のあとの出力に対する損失\n", "out_fine_loss = nn.MSELoss()(outs_fine, decoder_targets)\n", "# 3. Stop token に対する損失\n", "stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_tokens)" ] }, { "cell_type": "code", "execution_count": 48, "id": "individual-motor", "metadata": { "execution": { "iopub.execute_input": "2021-08-21T07:51:54.760416Z", "iopub.status.busy": "2021-08-21T07:51:54.759949Z", "iopub.status.idle": "2021-08-21T07:51:54.762590Z", "shell.execute_reply": "2021-08-21T07:51:54.763125Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out_loss: 0.9949015378952026\n", "out_fine_loss: 2.896300792694092\n", "stop_token_loss: 0.6844527125358582\n" ] } ], "source": [ "print(\"out_loss: \", out_loss.item())\n", "print(\"out_fine_loss: \", out_fine_loss.item())\n", "print(\"stop_token_loss: \", stop_token_loss.item())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }