Source code for ttslearn.tacotron.attention

# Acknowledgement: some of the code was adapted from ESPnet
#  Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import torch
from torch import nn
from torch.nn import functional as F
from ttslearn.util import make_pad_mask


[docs]class BahdanauAttention(nn.Module): """Bahdanau-style attention This is an attention mechanism originally used in Tacotron. Args: encoder_dim (int): dimension of encoder outputs decoder_dim (int): dimension of decoder outputs hidden_dim (int): dimension of hidden state """ def __init__(self, encoder_dim=512, decoder_dim=1024, hidden_dim=128): super().__init__() self.mlp_enc = nn.Linear(encoder_dim, hidden_dim) self.mlp_dec = nn.Linear(decoder_dim, hidden_dim, bias=False) self.w = nn.Linear(hidden_dim, 1) self.processed_memory = None
[docs] def reset(self): """Reset the internal buffer""" self.processed_memory = None
[docs] def forward( self, encoder_outs, src_lens, decoder_state, mask=None, ): """Forward step Args: encoder_outs (torch.FloatTensor): encoder outputs src_lens (list): length of each input batch decoder_state (torch.FloatTensor): decoder hidden state mask (torch.FloatTensor): mask for padding """ # エンコーダに全結合層を適用した結果を保持 if self.processed_memory is None: self.processed_memory = self.mlp_enc(encoder_outs) # (B, 1, hidden_dim) decoder_state = self.mlp_dec(decoder_state).unsqueeze(1) # NOTE: アテンションエネルギーは、デコーダの隠れ状態を入力として、 # エンコーダの特徴量のみによって決まる erg = self.w(torch.tanh(self.processed_memory + decoder_state)).squeeze(-1) if mask is not None: erg.masked_fill_(mask, -float("inf")) attention_weights = F.softmax(erg, dim=1) # エンコーダ出力の長さ方向に対して重み付き和を取ります attention_context = torch.sum( encoder_outs * attention_weights.unsqueeze(-1), dim=1 ) return attention_context, attention_weights
[docs]class LocationSensitiveAttention(nn.Module): """Location-sensitive attention This is an attention mechanism used in Tacotron 2. Args: encoder_dim (int): dimension of encoder outputs decoder_dim (int): dimension of decoder outputs hidden_dim (int): dimension of hidden state conv_channels (int): number of channels of convolutional layer conv_kernel_size (int): size of convolutional kernel """ def __init__( self, encoder_dim=512, decoder_dim=1024, hidden_dim=128, conv_channels=32, conv_kernel_size=31, ): super().__init__() self.mlp_enc = nn.Linear(encoder_dim, hidden_dim) self.mlp_dec = nn.Linear(decoder_dim, hidden_dim, bias=False) self.mlp_att = nn.Linear(conv_channels, hidden_dim, bias=False) assert conv_kernel_size % 2 == 1 self.loc_conv = nn.Conv1d( 1, conv_channels, conv_kernel_size, padding=(conv_kernel_size - 1) // 2, bias=False, ) self.w = nn.Linear(hidden_dim, 1) self.processed_memory = None
[docs] def reset(self): """Reset the internal buffer""" self.processed_memory = None
[docs] def forward( self, encoder_outs, src_lens, decoder_state, att_prev, mask=None, ): """Forward step Args: encoder_outs (torch.FloatTensor): encoder outputs src_lens (list): length of each input batch decoder_state (torch.FloatTensor): decoder hidden state att_prev (torch.FloatTensor): previous attention weight mask (torch.FloatTensor): mask for padding """ # エンコーダに全結合層を適用した結果を保持 if self.processed_memory is None: self.processed_memory = self.mlp_enc(encoder_outs) # アテンション重みを一様分布で初期化 if att_prev is None: att_prev = 1.0 - make_pad_mask(src_lens).to( device=decoder_state.device, dtype=decoder_state.dtype ) att_prev = att_prev / src_lens.unsqueeze(-1).to(encoder_outs.device) # (B, T_enc) -> (B, 1, T_enc) -> (B, conv_channels, T_enc) -> # (B, T_enc, conv_channels) att_conv = self.loc_conv(att_prev.unsqueeze(1)).transpose(1, 2) # (B, T_enc, hidden_dim) att_conv = self.mlp_att(att_conv) # (B, 1, hidden_dim) decoder_state = self.mlp_dec(decoder_state).unsqueeze(1) # NOTE: アテンションエネルギーは、デコーダの隠れ状態を入力として、次の2 つに依存します # 1) デコーダの前の時刻におけるアテンション重み # 2) エンコーダの隠れ状態 erg = self.w( torch.tanh(att_conv + self.processed_memory + decoder_state) ).squeeze(-1) if mask is not None: erg.masked_fill_(mask, -float("inf")) attention_weights = F.softmax(erg, dim=1) # エンコーダ出力の長さ方向に対して重み付き和を取ります attention_context = torch.sum( encoder_outs * attention_weights.unsqueeze(-1), dim=1 ) return attention_context, attention_weights
# 書籍中の数式に沿って、わかりやすさを重視した実装 class BahdanauAttentionNaive(nn.Module): def __init__(self, encoder_dim=512, decoder_dim=1024, hidden_dim=128): super().__init__() self.V = nn.Linear(encoder_dim, hidden_dim) self.W = nn.Linear(decoder_dim, hidden_dim, bias=False) # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません self.w = nn.Linear(hidden_dim, 1) def forward( self, encoder_outs, decoder_state, mask=None, ): # 式 (9.11) の計算 erg = self.w( torch.tanh(self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs)) ).squeeze(-1) if mask is not None: erg.masked_fill_(mask, -float("inf")) attention_weights = F.softmax(erg, dim=1) # エンコーダ出力の長さ方向に対して重み付き和を取ります attention_context = torch.sum( encoder_outs * attention_weights.unsqueeze(-1), dim=1 ) return attention_context, attention_weights # 書籍中の数式に沿って、わかりやすさを重視した実装 class LocationSensitiveAttentionNaive(nn.Module): def __init__( self, encoder_dim=512, decoder_dim=1024, hidden_dim=128, conv_channels=32, conv_kernel_size=31, ): super().__init__() self.V = nn.Linear(encoder_dim, hidden_dim) self.W = nn.Linear(decoder_dim, hidden_dim, bias=False) self.U = nn.Linear(conv_channels, hidden_dim, bias=False) self.F = nn.Conv1d( 1, conv_channels, conv_kernel_size, padding=(conv_kernel_size - 1) // 2, bias=False, ) # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません self.w = nn.Linear(hidden_dim, 1) def forward( self, encoder_outs, src_lens, decoder_state, att_prev, mask=None, ): # アテンション重みを一様分布で初期化 if att_prev is None: att_prev = 1.0 - make_pad_mask(src_lens).to( device=decoder_state.device, dtype=decoder_state.dtype ) att_prev = att_prev / src_lens.unsqueeze(-1).to(encoder_outs.device) # (B, T_enc) -> (B, 1, T_enc) -> (B, conv_channels, T_enc) -> # (B, T_enc, conv_channels) f = self.F(att_prev.unsqueeze(1)).transpose(1, 2) # 式 (9.13) の計算 erg = self.w( torch.tanh( self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs) + self.U(f) ) ).squeeze(-1) if mask is not None: erg.masked_fill_(mask, -float("inf")) attention_weights = F.softmax(erg, dim=1) # エンコーダ出力の長さ方向に対して重み付き和を取ります attention_context = torch.sum( encoder_outs * attention_weights.unsqueeze(-1), dim=1 ) return attention_context, attention_weights