DataLoader是pytorch提供的,一般我们要写的是Dataset,也就是DataLoader中的一个参数,其基本框架是:

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

由此可见,需要暴露的API只有__getitem____len__,还有一个构造函数

相关文章:

  • 2021-06-09
  • 2021-11-02
  • 2021-09-28
  • 2022-12-23
  • 2022-12-23
  • 2022-02-10
  • 2022-12-23
猜你喜欢
  • 2021-02-24
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-10-13
相关资源
相似解决方案