【问题标题】:I get this error using PyTorch: RuntimeError: gather_out_cpu(): Expected dtype int64 for index我使用 PyTorch 收到此错误:RuntimeError:gather_out_cpu(): Expected dtype int64 for index
【发布时间】:2021-07-30 23:35:21
【问题描述】:

我正在尝试使用 PyTorch 制作 AI,但出现此错误:

RuntimeError: gather_out_cpu(): Expected dtype int64 for index

这是我的功能:

def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma * next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_variables = True)
    self.optimizer.step()

【问题讨论】:

    标签: python-3.x pytorch artificial-intelligence


    【解决方案1】:

    在将 batch_action 张量传递给 torch.gather 之前,您需要更改其数据类型。

    def learn(...):
        batch_action = batch_action.type(torch.int64) 
        outputs = ...
        ...
    
    # or
    outputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)
    

    【讨论】:

    • 我遇到了同样的问题,这也对我有用。但为什么会这样呢?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-07-29
    • 2020-05-05
    • 1970-01-01
    • 2013-01-06
    • 2020-11-11
    相关资源
    最近更新 更多