【问题标题】:How to deploy locally trained TensorFlow graph file to Google Cloud Platform?如何将本地训练的 TensorFlow 图形文件部署到 Google Cloud Platform?
【发布时间】:2017-11-01 16:09:33
【问题描述】:

我遵循了 TensorFlow for Poets 教程,并用我自己的一些课程替换了现有的flower_photos。现在我的labels.txt 文件和graph.pb 保存在我的本地计算机上。

我有没有办法将此预训练模型部署到 Google Cloud Platform?我一直在阅读文档,我能找到的只是关于如何从他们的 ML 引擎中创建、训练和部署模型的说明。但是我不想花钱在 Google 的服务器上训练我的模型,因为我只需要它们来托管我的模型,以便我可以调用它来进行预测。

还有人遇到同样的问题吗?

【问题讨论】:

    标签: machine-learning deployment tensorflow google-cloud-platform google-cloud-ml-engine


    【解决方案1】:

    部署本地训练的模型是受支持的用例;无论您在哪里训练,instructions 本质上都是相同的:

    要部署您需要的模型版本:

    保存在 Google Cloud Storage 上的 TensorFlow SavedModel。你可以得到一个 模特:

    • 按照 Cloud ML Engine 训练步骤在 云。

    • 在其他地方训练并导出到 SavedModel。

    不幸的是,TensorFlow for Poets 没有显示如何导出 SavedModel(我已提交功能请求以解决该问题)。同时,您可以编写如下所示的“转换器”脚本(您也可以在训练结束时执行此操作,而不是保存 graph.pb 并重新读取):

    input_graph = 'graph.pb'
    saved_model_dir = 'my_model'
    
    with tf.Graph() as graph:
      # Read in the export graph
      with tf.gfile.FastGFile(input_graph, 'rb') as f:
          graph_def = tf.GraphDef()
          graph_def.ParseFromString(f.read())
          tf.import_graph_def(graph_def, name='')
    
      # CloudML Engine and early versions of TensorFlow Serving do
      # not currently support graphs without variables. Add a
      # prosthetic variable.
      dummy_var = tf.Variable(0)
    
      # Define SavedModel Signature (inputs and outputs)
      in_image = graph.get_tensor_by_name('DecodeJpeg/contents:0')
      inputs = {'image_bytes': 
    tf.saved_model.utils.build_tensor_info(in_image)}
    
      out_classes = graph.get_tensor_by_name('final_result:0')
      outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
    
      signature = tf.saved_model.signature_def_utils.build_signature_def(
          inputs=inputs,
          outputs=outputs,
          method_name='tensorflow/serving/predict'
      )
    
      # Save out the SavedModel.
      b = saved_model_builder.SavedModelBuilder(saved_model_dir)
      b.add_meta_graph_and_variables(sess,
                                     [tf.saved_model.tag_constants.SERVING],
                                     signature_def_map={'predict_images': signature})
      b.save() 
    

    (基于this codelabthis SO post 的未经测试的代码)。

    如果您希望输出使用字符串标签而不是整数索引,请进行以下更改:

      # Loads label file, strips off carriage return
      label_lines = [line.rstrip() for line 
                     in tf.gfile.GFile("retrained_labels.txt")]
      out_classes = graph.get_tensor_by_name('final_result:0')
      out_labels = tf.gather(label_lines, ot_classes)
      outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_labels)}
    

    【讨论】:

      【解决方案2】:

      不幸的是,仅部分回答,但我已经能够做到这一点......但有一些我尚未解决的持续问题。我将经过训练的 pb 和 txt 文件移植到我的服务器上,安装了 Tensorflow,并通过 HTTP 请求调用经过训练的模型。它完美地工作......在第一次运行时。然后每隔一段时间就失败一次。

      tensorflow deployment on openshift, errors with gunicorn and mod_wsgi

      很惊讶没有更多的人试图解决这个普遍问题。

      【讨论】:

        猜你喜欢
        • 2018-07-16
        • 2018-12-24
        • 1970-01-01
        • 2020-11-14
        • 2023-04-07
        • 2022-01-08
        • 2022-01-17
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多