【发布时间】:2021-03-20 20:21:45
【问题描述】:
我最近接触了 PyTorch,并开始浏览该库的文档和教程。
在"Creating extensions using numpy and scipy" 教程中,在“无参数示例”下,使用名为BadFFTFunction 的numpy 创建了一个示例函数。
函数说明:
“这个层并没有做任何有用的或数学上的事情 正确。
它被恰当地命名为 BadFFTFunction"
函数及其用法如下:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)
不幸的是,我也是最近才被介绍到信号处理的,我不确定这个函数的(可能很明显)错误在哪里。
我想知道,如何修复此函数以使其前向和后向输出正确?
如何修复 BadFFTFunction 以便在 PyTorch 中使用可微分的 FFT 函数?
【问题讨论】: