【发布时间】:2018-03-20 04:15:25
【问题描述】:
我的 csv 太大,无法一次读入内存,所以我想将其分块并逐个拟合 keras 模型。我想我误解了 fit_generator 函数的工作原理,因为我不断收到 StopIteration 错误,即使 chunksize 和 steps_per_epoch 正确说明了我的 csv 中有多少行。
代码:
import pandas as pd
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout
np.random.seed(26)
x_train_generator = pd.read_csv('X_train.csv', header=None, chunksize=150000)
y_train_generator = pd.read_csv('Y_train.csv', header=None, chunksize=150000)
x_test_generator = pd.read_csv('X_test.csv', header=None, chunksize=50000)
y_test_generator = pd.read_csv('Y_test.csv', header=None, chunksize=50000)
model = Sequential()
model.add(Dense(500, input_dim=1132, activation='tanh'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', metrics=['accuracy'],
optimizer='adam')
model.fit_generator((x_train_generator.get_chunk().as_matrix(),
y_train_generator.get_chunk().as_matrix()),
steps_per_epoch=37,
epochs=1,
verbose=2,
validation_data=(x_test_generator.get_chunk().as_matrix(),
y_test_generator.get_chunk().as_matrix()),
validation_steps=37
)
错误输出:
Exception in thread Thread-107:
Traceback (most recent call last):
File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/usr/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/home/user/myenv/local/lib/python2.7/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
generator_output = next(self._generator)
TypeError: tuple object is not an iterator
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
/home/user/tmp_keras.py in <module>()
22 verbose=2,
23 validation_data=(x_test_generator.get_chunk().as_matrix(), y_test_generator.get_chunk().as_matrix()),
---> 24 validation_steps=37
25 )
26
/home/user/myenv/local/lib/python2.7/site-packages/keras/legacy/interfaces.pyc in wrapper(*args, **kwargs)
85 warnings.warn('Update your `' + object_name +
86 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87 return func(*args, **kwargs)
88 wrapper._original_function = func
89 return wrapper
/home/user/myenv/local/lib/python2.7/site-packages/keras/models.pyc in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_$ueue_size, workers, use_multiprocessing, initial_epoch)
1119 workers=workers,
1120 use_multiprocessing=use_multiprocessing,
-> 1121 initial_epoch=initial_epoch)
1122
1123 @interfaces.legacy_generator_methods_support
/home/user/myenv/local/lib/python2.7/site-packages/keras/legacy/interfaces.pyc in wrapper(*args, **kwargs)
85 warnings.warn('Update your `' + object_name +
86 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87 return func(*args, **kwargs)
88 wrapper._original_function = func
89 return wrapper
/home/user/myenv/local/lib/python2.7/site-packages/keras/engine/training.pyc in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weig
ht, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
2009 batch_index = 0
2010 while steps_done < steps_per_epoch:
-> 2011 generator_output = next(output_generator)
2012
2013 if not hasattr(generator_output, '__len__'):
StopIteration:
奇怪的是,如果我将 fit_generator() 包装在 while 1: try: ... except StopIteration: 中,它会设法运行。
我尝试在不使用 get_chunk().as_matrix() 函数的 fit_generator 参数中使用 x/y_train_generator,但它失败了,因为我没有向 keras 传递一个 numpy 数组。
【问题讨论】:
-
你知道
chunksize=150000是做什么的吗?另外,你知道你是否需要它吗?如果你不知道你是否需要它,你可能不需要。 -
它获取数据帧的下 150000 行,对吧? csv 有超过 500 万行和 >20 GB,所以我知道读取它的唯一方法是 chunksize 或指定
iterator=True。 -
它返回一个迭代器对象,你仍然需要迭代它。
-
fit_generator 函数中的
.get_chunk()不满足吗? -
是的 - 您对
get_chunk()工作原理的使用和思考是合理的。见pandas IO tools documentation。问题是x_train_generator.get_chunk().as_matrix()调用,它在 pandas IO TextFileReader 对象(生成器,而不是数据框)上调用as_matrix()。
标签: python pandas keras generator