【问题标题】:Keras: deal with threads and large datasetsKeras:处理线程和大型数据集
【发布时间】:2018-04-05 10:35:23
【问题描述】:

我正在尝试处理 Keras 中的大型训练数据集。

我将model.fit_generator 与从 SQL 文件读取数据的自定义生成器一起使用。

我收到一条错误消息,告诉我不能在两个不同的线程中使用 SQLite 对象:

ProgrammingError: SQLite objects created in a thread can only be used in that 
same thread.The object was created in thread id 140736714019776 and this is 
thread id 123145449209856

我尝试对 HDF5 文件执行相同操作,但遇到了分段错误,我现在认为这也与 fit_generator 的多线程字符有关(请参阅报告的错误 here)。

使用这些生成器的正确方法是什么,因为我认为对于不适合内存的数据集,必须从文件中批量读取数据是很常见的。

这里是生成器的代码:

class DataGenerator:
    def __init__(self, inputfile, batch_size, **kwargs):
        self.inputfile = inputfile
        self.batch_size = batch_size

    def generate(self, labels, idlist):
        while 1:
            for batch in self._read_data_from_hdf(idlist):
                batch = pandas.merge(batch, labels, how='left', on=['id'])
                Y = batch['label']
                X = batch.drop(['id', 'label'], axis=1)
                yield (X, Y)    

    def _read_data_from_hdf(self, idlist):
        chunklist = [idlist[i:i + self.batch_size] for i in range(0, len(idlist), self.batch_size)]
        for chunk in chunklist:
            yield pandas.read_hdf(self.inputfile, key='data', where='id in {}'.format(chunk))

# [...]

model.fit_generator(generator=training_generator,
                    steps_per_epoch=len(partitions['train']) // 
                    config['batch_size'],
                    validation_data=validation_generator,
                    validation_steps=len(partitions['validation']) // 
                    config['batch_size'],
                    epochs=config['epochs'])

请参阅full example repository here

感谢您的支持。

干杯,

【问题讨论】:

  • 您找到解决问题的方法了吗?我有同样的...
  • 到目前为止,我结束的解决方案是使用 model.train_on_batch 方法。

标签: python multithreading dataset deep-learning keras


【解决方案1】:

面对同样的问题,我想出了一个解决方案,将线程安全装饰器与可以管理对数据库的并发访问的sqlalchemy 引擎相结合:

import pandas
from sqlalchemy import create_engine

class threadsafe_iter:
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)


def threadsafe_generator(f):
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g


class DataGenerator:
    def __init__(self, inputfile, batch_size, **kwargs):
        self.inputfile = inputfile
        self.batch_size = batch_size
        self.sqlengine = create_engine('sqlite:///' + self.inputfile)

    def __del__(self):
        self.sqlengine.dispose()

    @threadsafe_generator
    def generate(self, labels, idlist):
        while 1:
            for batch in self._read_data_from_sql(idlist):
                Y = batch['label']
                X = batch.drop(['id', 'label'], axis=1)
                yield (X, Y)

    def _read_data_from_sql(self, idlist):
        chunklist = [idlist[i:i + self.batch_size]
                     for i in range(0, len(idlist), self.batch_size)]
        for chunk in chunklist:
            query = 'select * from data where id in {}'.format(tuple(chunk))
            df = pandas.read_sql(query, self.sqlengine)
            yield df

# Build keras model and instantiate generators

model.fit_generator(generator=training_generator,
                    steps_per_epoch=train_steps,
                    validation_data=validation_generator,
                    validation_steps=valid_steps,
                    epochs=10,
                    workers=4)

希望对你有帮助!

【讨论】:

  • 确实很好
猜你喜欢
  • 2023-04-02
  • 2018-01-13
  • 2022-01-06
  • 2019-11-27
  • 1970-01-01
  • 2012-07-21
  • 2017-09-24
相关资源
最近更新 更多