【发布时间】:2019-04-29 06:38:56
【问题描述】:
我在使用 TF 估计器时遇到了一个奇怪的问题,并试图在我的输入函数中使用 tf.Dataset。
首先,我的模型如下所示:
model = tf.estimator.DNNClassifier(
feature_columns=my_feature_column,
hidden_units=[hidden_layers, hidden_layers],
n_classes=n_classes)
我的特征列是这样的
my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]
现在,如果我这样训练,一切正常,训练会在几秒钟内完成:
model.train(
input_fn=tf.estimator.inputs.numpy_input_fn(
dict({'image':X_train}),
y_train,
shuffle=True),
steps=nb_epoch)
但是当我尝试在输入函数中添加 tf.Datasets 时,它需要永远运行:
def input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
return dataset.shuffle(1000).batch(batch_size).repeat()
model.train(
input_fn=lambda:input_fn(X_train, y_train, batch_size),
steps=nb_epoch)
谁能看看我做错了什么?应该是一样的吧?
谢谢, 保罗
【问题讨论】:
-
欢迎您!你有没有提到input pipeline performance guide?例如,您可以尝试使用 prefetch 之类的方法,如图所示。如果您不想在每次重复数据时对数据进行随机播放,您也可以在随机播放中将reshuffle_each_iteration 参数设置为 False。也许这些改进会有所帮助。此外,如果您可以使用融合操作对数据进行混洗和重复,也可能会带来更好的性能!
标签: python tensorflow dataset