【发布时间】:2018-07-11 02:15:14
【问题描述】:
我正在尝试实现一个自定义数据生成器,它使用pandas.read_csv 以块的形式从 csv 文件中读取数据。我使用model.predict_generator 对其进行了测试,但返回的预测数量少于预期(在我的情况下,253457 中有 248192)。
自定义生成器
class TestDataGenerator:
def __init__(self, directory, batch_size=1024):
self.directory = directory
self.batch_size = batch_size
self.chunk_size=10000
self.samples = 0
def _to_movie_id(self, ids):
ids = ast.literal_eval(ids)
if ids == []:
return [EMB_MATRIX_SIZE-1]
else:
return [movie2idx[str(movie_id)] for movie_id in ids]
def generate(self):
csv_files = glob.glob(self.directory + '/*.csv')
while True:
for file in csv_files:
df = pd.read_csv(file, chunksize=self.chunk_size)
for df_chunk in df:
chunk_steps = math.ceil(len(df_chunk) / self.batch_size)
for i in range(chunk_steps):
batch = df_chunk[i * self.batch_size:(i + 1) * self.batch_size]
X_batch, y_batch = self.preprocess(batch)
self.samples += len(batch)
yield X_batch, y_batch
def preprocess(self, df):
X_user = df['user'].apply(lambda x: user2idx[str(x)]).values
X_watched = df['watched'].apply(self._to_movie_id).values
X_watched_padded = pad_sequences(X_watched, maxlen=SEQ_LENGTH, value=0)
ohe = df['movie'].apply(lambda x: to_categorical(movie2idx[x], num_classes=len(movie2idx)))
X = [X_user, X_watched_padded]
y = np.array([o.tolist() for o in ohe])
return X, y
运行model.predict_generator
batch_size=1024
n_samples_test = 253457
test_dir = 'folder/'
test_gen = TestDataGenerator(test_dir, batch_size=batch_size)
next_test_gen = test_gen.generate()
preds = model.predict_generator(next_test_gen, steps=math.ceil(n_samples_test/batch_size))
运行model.predict_generator 后,preds 的行数为248192,小于实际的253457。看起来它缺少几个时代。我还单独测试了generate 而没有与 Keras 交互,它的行为与预期的一样,在 csv 文件中返回了正确数量的样本。此外,在generate 产生值之前,我会跟踪使用samples 处理的样本数量。令人惊讶的是,samples 的值是 250000。所以,我很确定我可能对 Keras 做了一些事情。
请注意,我还尝试设置max_queue_size=1,并使generate 线程安全,但没有成功。为了简单起见,我只在test_dir 下放置了 1 个 csv 文件。我正在使用嵌入在 Tensorflow 1.5.0 中的 Keras 2.1.2-tf。
我对如何做到这一点进行了一些研究,但还没有遇到有用的示例。这个实现有什么问题?
谢谢
Peeranat F.
【问题讨论】:
-
只是出于好奇 - 你能检查一下
preds[::1024]向量的样子吗?
标签: python pandas tensorflow keras