【问题标题】:How to correctly use Numpy's FFT function in PyTorch?如何在 PyTorch 中正确使用 Numpy 的 FFT 函数?
【发布时间】:2021-03-20 20:21:45
【问题描述】:

我最近接触了 PyTorch,并开始浏览该库的文档和教程。 在"Creating extensions using numpy and scipy" 教程中,在“无参数示例”下,使用名为BadFFTFunction 的numpy 创建了一个示例函数。

函数说明:

“这个层并没有做任何有用的或数学上的事情 正确。

它被恰当地命名为 BadFFTFunction"

函数及其用法如下:

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

def incorrect_fft(input):
    return BadFFTFunction()(input)

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

不幸的是,我也是最近才被介绍到信号处理的,我不确定这个函数的(可能很明显)错误在哪里。

我想知道,如何修复此函数以使其前向和后向输出正确?

如何修复 BadFFTFunction 以便在 PyTorch 中使用可微分的 FFT 函数?

【问题讨论】:

    标签: python numpy pytorch fft


    【解决方案1】:

    我认为错误是:首先,该函数尽管名称中有 FFT,但仅返回 FFT 输出的幅度/绝对值,而不是完整的复系数。此外,仅使用逆 FFT 来计算幅度的梯度在数学上可能没有多大意义(?)。

    有一个名为 pytorch-fft 的包试图在 pytorch 中提供 FFT 函数。你可以看到一些 autograd 功能的实验代码here。另请注意此issue 中的讨论。

    【讨论】:

      【解决方案2】:

      截至 1.8 版,PyTorch 有 torch.fft:

      torch.fft.fft(input)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2021-07-29
        • 1970-01-01
        • 1970-01-01
        • 2019-10-17
        相关资源
        最近更新 更多