【问题标题】:load multiple models in Tensorflow在 TensorFlow 中加载多个模型
【发布时间】:2020-01-08 18:25:56
【问题描述】:

我在 Tensorflow 中编写了以下卷积神经网络 (CNN) 类[为了清楚起见,我尝试省略一些代码行。]

class CNN:
def __init__(self,
                num_filters=16,        # initial number of convolution filters
             num_layers=5,           # number of convolution layers
             num_input=2,           # number of channels in input
             num_output=5,          # number of channels in output
             learning_rate=1e-4,    # learning rate for the optimizer
             display_step = 5000,   # displays training results every display_step epochs
             num_epoch = 10000,     # number of epochs for training
             batch_size= 64,        # batch size for mini-batch processing
             restore_file=None,      # restore file (default: None)

            ):

                # define placeholders
                self.image = tf.placeholder(tf.float32, shape = (None, None, None,self.num_input))  
                self.groundtruth = tf.placeholder(tf.float32, shape = (None, None, None,self.num_output)) 

                # builds CNN and compute prediction
                self.pred = self._build()

                # I have already created a tensorflow session and saver objects
                self.sess = tf.Session()
                self.saver = tf.train.Saver()

                # also, I have defined the loss function and optimizer as
                self.loss = self._loss_function()
                self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)

                if restore_file is not None:
                    print("model exists...loading from the model")
                    self.saver.restore(self.sess,restore_file)
                else:
                    print("model does not exist...initializing")
                    self.sess.run(tf.initialize_all_variables())

def _build(self):
    #builds CNN

def _loss_function(self):
    # computes loss


# 
def train(self, train_x, train_y, val_x, val_y):
    # uses mini batch to minimize the loss
    self.sess.run(self.optimizer, feed_dict = {self.image:sample, self.groundtruth:gt})


    # I save the session after n=10 epochs as:
    if epoch%n==0:
        self.saver.save(sess,'snapshot',global_step = epoch)

# finally my predict function is
def predict(self, X):
    return self.sess.run(self.pred, feed_dict={self.image:X})

我已经为两个独立的任务分别训练了两个 CNN。每个大约需要 1 天。比如说,model1 和 model2 分别保存为“snapshot-model1-10000”和“snapshot-model2-10000”(及其对应的元文件)。我可以分别测试每个模型并计算其性能。

现在,我想在一个脚本中加载这两个模型。我自然会尝试如下:

cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........) 
cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........)

我遇到错误 [错误消息很长。我只是复制/粘贴了它的一个sn-p。]

NotFoundError: Tensor name "Variable_26/Adam_1" not found in checkpoint files /home/amitkrkc/codes/A549_models/snapshot-hela-95000
     [[Node: save_1/restore_slice_85 = RestoreSlice[dt=DT_FLOAT, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/restore_slice_85/tensor_name, save_1/restore_slice_85/shape_and_slice)]]

有没有办法从这两个文件中加载两个单独的 CNN?欢迎任何建议/评论/反馈。

谢谢,

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    是的,有。使用单独的图表。

    g1 = tf.Graph()
    g2 = tf.Graph()
    
    with g1.as_default():
        cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........) 
    with g2.as_default():
        cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........)
    

    编辑:

    如果你想让它们进入同一个图表。您必须重命名一些变量。一个想法是让每个 CNN 处于单独的范围内,并让 saver 处理该范围内的变量,例如:

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), scope='model1')
    

    并在 cnn 中将您的所有构造都包含在范围内:

    with tf.variable_scope('model1'):
        ...
    

    EDIT2:

    其他想法是重命名 saver 管理的变量(因为我假设您想使用保存的检查点而不重新训练所有内容。保存允许在图形和检查点中使用不同的变量名称,请查看初始化文档。

    【讨论】:

    • 非常感谢。您的第一个建议适用于我的情况。
    【解决方案2】:

    这应该是对投票最多的答案的评论。但我没有足够的声誉来做这件事。

    无论如何。 如果您(任何人搜索并到达这一点)仍然无法使用 lpp 提供的解决方案并且您正在使用 Keras,请查看来自 github 的以下引用。

    这是因为如果没有提供默认的 tf 会话,keras 共享一个全局会话

    model1创建时,在graph1上 当model1加载权重时,权重在与graph1关联的keras全局会话上

    model2创建时,在graph2上 model2加载权重时,全局session不知道graph2

    下面的解决方案可能会有所帮助,

    graph1 = Graph()
    with graph1.as_default():
        session1 = Session()
        with session1.as_default():
            with open('model1_arch.json') as arch_file:
                model1 = model_from_json(arch_file.read())
            model1.load_weights('model1_weights.h5')
            # K.get_session() is session1
    
    # do the same for graph2, session2, model2
    

    【讨论】:

      【解决方案3】:

      您需要创建 2 个会话并分别恢复 2 个模型。为此,您需要执行以下操作:

      1a。当您保存模型时,您需要将作用域添加到变量名称中。这样你就会知道哪些变量属于哪个模型:

      # The first model
      tf.Variable(tf.zeros([self.batch_size]), name="model_1/Weights")
      ...
      
      # The second model 
      tf.Variable(tf.zeros([self.batch_size]), name="model_2/Weights")
      ...
      

      1b。或者,如果您已经保存了模型,您可以通过使用this script 添加范围来重命名变量。

      2.. 当你恢复不同的模型时,你需要像这样按变量名过滤:

      # The first model
      sess_1 = tf.Session()
      sess_1.run(tf.initialize_all_variables())
      saver_1 = tf.train.Saver([v for v in tf.all_variables() if 'model_1' in v.name])
      saver_1.restore(sess_1, weights_1_file)
      sess_1.run(pred, feed_dict={image: X})
      
      # The second model
      sess_2 = tf.Session()
      sess_2.run(tf.initialize_all_variables())
      saver_2 = tf.train.Saver([v for v in tf.all_variables() if 'model_2' in v.name])
      saver_2.restore(sess_2, weights_2_file)
      sess_2.run(pred, feed_dict={image: X})
      

      【讨论】:

        【解决方案4】:

        我遇到了同样的问题,但我在互联网上找到的任何解决方案都无法解决问题(无需重新培训)。所以我所做的是将每个模型加载到两个与主线程通信的单独线程中。编写代码很简单,您只需要在同步线程时小心。 在我的例子中,每个线程都接收到其问题的输入并将输出返回给主线程。它可以在没有任何可观察到的开销的情况下工作。

        【讨论】:

        • 你能提供一个解决方案的例子吗?
        【解决方案5】:

        如果您想连续训练或加载多个模型,一种方法是清除会话。您可以使用

        轻松完成此操作
        from keras import backend as K 
        
        # load and use model 1
        
        K.clear_session()
        
        # load and use  model 2
        
        K.clear_session()`
        

        K.clear_session() 销毁当前的 TF 图并创建一个新的。 有助于避免旧模型/图层造成混乱。

        【讨论】:

        • 这在您尝试一个接一个地加载模型时有效,但在您想将它们一起加载时无效。另外,这就是 Keras,Tensorflow 的替代品是 tf.reset_default_graph()
        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2021-09-08
        • 1970-01-01
        • 2023-03-30
        • 2017-03-29
        • 2022-12-30
        • 2021-03-08
        • 1970-01-01
        相关资源
        最近更新 更多