【问题标题】:Why is Pytorch giving me a datatype error: Float vs Double?为什么 Pytorch 给我一个数据类型错误:Float vs Double?
【发布时间】:2020-10-16 22:02:12
【问题描述】:

我正在将我在网上找到的一些工作 Pytorch 代码(这是一个使用 MNIST 数据的 2D 图像分类示例;很抱歉我丢失了原始来源并且无法找到它)迁移到我需要的东西,这正在将一维值集合转换为数值分数。我创建了自己的 Dataset 类。当我调用 model() 时,我收到一个错误:RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm。我的第一层困惑是,即使有 Double 数据类型,我也找不到任何对 Python 的引用。我的第二个是为什么我得到错误 - 当我输入调试代码以显示 mat1 的数据类型及其元素时,我被告知它是一个声称是 float64 的张量。我也想知道为什么它期望 mat1 的标量,文档将其描述为矩阵/张量。

完整的错误转储是

Traceback (most recent call last):
  File "mlalan.py", line 174, in <module>
    outputs = model(images)
  File "/usr/home/adf/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "mlalan.py", line 80, in forward
    x = activate(self.fc1(x))
  File "/usr/home/adf/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/home/adf/.local/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/home/adf/.local/lib/python3.7/site-packages/torch/nn/functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm

我的 Dataset 类中的一些关键代码是

class RandomDataset(Dataset):

    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file, dtype=float)

    def __getitem__(self, idx):
        raw = self.data_frame.values[idx]
        sample = raw[0:6], raw[6:8]
        return sample

完整的源代码在http://8wheels.org/mlalan.py

【问题讨论】:

  • 试着把它变成一个浮点数,float(num)
  • 顺便说一句,float64 是 Double;浮点数是 float32。 Pandas 可能正在以双精度加载数据。
  • 您的模型需要一个 FloatTensor。调用 (some_tensor).float() 以作为 nn.Module 实例的输入的任何张量。

标签: python types pytorch


【解决方案1】:

默认情况下,在 Python 中,float 表示 float32。然而,在 Pandas 和 Numpy 中,float 表示 float64。我能够通过添加对 astype 的调用来解决问题,如下所示。它需要“32”才能工作。

raw = self.data_frame.values[idx].astype(np.float32)

谢谢!

现在我可以继续下一次崩溃 :-)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-06-02
    • 2021-05-26
    • 2021-05-21
    相关资源
    最近更新 更多