【问题标题】:Iterable dataset exhausts after a single epoch可迭代数据集在单个 epoch 后耗尽
【发布时间】:2021-09-03 22:36:00
【问题描述】:

我想在情感分析任务上训练一个 RNN,对于这个任务,我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 Python 迭代器。我用了split=('train', 'test')

我首先使用torchtext.vocab.Vocab 构建了一个词汇表,并对每个句子进行了标记,然后进行了数字化。

为了将序列填充到相同的长度,我使用了torch.nn.utils.rnn.pad_sequence,还使用了collate_fnbatch_sampler。然后我使用 torch.utils.data.DataLoader 加载数据。

RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。

我是否采用了正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个时期后耗尽,我该如何克服这个问题。

如果您想查看我的实现,请参阅共享的 colab 笔记本。

附言。我在关注来自github的torchtext官方changelog

你可以找到我的实现here

【问题讨论】:

    标签: python nlp pytorch torchtext


    【解决方案1】:

    解决方案是使用torchtext.data.functional.to_map_style_dataset(iter_data) (official doc) 将您的可迭代式数据集转换为地图式数据集。

    像这样:

    from torchtext.data.functional import to_map_style_dataset
    train_iter = IMDB(split='train')
    train_dataset = to_map_style_dataset(train_iter)  #Map-style dataset
    

    然后制作一个数据加载器。

    from torch.utils.data import DataLoader
    train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
    

    为什么会这样?

    我用上面例子的命名约定来解释。

    传递给Dataloadertrain_iter 是一个可迭代样式的数据集,这意味着它没有实现__getitem__。它只有 __iter____next__ dunders - 这使它成为可迭代的。

    因此,如果我将一个可迭代对象传递给Dataloader,则数据加载器会在发生StopIteration 异常后停止——这将由可迭代样式数据集的__next__ dunder(在本例中为train_iter)抛出数据集(可迭代)已用尽。

    所以我们使用to_map_style_dataset 函数将Iterable-style 转换为map-style 数据集。它通过实现__getitem__ dunder 来实现,因此Dataloader 默认使用索引从数据集中获取项目。

    做同样事情的另一种可能的方式也可以是

    如果我要使用可迭代式数据集 - 我需要在每个时期创建 Dataloader 对象。因此,在每个 epoch 之后,新的 dataloader 对象将在 for 循环中从头开始运行。

    为了更好地理解 Pytorch 中 Iterable-style 和 Map-style 数据集的区别和用例,请参阅https://yizhepku.github.io/2020/12/26/dataloader.html

    【讨论】:

    • 如果我回答了你的问题,请告诉我。如果您认为我的理解不正确,也请建议编辑。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-04-08
    • 2016-10-13
    • 2016-12-25
    • 2022-01-03
    • 1970-01-01
    相关资源
    最近更新 更多