【问题标题】:Tensorflow, How can I restore model when some new layers are added?Tensorflow,添加一些新层后如何恢复模型?
【发布时间】:2017-08-07 09:20:13
【问题描述】:

我已经训练了一个模型并保存了检查点。我的模型代码是:

with tf.variable_scope(scope):
    self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32)
    self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
                              kernel_size=[8, 8], stride=4, padding='SAME')
    self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
                              kernel_size=[4, 4], stride=2, padding='SAME')
    self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
                              kernel_size=[3, 3], stride=1, padding='SAME')
    self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu)

    # Output layers for policy and value estimations
    self.policy = slim.fully_connected(self.fc,
                                       cfg.ACTION_DIM,
                                       activation_fn=tf.nn.softmax,
                                       biases_initializer=None)
    self.value = slim.fully_connected(self.fc,
                                      1,
                                      activation_fn=None,
                                      biases_initializer=None)

大约有 32 个进程同时运行,每个进程都有上面代码中定义的全局网络的副本,scope 是每个进程的 id。全球网络的scopeglobal

然后,我想在self.fc 层之后添加更多层。

with tf.variable_scope(scope):
    self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32)
    self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
                              kernel_size=[8, 8], stride=4, padding='SAME')
    self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
                              kernel_size=[4, 4], stride=2, padding='SAME')
    self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
                              kernel_size=[3, 3], stride=1, padding='SAME')
    self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu)

    # Output layers for policy and value estimations
    self.policy = slim.fully_connected(self.fc,
                                       cfg.ACTION_DIM,
                                       activation_fn=tf.nn.softmax,
                                       biases_initializer=None)
    self.value = slim.fully_connected(self.fc,
                                      1,
                                      activation_fn=None,
                                      biases_initializer=None)

    self.new_fc_1 = slim.fully_connected(self.fc, 512, activation_fn=tf.nn.elu)

但是,当我恢复模型时,它报告了以下错误:

2017-08-03 22:23:43.473157: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
2017-08-03 22:23:43.477197: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 379803423 vs. calculated on the restored bytes 2648422677
2017-08-03 22:23:43.477210: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3963326522 vs. calculated on the restored bytes 3154501583
2017-08-03 22:23:43.477200: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3893236466 vs. calculated on the restored bytes 1767411214
2017-08-03 22:23:43.478276: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 4239176201 vs. calculated on the restored bytes 3213118706
2017-08-03 22:23:43.480438: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 442335910 vs. calculated on the restored bytes 4248164641
2017-08-03 22:23:43.483885: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3105262865 vs. calculated on the restored bytes 2648422677
2017-08-03 22:23:43.483953: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
2017-08-03 22:23:43.486987: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
2017-08-03 22:23:43.490616: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
2017-08-03 22:23:43.491951: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
2017-08-03 22:23:43.491957: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
2017-08-03 22:23:43.494310: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint
     [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]]
.... ....

我使用以下代码保存模型

saver.save(sess, self.model_path+'/model-'+str(episode_count)+'.ckpt')

这是定义保护程序的代码

value_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope')
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic'))
saver = tf.train.Saver(value_list, max_to_keep=100)

with tf.Session(config=tf_configs) as sess:
    coord = tf.train.Coordinator()
    if load_model:
        print('Loading Model...')
        ckpt = tf.train.get_checkpoint_state(model_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())

当一些具有随机初始化参数的新层被添加到当前神经网络时,如何恢复预训练模型?

【问题讨论】:

  • 使用旧模型恢复检查点,之后添加新的张量

标签: tensorflow


【解决方案1】:

您可以使用两个单独的变量范围。一个用于保存和加载,一个用于新图层。

然后您可以指定保护程序仅使用第一个范围内的变量:

saver = tf.train.Saver(
    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="save_scope")
)

【讨论】:

  • 你能提供更多的例子吗?
  • value_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope') value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic')) value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/added_layer')) saver = tf.train.Saver(value_list, max_to_keep=100) 我添加了一个新层,后来当我恢复模型时,它报告 Key global/added_layer/fully_connected/weights not found in checkpoint
  • 新层是 'global/added_layer' 范围的一部分吗?如果是,你为什么用 added_layer 范围扩展你的 value_list ?您必须仅使用具有您尝试加载的旧层的范围来初始化 Saver,例如 saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="old_scope"), max_to_keep=100)。否则 Saver 会尝试加载之前不存在的图层。
  • 如果你以后还想保存所有层的快照,那么只需定义多个tf.train.Saver。一个仅用于加载具有限制范围的先前层,一个用于保存所有变量。
【解决方案2】:

google了很久,在@BlueSun的帮助下,发现下面的方法可以帮助解决这个问题。

在添加新范围之前,首先使用当前范围中的变量来保存模型。

value_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope')
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic'))
saver = tf.train.Saver(value_list, max_to_keep=100)

并训练新网络。

稍后,添加新的作用域并在运行模型之前定义一个新的saver,代码如下

value_list = []
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope'))
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic'))
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/added_layer'))
saver = tf.train.Saver(value_list, max_to_keep=100)

with tf.Session(config=tf_configs) as sess:
    coord = tf.train.Coordinator()
    if load_model:
        print('Loading Model...')
        ckpt = tf.train.get_checkpoint_state(model_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="global"), max_to_keep=100)

网络代码如下所示

with tf.variable_scope(scope):
    with tf.variable_scope('old_scope'):
        self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32)
        self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
                                  kernel_size=[8, 8], stride=4, padding='SAME')
        self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
                                  kernel_size=[4, 4], stride=2, padding='SAME')
        self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
                                  kernel_size=[3, 3], stride=1, padding='SAME')
        self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu)

    with tf.variable_scope('added_layer'):
        self.fc_1 = slim.fully_connected(self.fc, 512, activation_fn=tf.nn.elu)

    with tf.variable_scope('actor_critic'):
        # Output layers for policy and value estimations
        self.policy = slim.fully_connected(self.fc_1,
                                         cfg.ACTION_DIM,
                                         activation_fn=tf.nn.softmax, 
                                         biases_initializer=None)
        self.value = slim.fully_connected(self.fc_1,
                                          1,
                                          activation_fn=None,
                                          biases_initializer=None)

现在可以正常工作了,虽然代码看起来有点不雅。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2017-08-21
    • 2018-09-06
    • 2016-05-01
    • 2021-09-17
    • 1970-01-01
    • 2018-05-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多