Source code for ttslearn.wavenet.modules

import torch
from torch import nn
from ttslearn.wavenet import conv


def Conv1d(in_channels, out_channels, kernel_size, *args, **kwargs):
    """Weight-normalized Conv1d layer."""
    m = conv.Conv1d(in_channels, out_channels, kernel_size, *args, **kwargs)
    return nn.utils.weight_norm(m)


def Conv1d1x1(in_channels, out_channels, bias=True):
    """1x1 Weight-normalized Conv1d layer."""
    return Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)


[docs]class ResSkipBlock(nn.Module): """Convolution block with residual and skip connections. Args: residual_channels (int): Residual connection channels. gate_channels (int): Gated activation channels. kernel_size (int): Kernel size of convolution layers. skip_out_channels (int): Skip connection channels. dilation (int): Dilation factor. cin_channels (int): Local conditioning channels. args (list): Additional arguments for Conv1d. kwargs (dict): Additional arguments for Conv1d. """ def __init__( self, residual_channels, # 残差結合のチャネル数 gate_channels, # ゲートのチャネル数 kernel_size, # カーネルサイズ skip_out_channels, # スキップ結合のチャネル数 dilation=1, # dilation factor cin_channels=80, # 条件付特徴量のチャネル数 *args, **kwargs, ): super().__init__() self.padding = (kernel_size - 1) * dilation # 1 次元膨張畳み込み (dilation == 1 のときは、通常の1 次元畳み込み) self.conv = Conv1d( residual_channels, gate_channels, kernel_size, padding=self.padding, dilation=dilation, *args, **kwargs, ) # local conditioning 用の 1x1 convolution self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) # ゲート付き活性化関数のために、1 次元畳み込みの出力は2 分割されることに注意 gate_out_channels = gate_channels // 2 self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels) self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels)
[docs] def forward(self, x, c): """Forward step Args: x (torch.Tensor): Input signal. c (torch.Tensor): Local conditioning signal. Returns: tuple: Tuple of output signal and skip connection signal """ return self._forward(x, c, False)
[docs] def incremental_forward(self, x, c): """Incremental forward Args: x (torch.Tensor): Input signal. c (torch.Tensor): Local conditioning signal. Returns: tuple: Tuple of output signal and skip connection signal """ return self._forward(x, c, True)
def _forward(self, x, c, is_incremental): # 残差接続用に入力を保持 residual = x # メインの dilated convolutionの計算 # 推論時と学習時で入力のテンソルのshapeが異なるのに注意 if is_incremental: splitdim = -1 # (B, T, C) x = self.conv.incremental_forward(x) else: splitdim = 1 # (B, C, T) x = self.conv(x) # 因果性を保証するために、出力をシフトする x = x[:, :, : -self.padding] # チャンネル方向で出力を分割 a, b = x.split(x.size(splitdim) // 2, dim=splitdim) # local conditioning c = self._conv1x1_forward(self.conv1x1c, c, is_incremental) ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) a, b = a + ca, b + cb # ゲート付き活性化関数 x = torch.tanh(a) * torch.sigmoid(b) # スキップ接続用の出力を計算 s = self._conv1x1_forward(self.conv1x1_skip, x, is_incremental) # 残差接続の要素和行う前に、次元数を合わせる x = self._conv1x1_forward(self.conv1x1_out, x, is_incremental) x = x + residual return x, s def _conv1x1_forward(self, conv, x, is_incremental): if is_incremental: x = conv.incremental_forward(x) else: x = conv(x) return x
[docs] def clear_buffer(self): """Clear input buffer.""" for c in [ self.conv, self.conv1x1_out, self.conv1x1_skip, self.conv1x1c, ]: if c is not None: c.clear_buffer()