【问题标题】:PyTorch custom DataLoader working with multiple CSVsPyTorch 自定义 DataLoader 处理多个 CSV
【发布时间】:2021-10-23 03:42:58
【问题描述】:

我正在尝试定义一个定制的 PyTorch DataLoader,它能够有效地从不同的 huge CSV 中读取,而无需将它们加载到内存中。问题定义如下。为简单起见,假设我有两个 CSV

1.csv:

1, 2, 3
4, 5, 6
7, 8, 9
2.csv:

10,11,12
13,14,15
16,17,18

为简单起见,我们还假设批量大小为 1。生成器应该产生两个张量:

Tensor_1: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Tensor_2: [10, 11, 12, 13, 14, 15, 16, 17, 18]

这是因为对于每个有效索引,我应用的历史窗口等于 2,然后我将样本展平。

按照What is the fastest way to load data from multiple csv files的答案,我编写了以下代码:

import numpy as np
import pandas as pd
import glob
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
import torch

@lru_cache()
def get_sample_count_by_file(path: Path) -> int:
    c = 0
    with path.open() as f:
        for line in f:
            c += 1
    return c

class CSVDataset:
    def __init__(self, csv_directory: str, extension: str = ".csv"):
        self.directory = Path(csv_directory)
        self.files = sorted((f, get_sample_count_by_file(f)) for f in self.directory.iterdir() if f.suffix == extension)
        self._sample_count = sum(f[-1] for f in self.files)

    def __len__(self):
        return self._sample_count

    def __getitem__(self, idx):
        current_count = 0

        history_window = 2
        my_idx=idx+2

        for file_, sample_count in self.files:
            if current_count <= my_idx < current_count + sample_count:
                break  
            current_count += sample_count

        file_idx = my_idx - current_count # the index we want to access in file_
        if file_idx < 2:
            file_idx += 2

        with file_.open() as f:
            data = []
            for i, line in enumerate(f):
                if i >= file_idx-history_window and i <= file_idx:
                    for v in line.split(","):
                        data.append(float(v))

            data = np.array(data)
            return torch.from_numpy(data)


dataset = CSVDataset("<PATH CONTAINING CSVs>")
loader = DataLoader(dataset, batch_size=1)

pprint(list(enumerate(loader)))

它对第一个文件非常有效,但是当它切换到第二个 CSV 时会出现问题(由于索引管理错误,有一些重复)。我该如何解决这个问题?

【问题讨论】:

    标签: pytorch iterator pytorch-dataloader


    【解决方案1】:

    如何将您的 CSVDataset 仅用于一个 csv,然后使用 torch.utils.data.ConcatDataset 将所有单独的 csv 数据集合并为一个数据集。 Pytorch 将为您处理索引,只要每个 CSVDataset 中的索引是一致的。

    【讨论】:

    • 感谢您的回答@Shai!您能否提供一个代码示例来更好地解释您的想法?
    • 编辑:似乎torch.utils.data.ConcatDataset 将文件连接起来,将它们全部保存在内存中。我的文件是巨大的 CSV,所以这不是一个解决方案。
    猜你喜欢
    • 2020-07-07
    • 1970-01-01
    • 2022-01-23
    • 2021-10-01
    • 2019-04-27
    • 2019-07-29
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多