【问题标题】:RuntimeError: expected scalar type Float but found DoubleRuntimeError:预期的标量类型 Float 但发现 Double
【发布时间】:2021-10-10 15:07:50
【问题描述】:

我的代码如下:

net = nn.Linear(54, 7)
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0)
logloss = torch.nn.CrossEntropyLoss()
for i in range(niter):
    optimizer.zero_grad()
    y_2 = torch.from_numpy(np.array(y, dtype='float64'))
    X_2 = torch.from_numpy(np.array(X, dtype='float64'))
    outputs = net(X_2)
    print(loss)
    loss.backward()
    optimizer.step()

我收到以下错误消息

---> 57             outputs = net(X_2)
     58             print(np.shape(outputs))
     59             loss = logloss(outputs, y_2)

~\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
     94 
     95     def forward(self, input: Tensor) -> Tensor:
---> 96         return F.linear(input, self.weight, self.bias)
     97 
     98     def extra_repr(self) -> str:

~\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1845     if has_torch_function_variadic(input, weight):
   1846         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1847     return torch._C._nn.linear(input, weight, bias)
   1848 
   1849 

RuntimeError: expected scalar type Float but found Double

你能具体说明我的问题吗,谢谢。我除了我已经通过torch.from_numpy(np.array(y, dtype='float64')) 将结果转换为浮点数,但不起作用。

【问题讨论】:

    标签: pytorch logistic-regression


    【解决方案1】:

    您需要将您的张量转换为 float32,可以使用 dtype='float32' 或在您的输入张量上调用 float()

    【讨论】:

      猜你喜欢
      • 2021-07-10
      • 2021-09-15
      • 2020-06-11
      • 2023-03-21
      • 2020-05-31
      • 2021-03-05
      • 1970-01-01
      • 2020-06-15
      • 2021-02-18
      相关资源
      最近更新 更多