from __future__ import absolute_import, print_function, with_statement
import numpy as np
import torch
from nnmnkwii.preprocessing.modspec import modspec as _modspec
from torch.autograd import Function
[docs]class ModSpec(Function):
"""Modulation spectrum computation ``f : (T, D) -> (N//2+1, D)``.
Args:
n (int): DFT length.
norm (bool): Normalize DFT output or not. See :obj:`numpy.fft.fft`.
"""
[docs] @staticmethod
def forward(ctx, y, n, norm):
ctx.n = n
ctx.norm = norm
assert y.dim() == 2
ctx.save_for_backward(y)
y_np = y.detach().numpy()
ms = torch.from_numpy(_modspec(y_np, n=n, norm=norm))
return ms
[docs] @staticmethod
def backward(ctx, grad_output):
(y,) = ctx.saved_tensors
T, D = y.size()
assert grad_output.size() == torch.Size((ctx.n // 2 + 1, D))
y_np = y.detach().numpy()
kt = -2 * np.pi / ctx.n * np.arange(ctx.n // 2 + 1)[:, None] * np.arange(T)
assert kt.shape == (ctx.n // 2 + 1, T)
cos_table = np.cos(kt)
sin_table = np.sin(kt)
s_complex = np.fft.rfft(
y_np, n=ctx.n, axis=0, norm=ctx.norm
) # DFT against time axis
assert s_complex.shape == (ctx.n // 2 + 1, D)
R, im = s_complex.real, s_complex.imag
grads = torch.zeros(T, D)
C = 2 # normalization constant
if ctx.norm == "ortho":
C /= np.sqrt(ctx.n)
for d in range(D):
r = R[:, d][:, None]
i = im[:, d][:, None]
grad = C * (r * cos_table + i * sin_table)
assert grad.shape == sin_table.shape
grads[:, d] = torch.from_numpy(grad_output[:, d].numpy().T.dot(grad))
return grads, None, None
[docs]def modspec(y, n=2048, norm=None):
"""Moduration spectrum computation.
Args:
y (torch.autograd.Variable): Parameter trajectory.
n (int): DFT length.
norm (bool): Normalize DFT output or not. See :obj:`numpy.fft.fft`.
"""
return ModSpec.apply(y, n, norm)