【问题标题】:How do I make custom pytorch datasets structured like the torchvision datasets?如何制作像 torchvision 数据集一样结构化的自定义 pytorch 数据集?
【发布时间】:2020-11-17 08:45:24
【问题描述】:

我是 pytorch 的新手,我正在尝试重用 Fashion MNIST CNN (from deeplizard) 对时间序列数据进行分类。我发现很难理解数据集的结构,因为尽我所能遵循this official tutorialthis SO question,我得到的东西太简单了。我认为这是因为我不太了解 OOP。我制作的数据集在我的 CNN 中可以很好地进行训练,但是尝试使用他们的代码分析结果时却卡住了。

所以我从两个称为特征 [4050, 1, 150, 6] 和目标 [4050] 的 pytorch 张量创建了一个数据集:

train_dataset = TensorDataset(features,targets) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=False) # create your dataloader
print(train_dataset.__dict__.keys()) # list the attributes

我通过检查属性得到这个打印输出

dict_keys(['tensors'])

但在 Fashion MNIST 教程中,他们访问数据的方式如下:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
print(train_set.__dict__.keys()) # list the attributes

你通过检查属性得到这个打印输出

dict_keys(['root', 'transform', 'target_transform', 'transforms', “火车”、“数据”、“目标”])

我的数据集可以很好地用于训练,但是当我进入教程的后续分析部分时,他们希望我访问数据集的部分内容并且我收到错误:

# Analytics
prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
train_preds = get_all_preds(network, prediction_loader)
preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()

print('total correct:', preds_correct)
print('accuracy:', preds_correct / len(train_set))

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-73-daa87335a92a> in <module>
      4 prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
      5 train_preds = get_all_preds(network, prediction_loader)
----> 6 preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()
      7 
      8 print('total correct:', preds_correct)

AttributeError: 'TensorDataset' object has no attribute 'targets'

谁能告诉我这里发生了什么?这是我制作数据集的方式需要改变的地方,还是我可以重写分析代码以访问数据集的正确部分?

【问题讨论】:

    标签: python pytorch torch


    【解决方案1】:

    TensorDatasets 对应的.targets 将是train_dataset.tensors[1]

    TensorDataset的实现很简单:

    class TensorDataset(Dataset[Tuple[Tensor, ...]]):
        r"""Dataset wrapping tensors.
        Each sample will be retrieved by indexing tensors along the first dimension.
        Arguments:
            *tensors (Tensor): tensors that have the same size of the first dimension.
        """
        tensors: Tuple[Tensor, ...]
    
        def __init__(self, *tensors: Tensor) -> None:
            assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
            self.tensors = tensors
    
        def __getitem__(self, index):
            return tuple(tensor[index] for tensor in self.tensors)
    
        def __len__(self):
            return self.tensors[0].size(0)
    

    【讨论】:

      猜你喜欢
      • 2020-09-28
      • 2019-06-12
      • 2020-07-30
      • 2020-02-15
      • 2010-11-18
      • 2019-12-30
      • 1970-01-01
      • 1970-01-01
      • 2021-07-21
      相关资源
      最近更新 更多