【问题标题】:Tensorflow estimator issue with datasets数据集的 TensorFlow 估计器问题
【发布时间】:2019-04-29 06:38:56
【问题描述】:

我在使用 TF 估计器时遇到了一个奇怪的问题,并试图在我的输入函数中使用 tf.Dataset。

首先,我的模型如下所示:

    model = tf.estimator.DNNClassifier(
        feature_columns=my_feature_column,
        hidden_units=[hidden_layers, hidden_layers],
        n_classes=n_classes)

我的特征列是这样的

    my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]

现在,如果我这样训练,一切正常,训练会在几秒钟内完成:

    model.train(
        input_fn=tf.estimator.inputs.numpy_input_fn(
            dict({'image':X_train}),
            y_train,
            shuffle=True),
        steps=nb_epoch)

但是当我尝试在输入函数中添加 tf.Datasets 时,它需要永远运行:

def input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
    return dataset.shuffle(1000).batch(batch_size).repeat()

model.train(
    input_fn=lambda:input_fn(X_train, y_train, batch_size),
    steps=nb_epoch)

谁能看看我做错了什么?应该是一样的吧?

谢谢, 保罗

【问题讨论】:

  • 欢迎您!你有没有提到input pipeline performance guide?例如,您可以尝试使用 prefetch 之类的方法,如图所示。如果您不想在每次重复数据时对数据进行随机播放,您也可以在随机播放中将reshuffle_each_iteration 参数设置为 False。也许这些改进会有所帮助。此外,如果您可以使用融合操作对数据进行混洗和重复,也可能会带来更好的性能!

标签: python tensorflow dataset


【解决方案1】:

您的数据集无限重复,并且没有默认的最大迭代次数,因此 tensorflow 不知道何时停止。

return dataset.shuffle(1000).batch(batch_size).repeat() 的行替换为return dataset.shuffle(1000).batch(batch_size).repeat(10) 之类的内容,它将训练10 个epoch,你会没事的。

【讨论】:

  • 谢谢亚历山大!
猜你喜欢
  • 1970-01-01
  • 2018-10-05
  • 2018-05-03
  • 1970-01-01
  • 1970-01-01
  • 2018-06-15
  • 2018-05-06
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多