【问题标题】:Stop Tensorflow from reloading model per request阻止 Tensorflow 为每个请求重新加载模型
【发布时间】:2018-06-06 14:35:30
【问题描述】:

我继承了一些 TF 代码,每个请求执行以下操作:

def predict_tf(ml_wrapper, prediction_row_df):
    log.debug('request POST {}'.format(prediction_row_df))

    prediction_row_df, _, _ = ml_wrapper._engineer_features(prediction_row_df)
    # As method says; delete stuff we don't want and scale and impute if needed
    ml_wrapper._delete_unused_values_and_scale_and_impute_missing_values(prediction_row_df, single_row_prediction=True)
    features, labels = ml_wrapper._split_features_and_labels(prediction_row_df)
    panda_function_for_prediction = tf.estimator.inputs.pandas_input_fn(
        features,
        labels,
        batch_size=ml_wrapper.batch_size,
        num_epochs=1,
        shuffle=False
    )
    predictions = ml_wrapper.tf_model.predict(
        input_fn=panda_function_for_prediction)
    probas = list(predictions)[0]['probabilities']
    log.warning('PREDICTED: no:{} yes:{}'.format(probas[0], probas[1]))
    return probas

代码似乎可以工作,但在控制台中我看到如下内容:

2018-06-06 16:32:46,767 INFO  [tensorflow:116] Calling model_fn.
2018-06-06 16:32:50,848 INFO  [tensorflow:116] Done calling model_fn.
2018-06-06 16:32:51,082 INFO  [tensorflow:116] Graph was finalized.
2018-06-06 16:32:51,083 INFO  [tensorflow:116] Restoring parameters from /model_tensorflow/model.ckpt-719
2018-06-06 16:32:51,494 INFO  [tensorflow:116] Running local_init_op.
2018-06-06 16:32:51,536 INFO  [tensorflow:116] Done running local_init_op.

这个操作似乎每个请求需要 4 秒 - 有没有办法只加载一次模型/估计器并对其进行预测?

【问题讨论】:

  • 你能确定从 protobuf 加载 tensorflow 模型的位置吗?我假设它在ml_wrapper 方法调用之一中。您绝对应该能够在 predict_tf 端点调用之外一次性加载模型。
  • self.tf_model = tf.estimator.DNNLinearCombinedClassifier(model_dir=self.model_dir, ...) - 这在应用程序生命周期的早期被调用(在应用程序启动时),我预计它会加载数据 - 但似乎模型是在每次 predict() 调用时重新加载的。
  • 我明白了。确实不需要重新加载(对象不会重新初始化,是吗?)但如果不深入研究ml_wrapper 和应用程序主体,我就不能说更多。我会在predict_tf 上进行线路分析,然后从那里向下钻取。抱歉没有更清楚的;也许其他人会。
  • 这似乎是一个常见问题,groups.google.com/a/tensorflow.org/forum/#!topic/discuss/… 并且我认为它与我的代码流程无关。不过,我正在努力使用生成器方法:(

标签: python tensorflow tensorflow-serving tensorflow-estimator


【解决方案1】:

如果您使用的是Estimator,那么您可以试试这个: https://github.com/marcsto/rl/blob/master/src/fast_predict2.py 这基本上可以防止重新加载图表。

【讨论】:

  • 这听起来更像是一条评论。如果您认为它可以回答问题,请提供提问者应如何修改其代码以使其按需要工作。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2015-07-30
  • 2021-12-26
  • 2011-03-17
  • 1970-01-01
  • 2015-06-01
相关资源
最近更新 更多