【问题标题】:Python Dataset Class + PyTorch Dataloader: Stuck at __getitem__, how to get Index, Label and so on during Testing?Python Dataset Class + PyTorch Dataloader:卡在__getitem__,Testing时如何获取Index、Label等?
【发布时间】:2020-05-18 11:40:55
【问题描述】:

我有一个小问题,但我现在被困了很长一段时间。希望有人可以帮助我。我目前正在使用我喜欢通过深度学习(CNN 网络)进行训练的 Kddcup99 数据集

我有一个包含 Panda Dataframe 的“数据集”类。因此我分成正常和验证数据集。到目前为止,没有问题。 我将它加载到一个 Numpy 向量中,将其火炬传递到 Tensor,然后将其定向到 DataLoader。

Dataset 类有这两个重要的用于迭代的类:

def __len__(self):
        return len(self.val_df)

def __getitem__(self, index):        
        img, target = self.val_df[index][:-1], self.val_df[index][-1]
        return img, target, index

类中没有DataLoader字符串:

test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)

在我的培训师课程中,我有一个 for 循环,它应该遍历 Dataloader:

with torch.no_grad():
            for data in dataloader:
                inputs, labels, idx = data
                inputs = inputs.to(self.device)

但它不会。我无法访问标签、索引等。

我现在的问题是:为什么? 如何通过数据加载器访问给定数据集中的标签、索引?

感谢大家的帮助! 非常感谢。

【问题讨论】:

    标签: python machine-learning dataset pytorch dataloader


    【解决方案1】:

    DataLoader 的第一个参数是您要从中加载数据的数据集,通常是 Dataset,但不限于 Dataset 的任何实例。只要它定义了长度(__len__)并且可以被索引(__getitem__ 允许),就可以接受。

    您将datat.val_df 传递给DataLoader,这可能是一个NumPy 数组。 NumPy 数组有一个长度并且可以被索引,所以它可以在DataLoader 中使用。由于您直接传递该数组,因此您的数据集的 __getitem__ 永远不会被调用,但数组本身已被索引,因此每个项目都只是 data.val_df[index]

    您必须使用数据集本身 (datat),而不是使用 DataLoader 的基础数据:

    test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
    

    【讨论】:

    • 首先:很抱歉回答迟了。我有一段时间没有在这里检查了。您的回答以非常好的方式完全解决了我的问题。谢谢!
    猜你喜欢
    • 2020-03-09
    • 2021-11-26
    • 1970-01-01
    • 2020-08-26
    • 2022-12-14
    • 2019-05-03
    • 2021-07-28
    • 2020-07-14
    • 2022-01-24
    相关资源
    最近更新 更多