【问题标题】:How can I shuffle the labels of a dataset?如何打乱数据集的标签?
【发布时间】:2020-02-19 22:52:51
【问题描述】:

我已经下载了 MNIST 数据集,使用以下命令:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

我现在需要在这个数据集 (MNIST) 上运行一些实验,但要打乱训练集的标签。如何随机洗牌/重新分配它们?我尝试了以下方法:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: torch.randint(0, 10, (1,)).item(),
                            download=True)

但我注意到 lambda 函数之后的内容会在训练过程中使标签随机排列,例如它们在每个时代都在变化。这样,我不会达到 100% 的训练准确率,而这正是我的目标。如何以完全随机的方式打乱这些标签,确保这些标签在训练过程中不会改变?

谢谢!!

【问题讨论】:

    标签: pytorch mnist


    【解决方案1】:

    如果您的目标是创建标签的随机映射,则需要在定义目标变换之前定义映射以保持变换不变。像下面这样的东西应该可以解决问题

    import random
    label_mapping = list(range(10))
    random.shuffle(label_mapping)
    train_dataset = dsets.MNIST(root='./data', 
                                train=True, 
                                transform=transforms.ToTensor(),
                                target_transform=lambda y: label_mapping[y],
                                download=True)
    

    为了在每个 epoch 获得新的 shuffle,您需要重新定义每个 epoch 的标签映射、训练数据集和数据加载器。

    更新 要生成一个独立于真实标签但与给定索引一致的随机标签,那么您可能需要做一些非常仔细的播种或重新实现数据集类的某些功能。

    例如,后一种情况可能看起来像这样

    import random
    class RandomMNIST(dsets.MNIST):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.targets = [random.randint(0, 9) for _ in range(len(self.data))]
    
    train_dataset = RandomMNIST(root='./data', 
                                train=True, 
                                transform=transforms.ToTensor(),
                                download=True)
    

    或等效

    import random
    train_dataset = dsets.MNIST(root='./data', 
                                train=True, 
                                transform=transforms.ToTensor(),
                                download=True)
    train_dataset.targets = [random.randint(0, 9) for _ in range(len(train_dataset))]
    

    【讨论】:

    • 感谢您的回答!这里的问题是,如果我做对了,随机数取决于 y。换句话说,它将所有相等的标签映射到相同的数字,例如它将每 3 个映射到一个始终相同的随机数。我需要的是完全随机的洗牌。
    • 啊,我明白了,所以您想要一个独立于真实标签的随机标签?
    • 是的,那太好了:)
    • @Alfred 我更新了答案。我在手机上打字所以没有测试,但这个想法应该很清楚。另一种方法是在定义 train_dataset.targets 后对其进行更改。
    猜你喜欢
    • 2019-12-05
    • 2019-01-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-01-19
    • 1970-01-01
    • 1970-01-01
    • 2020-11-25
    相关资源
    最近更新 更多