【问题标题】:How to mask weights in PyTorch weight parameters?如何掩盖 PyTorch 权重参数中的权重?
【发布时间】:2019-05-01 20:43:23
【问题描述】:

我试图在 PyTorch 中屏蔽(强制为零)特定的权重值。我试图掩盖的权重在def __init__

class LSTM_MASK(nn.Module):
        def __init__(self, options, inp_dim):
            super(LSTM_MASK, self).__init__()
            ....
            self.wfx = nn.Linear(input_dim, curernt_output, bias=add_bias)

掩码也在def __init__中定义为

self.mask_use = torch.Tensor(curernt_output, input_dim)

掩码是一个常数,.requires_grad_()False 的掩码参数。现在在课程的def forward 部分中,我尝试在线性运算完成之前对权重参数和掩码进行元素乘法

def forward(self, x):
    ....
    self.wfx.weight = self.wfx.weight * self.mask_use
    wfx_out = self.wfx(x)

我收到一条错误消息:

self.wfx.weight = self.wfx.weight * self.mask_use
  File "/home/xyz/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 537, in __setattr__
    .format(torch.typename(value), name))
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

但是当我使用.type() 检查这两个参数时,它们都显示为torch.cuda.FloatTensor。我不知道为什么这里有错误。

【问题讨论】:

    标签: python pytorch lstm


    【解决方案1】:

    逐元素操作总是返回FloatTensor。无法将正常张量分配为层的weight

    有两种可能的选择来处理它。您可以将其分配给您体重的data 属性,在那里可以分配正常张量。

    或者,您也可以将结果转换为 nn.Parameter 本身,然后您可以将其分配给 wfx.weight

    这是一个显示两种方式的示例:

    import torch
    import torch.nn as nn
    
    wfx = nn.Linear(10, 10)
    mask_use = torch.rand(10, 10)
    #wfx.weight = wfx.weight * mask_use #your example - this raises an error
    
    # Option 1: write directly to data
    wfx.weight.data = wfx.weight * mask_use
    
    # Option 2: convert result to nn.Parameter and write to weight
    wfx.weight = nn.Parameter(wfx.weight * mask_use)
    

    免责声明:在权重上使用 =(赋值)时,您将替换参数的权重张量。这可能会对图形产生不良影响。优化步骤。

    【讨论】:

    • 谢谢@blue-phoenox,也将把它换成self.wfx.weight.data.mul_(self.mask_use),以规避使用=可能出现的问题
    • @DeepakKadetotad 是的,这好多了,实际上我只是在提出这样的操作的路上。这种方式比更换权重要好得多。但我还是有点怀疑,在训练期间对权重进行这样的操作是否会导致平滑转换,但我还没有尝试过:)
    【解决方案2】:

    将 Tensorfloat 变量更改为参数变量的一种有效方法:

    self.wfx.weight = torch.nn.parameter.Parameter((self.wfx.weight.data * self.mask_use))
    

    我希望这会有用。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-10-19
      • 1970-01-01
      • 2022-07-29
      • 2020-08-15
      • 2018-09-01
      • 2019-06-24
      • 2020-03-15
      • 2019-04-27
      相关资源
      最近更新 更多