【问题标题】:How to structure Tensorflow model code?如何构建 TensorFlow 模型代码?
【发布时间】:2017-07-04 00:23:31
【问题描述】:

我很难找到如何构建我的 Tensorflow 模型代码。我想以 Class 的形式构造它,以方便将来重用。另外,我目前的结构很乱,张量板图输出里面有多个“模型”。

以下是我目前拥有的:

import tensorflow as tf
import os

from utils import Utils as utils

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

class Neural_Network:
    # Neural Network Setup
    num_of_epoch = 50

    n_nodes_hl1 = 500
    n_nodes_hl2 = 500
    n_nodes_hl3 = 500

    def __init__(self):
        self.num_of_classes = utils.get_num_of_classes()
        self.num_of_words = utils.get_num_of_words()

        # placeholders
        self.x = tf.placeholder(tf.float32, [None, self.num_of_words])
        self.y = tf.placeholder(tf.int32, [None, self.num_of_classes])

        with tf.name_scope("model"):
            self.h1_layer = tf.layers.dense(self.x, self.n_nodes_hl1, activation=tf.nn.relu, name="h1")
            self.h2_layer = tf.layers.dense(self.h1_layer, self.n_nodes_hl2, activation=tf.nn.relu, name="h2")
            self.h3_layer = tf.layers.dense(self.h2_layer, self.n_nodes_hl3, activation=tf.nn.relu, name="h3")

            self.logits = tf.layers.dense(self.h3_layer, self.num_of_classes, name="output")

    def predict(self):
        return self.logits

    def make_prediction(self, query):
        result = None

        with tf.Session() as sess:
            saver = tf.train.import_meta_graph('saved_models/testing.meta')
            saver.restore(sess, 'saved_models/testing')

            sess.run(tf.global_variables_initializer())

            prediction = self.predict()
            prediction = sess.run(prediction, feed_dict={self.x : query})
            prediction = prediction.tolist()
            prediction = tf.nn.softmax(prediction)
            prediction = sess.run(prediction)
            print prediction

            return utils.get_label_from_encoding(prediction[0])

    def train(self, data):

        print len(data['values'])
        print len(data['labels'])

        prediction = self.predict()

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=self.y))
        optimizer = tf.train.AdamOptimizer().minimize(cost)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())

            for epoch in range(self.num_of_epoch):
                optimised, loss = sess.run([optimizer, cost],
                                           feed_dict={self.x: data['values'], self.y: data['labels']})

                if epoch % 1 == 0:
                    print("Completed Training Cycle: " + str(epoch) + " out of " + str(self.num_of_epoch))
                    print("Current Loss: " + str(loss))

                    saver = tf.train.Saver()
                    saver.save(sess, 'saved_models/testing')
                    print("Model saved")

我在网上发现很多使用低级代码,例如 tf.Variablestf.Constant,因此,他们更能够拆分他们的代码。但是,由于我对 Tensorflow 比较陌生,所以我想先使用更高级别的代码。

谁能告诉我如何构建我的代码?

【问题讨论】:

标签: machine-learning tensorflow


【解决方案1】:

正如评论所言,对您最初问题的简短回答是阅读this,但当您在 ​​cmets 中提出后续问题时,我认为它需要一个更完整的答案。


谁能告诉我如何构建我的代码?

显然,构建代码是一个品味问题。但是,为了帮助您制作自己的口味,您需要牢记以下主要事项:TensorFlow 中有 2 个不同的层,不要混淆它们。

  • 第一个是Graph 层,它包含所有 TensorFlow 节点,例如
    • tensors(例如tf.placeholdertf.constanttf.Variables 等),或
    • operationstf.addtf.matmul 等)。 Graph 包含您的模型 本身,并且可能包含更多内容,例如:损失函数、训练模型的优化器、输入数据管道等。

每个节点都有一个名称,您可以使用它直接从图中检索它(例如,使用tf.get_variable 方法或tf.Graph.get_tensor_by_name)。

  • 第二层是您使用 Python(或 C++ 或 Java,...)API 构建 TensorFlow Graph 的方式。这可能是您在提问时想到的这一层。但是,在某种程度上,这一层实际上更像是一个模型工厂,而不是一个模型。

该格式是否支持模型的保存和恢复?

这取决于您对 model 的含义,即使两种情况的答案都是肯定的。

  • 如果您想到了 TensorFlow Graph,答案是,您可以保存和恢复您的 Graph,因为它不取决于您如何构建它。只需查看此document保存和恢复部分,了解如何操作或查看此answer,其中仅恢复了Graph
  • 如果您想到了 Python 类,简短的回答是但是你可以编造一些东西来使它成为
    如上一项所述,TensorFlow 检查点不保存 Python(也不是 C++ 或 Java)对象,而只保存图形。但是您的 model 作为 Python 类的结构存在于其他地方:它存在于您的代码中。

    因此,如果您重新创建 Python 类的实例,并且确保在 Graph 中重新创建所有 TensorFlow 节点(因此您将获得等效的 Graph),那么,当您恢复TensorFlow Graph 从检查点开始,您的模型作为 Python-instance-linked-to-a-TensorFlow-Graph 将被恢复。

    请参阅document恢复变量 部分,其中 Python-instances-linked-to-a-TensorFlow-Graph 实际上是 Python 变量(即 v1v2) 位于模块范围内。

    # Create some variables.
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:
      # Restore variables from disk.
      saver.restore(sess, "/tmp/model.ckpt")
      print("Model restored.")
      # Do some work with the model
      ...
    

我只能推荐阅读(并点赞 :))question 及其答案,因为您将学到很多关于如何在 TensorFlow 中保存/恢复的知识。


希望现在清楚一点。

【讨论】:

  • 这是否意味着假设我有一个定义网络结构的 Network() 类。所以要恢复这个模型,我所要做的就是创建一个 Network() 实例并继续使用tf.train.Saver().restore() 来恢复图形?不需要tf.train.import_meta_graph(),因为创建实例已经定义了图的结构。
  • 没错。无需重新导入图的结构,因为您将通过实例定义它。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2020-04-16
  • 2017-06-07
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多