【问题标题】:How can I read in image dataset as numpy array in a memory efficient manner?如何以内存有效的方式将图像数据集作为 numpy 数组读取?
【发布时间】:2019-05-01 01:49:19
【问题描述】:

我正在尝试将图像数据集作为 numpy 数组加载。我怎样才能做到这一点,以免我在本地机器上强调 RAM 的限制,或者创建一个需要太多内存的数组?较大的图像集是训练集,总共有大约 2GB 的图像。

这是为了训练一个残差神经网络,要求输入数据是一个 numpy 数组。我曾使用模块 glob、PIL、skimage、sklearn 和 numpy 来尝试加载图像,但我这样做的方式可能很幼稚,因为 ~2GB 的图像变成了 ~17(!) GB numpy 数组。我曾尝试搜索解决方案、示例等,但对 Python 比较陌生,因此过程非常缓慢。

用于加载图片的代码是

import glob
from skimage.transform import resize
import numpy as np
from sklearn import datasets
from PIL import Image

def root_2_numpy(data_root):
    """
    Load raw images and output a numpy array of all images and numpy array of labels
    Also preprocesses each image to (224,224) using anti-aliasing
    """
    # load images into numpy array
    all_image_paths = list(data_root.glob('*/*'))  # get image paths
    all_image_paths = [str(path) for path in all_image_paths]  # convert to string
    image_ds = np.zeros([len(all_image_paths), 224, 224,3])  # initialize image dataset
    for i in range(len(all_image_paths)):
        print(i)
        im = Image.open(all_image_paths[i])  # read image as RGB using matplotlib
        if im.mode == 'RGBA' or im.mode == 'L' or im.mode == 'CMYK':
            im = im.convert('RGB')
        elif im.mode =='P':
            im = im.convert('RGBA')
            im = im.convert('RGB')
        im = np.array(im)
        im = resize(im, (224,224), anti_aliasing=True)  # resize image using skimage
        image_ds[i,:,:,:] = im

    # load labels into numpy array
    label_ds = datasets.load_files(data_root, load_content=False, shuffle=False)  # get labels
    n_classes = len(label_ds.target_names)
    Y_ds = np.eye(len(label_ds.target_names))[label_ds.target.reshape(-1)]

    return image_ds, Y_ds, n_classes

我希望这会返回一个约 2GB 的 numpy 数组,该数组的尺寸为(N、W、H、C),用于表示图像数量、图像宽度、图像高度和 3 个图像通道。这不是手头的问题,但我也希望有标签的数据,它们是根目录中的类别名称。

除了帮助我有效地加载数据之外,我还非常感谢了解我的代码如何创建如此大的 numpy 数组。在我写这篇文章时,我有一种感觉,当转换非 RBG 图像的图像类型时,可能会创建比预期更多的图像。

【问题讨论】:

  • len(all_image_paths)的值是多少?
  • len(all_image_paths) 返回 11

标签: python numpy memory


【解决方案1】:

numpy.zeros 创建的数组的默认数据类型是 64 位浮点数。所以image_ds = np.zeros([len(all_image_paths), 224, 224,3]) 创建了一个比你需要的大 8 倍的数组。添加dtype 参数,使image_ds 的数据类型为uint8(8 位无符号整数):

image_ds = np.zeros([len(all_image_paths), 224, 224,3], dtype=np.uint8)

【讨论】:

  • 感谢您的回答!我已经尝试过您的解决方案并且它有效。但是,在调用 X_train.nbytes 之后,我仍然得到了一个 17gb 的数组。不过,使用 np.uint8 作为数据类型的 RAM 似乎没有问题。
猜你喜欢
  • 1970-01-01
  • 2021-02-20
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2022-06-10
  • 2016-08-23
  • 2019-10-19
  • 2020-07-07
相关资源
最近更新 更多