【发布时间】:2023-03-25 15:28:01
【问题描述】:
我有大约 200,000 张高分辨率图像,每次加载如此高质量的图像非常耗时。 预加载所有图像可能会占用太多内存。 如何将每个图像保存为 .npz 文件格式并加载 .npz 而不是 .jpg?会不会提高速度?
【问题讨论】:
标签: pytorch
我有大约 200,000 张高分辨率图像,每次加载如此高质量的图像非常耗时。 预加载所有图像可能会占用太多内存。 如何将每个图像保存为 .npz 文件格式并加载 .npz 而不是 .jpg?会不会提高速度?
【问题讨论】:
标签: pytorch
您不需要一次将所有图像加载到内存中。还要考虑到我们在模型训练的时候需要对数据集进行数据扩充,所以不可能加载所有的图片。
在 PyTorch 中,您可以使用 Dataset 来存储您的训练和验证集。 Dataset 类有一个参数transforms(例如,Scale、RandomCrop 等),用于在训练期间动态变换训练图像。 torchvision包还提供了几个现成的数据集,见here。
PyTorch 的内置Dataloader 有一个num_worker,用于控制您使用多少子进程来加载数据。由于您的数据集不是那么大,因此足以满足您的需要。关于如何设置合适的worker数量,见here。
【讨论】:
img 是一个火炬张量,model 就是你的网络。您可以使用out = model(img) 来获取模型的输出。