【发布时间】:2017-07-19 06:49:26
【问题描述】:
(我使用的是 tensorflow 1.0 和 Python 2.7)
我无法让 Estimator 处理队列。实际上,如果我将已弃用的 SKCompat 接口与自定义数据文件和给定的批量大小一起使用,则模型可以正确训练。我正在尝试将新接口与 input_fn 一起使用,该接口从 TFRecord 文件(相当于我的自定义数据文件)中批量处理功能。脚本运行正常,但损失值在 200 或 300 步后不会改变。似乎模型在小批量输入上循环(这可以解释为什么损失收敛得如此之快)。
我有一个如下所示的“run.py”脚本:
import tensorflow as tf
from tensorflow.contrib import learn, metrics
#[...]
evalMetrics = {'accuracy':learn.MetricSpec(metric_fn=metrics.streaming_accuracy)}
runConfig = learn.RunConfig(save_summary_steps=10)
estimator = learn.Estimator(model_fn=myModel,
params=myParams,
modelDir='/tmp/myDir',
config=runConfig)
session = tf.Session(graph=tf.get_default_graph())
with session.as_default():
tf.global_variables_initializer()
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=session,coord=coordinator)
estimator.fit(input_fn=lambda: inputToModel(trainingFileList),steps=10000)
estimator.evaluate(input_fn=lambda: inputToModel(evalFileList),steps=10000,metrics=evalMetrics)
coordinator.request_stop()
coordinator.join(threads)
session.close()
我的 inputToModel 函数如下所示:
import tensorflow as tf
def inputToModel(fileList):
features = {'rawData': tf.FixedLenFeature([100],tf.float32),
'label': tf.FixedLenFeature([],tf.int64)}
tensorDict = tf.contrib.learn.read_batch_record_features(fileList,
batch_size=100,
features=features,
randomize_input=True,
reader_num_threads=4,
num_epochs=1,
name='inputPipeline')
tf.local_variables_initializer()
data = tensorDict['rawData']
labelTensor = tensorDict['label']
inputTensor = tf.reshape(data,[-1,10,10,1])
return inputTensor,labelTensor
欢迎任何帮助或建议!
【问题讨论】:
标签: python tensorflow