【问题标题】:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0
【发布时间】:2021-12-03 21:55:18
【问题描述】:

我有一个拥抱脸模型的包装器。在这个包装器中,我有一些编码器,主要是一系列嵌入。在包装模型的前面,我想在循环中调用每个编码器的转发,但我收到错误:

Traceback (most recent call last):
  File "/home/pouramini/mt5-comet/comet/train/train.py", line 1275, in <module>
    run()
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 716, in __call__
    return self.main(*args, **kwargs)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 696, in main
    rv = self.invoke(ctx)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 1060, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 889, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 534, in invoke
    return callback(*args, **kwargs)
  File "/home/pouramini/mt5-comet/comet/train/train.py", line 1069, in train
    result = wrapped_model(**batch)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 135, in forward
    prompt_embeds = encoder(prompt_input_ids,\
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 238, in forward
    return self.embedding(prompt_token_ids)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2043, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument index in method wrapper_index_select)

这是导致错误的代码:

           for encoder in self.prompt_encoders:
                #encoder = self.prompt_encoders[0]
                wlog.info("********** offset: %s, length: %s", encoder.id_offset, encoder.length)
                prompt_token_fn = encoder.get_prompt_token_fn()
                encoder_masks = prompt_token_fn(input_ids)
                wlog.info("Encoder masks: %s", encoder_masks)
                if encoder_masks.any():
                    #find input ids for prompt tokens
                    prompt_input_ids = input_ids[encoder_masks]
                    wlog.info("Prompt Input ids: %s", prompt_input_ids)
                    # call forwards on prompt encoder whose outputs are prompt embeddings
                    prompt_embeds = encoder(prompt_input_ids,\
                        prompt_ids).to(device=inputs_embeds.device)

如果我只使用cpu 作为设备,代码就会运行。另外,如果我有一个编码器,代码是用 cuda 运行的,但是当有多个编码器时,它似乎希望它们都被传输到设备上,我不知道该怎么做。

【问题讨论】:

  • 确保模型及其输入都传输到同一个设备。您只需调用input = input.to("cuda")model.to("cuda") 即可。在您的情况下,您可以在训练周期之前执行for encoder in self.prompt_encoders: encoder.to("cuda") 之类的操作。
  • @aretor 谢谢,您指出的正是解决方案,我将其发布为答案。

标签: pytorch


【解决方案1】:

基于cmets,我在训练前添加了如下代码。

      wrapped_model.to(device=device)
      for encoder in wrapped_model.prompt_encoders:
            encoder.to(device=device)

有趣的是,当有一个编码器或包含一个编码器的编码器列表时,我不需要明确地将它放在设备上,但对于编码器列表,我似乎必须。

原因可能是我将单个编码器放在设备上的转发功能中。

【讨论】:

  • 如果您还没有这样做,请考虑使用nn.ModuleList 而不是标准的,这样您就可以直接在ModuleList 对象上调用to("cuda")
猜你喜欢
  • 2021-04-27
  • 2021-12-24
  • 2021-03-03
  • 2021-07-18
  • 2021-05-11
  • 2022-01-03
  • 2021-03-26
  • 2020-02-01
  • 2020-03-24
相关资源
最近更新 更多