【问题标题】:How to define a batch generator?如何定义批处理生成器?
【发布时间】:2019-06-12 10:47:04
【问题描述】:

我有一个包含大约一百万张图片的目录。我想创建一个batch_generator,这样我就可以训练我的 CNN,因为我不能一次将所有这些图像都保存在内存中。

所以,我为此编写了一个生成器函数:

def batch_generator(image_paths, batch_size, isTraining):
    while True:
        batch_imgs = []
        batch_labels = []
        
        type_dir = 'train' if isTraining else 'test'
        
        for i in range(len(image_paths)):
            print(i)
            print(os.path.join(data_dir_base, type_dir, image_paths[i]))
            img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
            img  = np.divide(img, 255)
            img = img.reshape(28, 28, 1)
            batch_imgs.append(img)
            label = image_paths[i].split('_')[1].split('.')[0]
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                yield (np.asarray(batch_imgs), np.asarray(batch_labels))
                batch_imgs = []
        if batch_imgs:
            yield batch_imgs

当我调用这个语句时:

index = next(batch_generator(train_dataset, 10, True))

它打印相同的索引值和路径,因此,它在每次调用next() 时返回相同的批次。 我该如何解决这个问题?

我用这个问题作为代码的参考:how to split an iterable in constant-size chunks

【问题讨论】:

  • @kerwei 不,它的缩进是正确的,如果它的大小是
  • @brunodesthuilliers 是的,乍一看我没有注意到内部 if 块。因此,删除了我的评论:)

标签: python python-3.x generator


【解决方案1】:
# batch generator
def get_batches(dataset, batch_size):
    X, Y = dataset
    n_samples = X.shape[0]

    # Shuffle at the start of epoch
    indices = np.arange(n_samples)
    np.random.shuffle(indices)

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)

        batch_idx = indices[start:end]

        yield X[batch_idx], Y[batch_idx]

【讨论】:

    【解决方案2】:

    生成器函数本身不是生成器,而是“生成器工厂” - 每次调用 batch_generator(...) 时,它都会返回一个全新的生成器,准备重新开始。 IOW,你想要:

    gen = batch_generator(...)
    for batch in gen:       
        do_something_with(batch)
    

    还有:

    1/ 您编写生成器函数的方式将创建一个无限生成器 - 外部 while 循环将永远重复 - 这可能是也可能不是您所期望的(我想我最好警告您)。

    2/ 您的代码中有两个逻辑错误:首先,您没有重置batch_labels 列表,然后在最后一个yield 上您只产生batch_imgs,这与内部@987654326 不一致@。 FWIW,与其维护两个列表(一个用于图像,另一个用于标签),不如使用一个 (img, label) 元组列表。

    最后附注:你不需要使用range(len(lst)) 来迭代一个列表——Python 的for 循环属于foreach 类型,它直接迭代可迭代的项目,即:

    for path image_paths:
        print(path)
    

    工作原理相同,可读性更强,速度更快...

    【讨论】:

    • 关于外循环,我将使用keras中的生成器来训练一个CNN。所以,我在这门课上使用了类似的批处理生成器实现。您能详细解释一下无限生成器的优缺点吗?
    • 您将用来“训练 CNN”的是(直接或间接)迭代生成器的结果,而不是生成器本身。无限生成器的原理是迭代永远不会停止——next(iterator)总是返回一些东西,for item in iterator 循环将永远运行。如果没有确切地看到它是如何使用的,就不可能判断无限生成器是否适合您自己的用例,我只是想您可能想被警告这一点,因为您似乎并没有真正完全掌握生成器是什么以及它们是如何工作的.
    • 是的,我对生成器没有完全的了解。这是我第一次。但我已经让它工作了。感谢您的帮助。
    【解决方案3】:

    在我看来,您正试图实现这一目标:

    def batch_generator(image_paths, batch_size, isTraining):
        your_code_here
    

    调用生成器 - 而不是你所拥有的:

    index = next(batch_generator(train_dataset, 10, True))
    

    你可以试试:

    index = iter(batch_generator(train_dataset, 10, True))
    index.__next__()
    

    【讨论】:

    • 1/ 您不需要在可迭代对象上调用 iter()(在这种情况下,它实际上只会返回未更改的参数),2/ __next__() 是一种“魔术方法”(通用运算符或类似运算符的函数的实现),不应直接调用,而应通过 next() 函数调用。
    • @brunodesthuilliers 感谢您的指点!诚然,我对生成器还是比较陌生。参与这些讨论有助于我学习和提高。
    【解决方案4】:

    我制作了自己的生成器,它支持限制、批处理或简单的第 1 步迭代:

    def gen(batch = None, limit = None):
        ret = []
        for i in range(1, 11): # put your data reading here and i counter (i += 1) under for
            if batch:
                ret.append(i)
                if limit and i == limit:
                    if len(ret):            
                        yield ret
                    return
                if len(ret) == batch:
                    yield ret
                    ret = []
            else:
                if limit and i > limit:
                    break
                yield i
        if batch and len(ret): # yield the rest of the list
            yield ret
                
    g = gen(batch=5, limit=8) # batches with limit
    #g = gen(batch=5) # batches
    #g = gen(limit=5) # step 1 with limit
    #g = gen() # step 1 with limit
    for i in g:
        print(i)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-10-28
      • 2011-12-18
      • 2020-06-05
      • 1970-01-01
      • 2018-06-02
      • 1970-01-01
      • 2020-03-01
      相关资源
      最近更新 更多