【问题标题】:MNIST dataset is failed to transform as tensor objectMNIST 数据集未能转换为张量对象
【发布时间】:2020-11-02 12:27:23
【问题描述】:

如何正确地将 MNIST 数据集转换为张量类型?我在下面试过但没有用。错误信息AttributeError: 'int' object has no attribute 'type' 表示它不是张量类型。

以下代码可以在 Google Colab 中进行测试。

PyTorch 版本 1.3.1 似乎可以运行它,但不适用于 1.5.1。

>>> import torch
>>> import torch.nn as nn
>>> import torchvision.transforms as transforms
>>> import torchvision.datasets as dsets
>>> import numpy as np
>>> torch.__version__
1.5.1+cu101

>>> train_dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100.1%Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
113.5%Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100.4%Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
180.4%Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
/pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.
Done!

>>> print("Print the training dataset:\n ", train_dataset)
Print the training dataset:
  Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

>>> print("Type of data element: ", train_dataset[0][1].type())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'int' object has no attribute 'type'

【问题讨论】:

    标签: python numpy machine-learning pytorch


    【解决方案1】:

    您需要访问 Ist 元素(对应于图像张量),而不是第二个元素(标签),即

    >>> print("Type of data element: ", train_dataset[0][0].type())
    Type of data element:  torch.FloatTensor
    
    >>> print(train_dataset[0][0].shape, train_dataset[0][1])
    (torch.Size([1, 28, 28]), 5)
    

    【讨论】:

    • 谢谢,为什么1.5.1版本的标签不是LongTensor
    猜你喜欢
    • 1970-01-01
    • 2017-07-06
    • 1970-01-01
    • 2020-02-26
    • 1970-01-01
    • 1970-01-01
    • 2019-01-17
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多