【发布时间】:2021-01-14 20:03:24
【问题描述】:
我遇到了这个主要做我想要的功能,但我需要稍微调整一下。
我的数据是这样的:
import torch
import torch.nn as nn
actual_x = torch.randn(13, 16, 64, 768)
但要在下面的函数中工作,我需要将其置换为:
x = torch.randn(16, 64, 768, 13)
在函数内部,我无法操作*args 的值。因此,如果我想添加这一行以在函数内正确重塑我的数据:args[0] = args[0].permute(1, 2, 3, 0)
我收到'tuple' object does not support item assignment。
class TimeDistributed(nn.Module):
'''
'''
def __init__(self):
super(TimeDistributed, self).__init__()
self.n_layers = 13
self.n_tokens = 64
self.module = torch.nn.Linear(self.n_layers, self.n_tokens)
def forward(self, *args, **kwargs):
#only support tdim=1
#args[0] = args[0].permute(1, 2, 3, 0)
args = list(args[0])
args = args.permute(1, 2, 3, 0)
inp_shape = args[0].shape
bs, seq_len = inp_shape[0], inp_shape[1]
out = self.module(*[x.reshape(bs*seq_len, *x.shape[2:]) for x in args], **kwargs)
out_shape = out.shape
return out.view(bs, seq_len,*out_shape[1:])
它的运行者:
TD1 = TimeDistributed()
out = TD1(x)
out.shape
它失败了:
TD1 = TimeDistributed()
out = TD1(actual_x)
out.shape
【问题讨论】:
-
在课堂
TimeDistributed中没有看到任何对torch.randn或args[0] = args[0].permute(1, 2, 3, 0)的引用。为什么会显示这个类? -
我已经对其进行了编辑以准确显示问题所在。
-
*args是一个元组,您需要将其转换为列表以支持项目分配 -
args没有什么特别之处:它是一个元组,元组是不可变的。 -
这更没有意义。
TD1 = TimeDistributed()缺少参数,out = TD1(x)需要类`TimeDistributed` 定义方法__call__。还是我错过了什么?
标签: python