【发布时间】:2018-02-18 12:04:13
【问题描述】:
Keras 的fit_generator() 模型方法需要一个生成器,该生成器生成形状为(输入、目标)的元组,其中两个元素都是 NumPy 数组。 The documentation 似乎暗示如果我简单地将 Dataset iterator 包装在生成器中,并确保将张量转换为 NumPy 数组,我应该很高兴。然而,这段代码给了我一个错误:
import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with tf.Session() as sess:
def create_data_generator():
dat1 = np.arange(4).reshape(-1, 1)
ds1 = Dataset.from_tensor_slices(dat1).repeat()
dat2 = np.arange(5, 9).reshape(-1, 1)
ds2 = Dataset.from_tensor_slices(dat2).repeat()
ds = Dataset.zip((ds1, ds2)).batch(4)
iterator = ds.make_one_shot_iterator()
while True:
next_val = iterator.get_next()
yield sess.run(next_val)
datagen = create_data_generator()
input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
verbose=2, max_queue_size=2)
这是我得到的错误:
Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
fetch, allow_tensor=True, allow_operation=True))
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
generator_output = next(self._generator)
File "./datagen_test.py", line 25, in create_data_generator
yield sess.run(next_val)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
return _ListFetchMapper(fetch)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)
Traceback (most recent call last):
File "./datagen_test.py", line 34, in <module>
verbose=2, max_queue_size=2)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
generator_output = next(output_generator)
StopIteration
奇怪的是,在我初始化 datagen 之后直接添加包含 next(datagen) 的行会导致代码运行良好,没有错误。
为什么我的原始代码不起作用?为什么当我将该行添加到我的代码时它开始工作?有没有更有效的方法将 TensorFlow 的 Dataset API 与 Keras 一起使用,而不涉及将张量转换为 NumPy 数组并再次转换回来?
【问题讨论】:
-
我不确定是不是这个原因,但我觉得你在
with块中定义一个函数真的很奇怪。 -
显然,将
with块放在生成器定义中确实可以使代码在有和没有额外行的情况下都可以工作,尽管我可以发誓我首先尝试过这种方式。不过,考虑到(我认为)TensorFlowSessions 的工作方式,我不明白为什么它应该有所作为。另一个谜。 -
with 块不会在会话结束时关闭会话吗?我认为它真的不应该包含将在它之外使用的定义......如果我将其作为问题的答案发布,它会被标记为已回答吗?
-
我不认为这个问题会得到回答。如果我们将
sess = tf.InteractiveSession()放在文件顶部并将with块更改为with sess.as_default()(并将它放在生成器定义中),我们会得到与以前相同的错误。更改交互式会话并完全删除 with 块(因为它将自己设置为默认会话),也会产生相同的错误。我不清楚这是问题的症结所在。 -
我认为这确实是图表的“脱节”。一旦你在一个 numpy 数组中转换一个张量,你就会失去连接(它不再是一个张量)。有没有办法创建并行会话?也许您的生成器应该在其中创建子会话(独立于运行模型的会话),所以这样它就不会期望连接?
标签: tensorflow keras