【问题标题】:Extending Pytorch: Python vs. C++ vs. CUDA扩展 Pytorch:Python vs. C++ vs. CUDA
【发布时间】:2021-09-02 06:17:20
【问题描述】:

我一直在尝试实现一个自定义 Conv2d 模块,其中 grad_input (dx) 和 grad_weight (dw) 是通过使用不同的 grad_output (dy) 值来计算的。我通过在 Pytorch 教程中扩展 torch.autograd 来实现这一点。

但是我对in this link的信息感到困惑。

  • 扩展 autograd.Function 还不够吗?
  • 有什么区别 在 Python 和 C++ 中编写新的 autograd 函数之间?
  • 怎么样 CUDA 实现 /torch/nn/blob/master/lib/THNN/generic/SpatialConvolutionMM.c 其中 dx 和 dw 计算出来的?我也应该更改它们吗?

这是我的自定义函数:

class myCustomConv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, bias=None, stride=1, padding=0, dilation=1, groups=1):
    ctx.save_for_backward(x, w, bias)
    ctx.stride = stride
    ctx.padding = padding
    ctx.dilation = dilation
    ctx.groups = groups
    out = F.conv2d(x, w, bias, stride, padding, dilation, groups)
    return out

@staticmethod
def backward(ctx, grad_output):
    input, weight, bias = ctx.saved_tensors
    stride = ctx.stride
    padding = ctx.padding
    dilation = ctx.dilation
    groups = ctx.groups
    grad_input = grad_weight = grad_bias = None

    dy_for_inputs = myspecialfunction1(grad_output)
    dy_for_weights = myspecialfunction2(grad_output)

    grad_input = torch.nn.grad.conv2d_input(input.shape, weight, dy_for_inputs , stride, padding, dilation, groups)
    grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, dy_for_weights , stride, padding, dilation, groups)

    if bias is not None and ctx.needs_input_grad[2]:
        grad_bias = dy_for_weights .sum((0,2,3)).squeeze(0)

    return grad_input, grad_weight, grad_bias, None, None, None, None

【问题讨论】:

    标签: pytorch conv-neural-network autograd


    【解决方案1】:

    扩展 autograd.Function 还不够吗?

    如果您的代码重用包装在 Python 接口中的 Pytorch 组件就足够了(似乎是这样)。渐变是自动合成的。

    用 Python 和 C++ 编写新的 autograd 函数有什么区别?

    性能,您的操作越自定义(并且从现有 Pytorch 操作中组合它的难度越大),您获得的性能提升就越大。

    /torch/nn/blob/master/lib/THNN/generic/SpatialConvolutionMM.c 中的 CUDA 实现如何计算 dx 和 dw?我也应该改变它们吗?

    没有必要,除非你想为 CUDA 创建专门的操作

    【讨论】:

      猜你喜欢
      • 2011-02-28
      • 1970-01-01
      • 1970-01-01
      • 2023-03-09
      • 1970-01-01
      • 1970-01-01
      • 2013-04-02
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多