【发布时间】:2019-06-12 10:47:04
【问题描述】:
我有一个包含大约一百万张图片的目录。我想创建一个batch_generator,这样我就可以训练我的 CNN,因为我不能一次将所有这些图像都保存在内存中。
所以,我为此编写了一个生成器函数:
def batch_generator(image_paths, batch_size, isTraining):
while True:
batch_imgs = []
batch_labels = []
type_dir = 'train' if isTraining else 'test'
for i in range(len(image_paths)):
print(i)
print(os.path.join(data_dir_base, type_dir, image_paths[i]))
img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
img = np.divide(img, 255)
img = img.reshape(28, 28, 1)
batch_imgs.append(img)
label = image_paths[i].split('_')[1].split('.')[0]
batch_labels.append(label)
if len(batch_imgs) == batch_size:
yield (np.asarray(batch_imgs), np.asarray(batch_labels))
batch_imgs = []
if batch_imgs:
yield batch_imgs
当我调用这个语句时:
index = next(batch_generator(train_dataset, 10, True))
它打印相同的索引值和路径,因此,它在每次调用next() 时返回相同的批次。
我该如何解决这个问题?
我用这个问题作为代码的参考:how to split an iterable in constant-size chunks
【问题讨论】:
-
@kerwei 不,它的缩进是正确的,如果它的大小是
-
@brunodesthuilliers 是的,乍一看我没有注意到内部 if 块。因此,删除了我的评论:)
标签: python python-3.x generator