【发布时间】:2019-01-11 20:32:27
【问题描述】:
我收到这个错误提示
RuntimeError: Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3 'index'
但是argument #3 "index" 是什么意思?我在torch.embedding 中找不到“index”参数(此处来源:https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding)
好像我正在传递嵌入错误的参数。
我什至像下面这样更改了输入的数据类型,但错误仍然存在。
batch['doc_tok'] = batch['doc_tok'].long()
batch['query_tok'] = batch['query_tok'].long()
任何评论(即使很短!)或仅列出要查看的关键字都将受到高度赞赏!
这是一个完整的回溯。
Traceback (most recent call last):
File "train_v2.py", line 110, in <module>
main()
File "train_v2.py", line 81, in main
model.update(batch)
File "/home/aerin/Desktop/squad_vteam/src/model.py", line 129, in update
loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y)
File "/home/aerin/Desktop/squad_vteam/src/model.py", line 104, in adversarial_loss
start, end, _ = self.network(batch)
File "/home/aerin/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/aerin/Desktop/squad_vteam/src/dreader.py", line 78, in forward
doc_mask, query_mask = self.lexicon_encoder(batch)
File "/home/aerin/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/aerin/Desktop/squad_vteam/src/encoder.py", line 116, in forward
doc_emb, query_emb = emb(doc_tok), emb(query_tok)
File "/home/aerin/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/aerin/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 108, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/home/aerin/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/functional.py", line 1076, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3 'index'
更新:我什至将整个 model.network 发送到 cpu,但仍然遇到同样的错误。
batch['doc_tok']=batch['doc_tok'].long().cpu()
batch['query_tok']=batch['query_tok'].long().cpu()
self.network.cpu()
print(batch['doc_tok'].dtype, batch['query_tok'].dtype) # They are both torch.int64 torch.int64
start, end, _ = self.network(batch)
在这一点上,我怀疑这可能是一个错误......
model.py 代码:https://github.com/byorxyz/san_mrc/blob/master/src/model.py
网络定义:https://github.com/byorxyz/san_mrc/blob/master/src/dreader.py
【问题讨论】:
-
@Shai 仅删除 cuda() 并不能解决我的问题。
标签: python typeerror pytorch word-embedding