【问题标题】:Using 'read_batch_record_features' with an Estimator将“read_batch_record_features”与 Estimator 一起使用
【发布时间】:2017-07-19 06:49:26
【问题描述】:

(我使用的是 tensorflow 1.0 和 Python 2.7)

我无法让 Estimator 处理队列。实际上,如果我将已弃用的 SKCompat 接口与自定义数据文件和给定的批量大小一起使用,则模型可以正确训练。我正在尝试将新接口与 input_fn 一起使用,该接口从 TFRecord 文件(相当于我的自定义数据文件)中批量处理功能。脚本运行正常,但损失值在 200 或 300 步后不会改变。似乎模型在小批量输入上循环(这可以解释为什么损失收敛得如此之快)。

我有一个如下所示的“run.py”脚本:

import tensorflow as tf
from tensorflow.contrib import learn, metrics

#[...]
evalMetrics = {'accuracy':learn.MetricSpec(metric_fn=metrics.streaming_accuracy)}
runConfig = learn.RunConfig(save_summary_steps=10)
estimator = learn.Estimator(model_fn=myModel,
                            params=myParams,
                            modelDir='/tmp/myDir',
                            config=runConfig)

session = tf.Session(graph=tf.get_default_graph())

with session.as_default():
  tf.global_variables_initializer()
  coordinator = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=session,coord=coordinator)

  estimator.fit(input_fn=lambda: inputToModel(trainingFileList),steps=10000)

  estimator.evaluate(input_fn=lambda: inputToModel(evalFileList),steps=10000,metrics=evalMetrics)

  coordinator.request_stop()
  coordinator.join(threads)
session.close()

我的 inputToModel 函数如下所示:

import tensorflow as tf

def inputToModel(fileList):
  features = {'rawData': tf.FixedLenFeature([100],tf.float32),
              'label': tf.FixedLenFeature([],tf.int64)}
  tensorDict = tf.contrib.learn.read_batch_record_features(fileList,
                                batch_size=100,
                                features=features,
                                randomize_input=True,
                                reader_num_threads=4,
                                num_epochs=1,
                                name='inputPipeline')
  tf.local_variables_initializer()
  data = tensorDict['rawData']
  labelTensor = tensorDict['label']
  inputTensor = tf.reshape(data,[-1,10,10,1])

  return inputTensor,labelTensor

欢迎任何帮助或建议!

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    尝试使用:tf.global_variables_initializer().run()

    我想做类似的事情,但我不知道如何将 Estimator API 与多线程一起使用。还有一个实验类也可以提供服务 - 可能有用

    删除session = tf.Session(graph=tf.get_default_graph())session.close() 并尝试:

    with tf.Session() as sess:
      tf.global_variables_initializer().run()
    

    【讨论】:

    • 它不起作用:我收到一条错误消息:“ValueError:无法使用给定的会话执行操作:操作的图表与会话的图表不同。”
    • 我用新的想法为你更新了答案。异常可能是由于双会话初始化...我无法测试您的代码。
    • 好的。我不再遇到异常,但不幸的是它并没有解决我的问题。
    • 我能给你的其他建议是使用 Experiment 类,而不是 Estimator。请在此处查看详细信息tensorflow.org/api_docs/python/tf/contrib/learn/…
    • 我发布了类似的更通用的问题:stackoverflow.com/questions/42529598/…
    猜你喜欢
    • 1970-01-01
    • 2022-11-24
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多