一、概述
数据集较小时(小于2W)建议num_works不用管默认就行,因为用了反而比没用慢。
当数据集较大时建议采用,num_works一般设置为(CPU线程数+-1)为最佳,可以用以下代码找出最佳num_works(注意windows用户如果要使用多核多线程必须把训练放在if __name__ == \'__main__\':下才不会报错)
二、代码
import time import torch.utils.data as d import torchvision import torchvision.transforms as transforms if __name__ == \'__main__\': BATCH_SIZE = 100 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_set = torchvision.datasets.MNIST(\'\mnist\', download=False, train=True, transform=transform) # data loaders train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) for num_workers in range(20): train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers) # training ... start = time.time() for epoch in range(1): for step, (batch_x, batch_y) in enumerate(train_loader): pass end = time.time() print(\'num_workers is {} and it took {} seconds\'.format(num_workers, end - start))
三、查看线程数
1、cpu个数
grep \'physical id\' /proc/cpuinfo | sort -u
2、核心数
grep \'core id\' /proc/cpuinfo | sort -u | wc -l
3、线程数
grep \'processor\' /proc/cpuinfo | sort -u | wc -l
4、例子
命令执行结果如图所示,根据结果得知,此服务器有1个cpu,6个核心,每个核心2线程,共12线程。