【发布时间】:2016-08-05 15:17:41
【问题描述】:
我有一个数据集(71094 个训练图像和 17000 个测试),我需要为它训练一个 CNN。在预处理期间,我尝试使用 numpy 创建一个矩阵,结果证明这个矩阵非常大(71094*100*100*3 for火车数据)[所有图像都是 RGB 100 x 100]..因此我得到一个内存错误。我该如何解决这种情况.??请帮忙。 这是我的代码..
import numpy as np
import cv2
from matplotlib import pyplot as plt
data_dir = './fashion-data/images/'
train_data = './fashion-data/train.txt'
test_data = './fashion-data/test.txt'
f = open(train_data, 'r').read()
ims = f.split('\n')
print len(ims)
train = np.zeros((71094, 100, 100, 3)) #this line causes the error..
for ix in range(train.shape[0]):
i = cv2.imread(data_dir + ims[ix] + '.jpg')
label = ims[ix].split('/')[0]
train[ix, :, :, :] = cv2.resize(i, (100, 100))
print train[0]
train_labels = np.zeros((71094, 1))
for ix in range(train_labels.shape[0]):
l = ims[ix].split('/')[0]
train_labels[ix] = int(l)
print train_labels[0]
np.save('./data/train', train)
np.save('./data/train_labels', train_labels)
【问题讨论】:
标签: python numpy neural-network conv-neural-network