参考链接:https://www.jb51.net/article/167899.htm
在训练神经网络时,需要向网络中丢入数据,以供神经网络来学习其中的一些特征,但是对于同样的框架,神经网络如何做到训练各种各样的数据呢?
那么就需要数据按照一定的格式来组织了,即Dataset类,(以便使用已经定义好的特殊数据集接口来加载数据)
1.先来介绍一下pytorch中的数据处理模块torch.utils.data.TensorDataset
class torch.utils.data.TensorDataset(data_tensor, target_tensor):封装成tensor的数据集,每一个样本都通过索引张量来获得。
当使用的数据是pytorch官方给出的数据(及已经做好了框架需要的数据格式),可以直接使用 TensorDataset 来将数据包装成Dataset类。
例如:
dataset = TensorDataset(x_data, y_data)
2.再来介绍torch.utils.data.TensorDataset:数据集接口类
如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。
torch.utils.data.TensorDataset是一个包装类,用来将数据包装为Dataset类,方便我们继承并实现自己的数据集接口,
自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
例:
import torch
class myDataset(torch.nn.data.Dataset):
def __init__(self, dataSource)
self.dataSource = dataSource
def __getitem__(self, index):
element = self.dataSource[index]
return element
def __len__(self):
return len(self.dataSource)
train_data = myDataset(dataSource)
整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;
__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;
__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。
参考代码段:
可以看出,在初始化的时候根据数据特点来设置
3.最后,介绍class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器
将数据集处理成dataset类之后就可以使用数据加载器直接加载入神经网络中进行训练或者测试了。
边学边做笔记,如有不当,还请多多指正