【发布时间】:2018-07-13 12:57:02
【问题描述】:
如何在train(...) 完成后从tf.estimator.Estimator 获取最后一个global_step?例如,一个典型的基于 Estimator 的训练例程可以这样设置:
n_epochs = 10
model_dir = '/path/to/model_dir'
def model_fn(features, labels, mode, params):
# some code to build the model
pass
def input_fn():
ds = tf.data.Dataset() # obviously with specifying a data source
# manipulate the dataset
return ds
run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
for epoch in range(n_epochs):
estimator.train(input_fn=input_fn)
# Now I want to do something which requires to know the last global step, how to get it?
my_custom_eval_method(global_step)
只有evaluate() 方法返回包含global_step 作为字段的字典。如果由于某种原因我不能或不想使用此方法,我如何获得global_step?
【问题讨论】:
标签: python tensorflow tensorflow-estimator