【问题标题】:add new output for pre-trained model为预训练模型添加新输出
【发布时间】:2017-08-18 00:41:22
【问题描述】:

我很困惑为预训练模型添加新类,到目前为止我所做的是恢复预训练检查点并创建大小为 m * C+1 的矩阵和长度为 C+1 的向量,然后从现有权重初始化这些的前 C 行/元素,并通过仅在 Optimizer.minimize() 中训练 FC 层来冻结前一层。但是当我运行代码时,我得到了这个错误:

Traceback (most recent call last):
File "/home/tensorflow/tensorflow/models/image/mnist/new_dataset/Nets.py", line 482, in <module>
new_op_w = optimizer_new.minimize(loss, var_list = resize_var_w)
File "/home/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 279, in minimize
grad_loss=grad_loss)
File "/home/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 337, in compute_gradients
processors = [_get_processor(v) for v in var_list]
File "/home/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 502, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

这就是代码:

with tf.Session(graph=graph) as sess:
  if os.path.isfile(ckpt):
       aver.restore(sess, 'path_to_checkpoint.ckpt')
       w_b_new = {

        'weight_4': tf.Variable(tf.random_normal([num_hidden, 1], stddev=0.1), name = 'weight_4'),
        'bias_4'  : tf.Variable(tf.constant(1.0, shape=[1]), name = 'bias_4'),}
       change_1 = tf.unstack(w_b_not['weight_4'])
       change_2 = tf.unstack(w_b_not['bias_4'])
       change_3 = tf.unstack(w_b_new['weight_4'])
       change_4 = tf.unstack(w_b_new['bias_4'])
       changestep1 = []
       for i in range(len(change_1)):
        changestep1.append(tf.unstack(change_1[i]))          
       changestep3 = []
       for i in range(len(change_3)):
        changestep3.append(tf.unstack(change_3[i]))
        for j in range(len(changestep3[i])):
          changestep1[i].append(changestep3[i][j])

        changestep1[i] = tf.stack(changestep1[i])
       final1 = tf.stack(changestep1)
       resize_var_w = tf.assign(w_b_not['weight_4'], final1, validate_shape=False)
       final2 = tf.concat([w_b_not['bias_4'] ,  w_b_new['bias_4']], axis=0)
       resize_var = tf.assign(w_b_not['bias_4'], final2, validate_shape=False)

       optimizer_new = tf.train.GradientDescentOptimizer(0.01)
       new_op_w = optimizer_new.minimize(loss, var_list = resize_var_w)
       new_op_b = optimizer_new.minimize(loss, var_list = resize_var)       
       for step in range(num_steps,num_steps + num_train_steps):
          offset = (step * batch_size) % (train_labels.shape[0] - batch_size)       
          batch_data = train_dataset[offset:(offset + batch_size), :, :, :]      
          batch_labels = train_labels[offset:(offset + batch_size), :]
          feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels , keep_prob:0.5}        
          _,_, l, predictions = sess.run([new_op_w,new_op_b, loss, train_prediction ], feed_dict=feed_dict)
          if (step % 50 == 0):
            print('%d\t%f\t%.1f%%\t%.1f%%' % (step, l, accuracy(predictions, batch_labels), accuracy(valid_prediction.eval(), valid_labels)))    
       print('Test accuracy: %.1f%%' % accuracy(test_prediction.eval() , test_labels))

       save_path_w_b = saver.save(sess, "path_checkpoint.ckpt")
       print("Model saved in file: %s" % save_path_w_b)

【问题讨论】:

    标签: python tensorflow deep-learning conv-neural-network


    【解决方案1】:

    根据 GradientDescentOptimizer 的 minimize method 上的 TensorFlow 文档,“var_list”必须是变量对象的列表。根据您的代码,resize_var_w 是单个张量。

    编辑 具体来说:

    如果给优化器var_list,顾名思义,这一定是一个变量列表。在反向传播期间,优化器将循环 var_list 并仅更新列表中的变量,而不是图中的所有可训练变量。单个变量不可迭代。

    如果您只想更新单个 Tensor,您可以简单地尝试:

    resize_var_w = [tf.assign(w_b_not['weight_4'], final1, validate_shape=False)]
    

    我没有测试,但应该可以。

    【讨论】:

    • 是的,但我已经在 resize_var_w 中分配了新变量
    • 调试您的脚本并打印type(resize_var_w)。它是打印 TensorFlow 张量变量还是 python 列表?
    • 这是为 resize_var_w 打印的输出
    • 根据assign doctf.assign的输入是单个张量并返回单个张量。您需要一个list() 的变量来更新。基本上,如果将var_list 传递给优化器,则仅更新该列表中的变量。默认是更新图中的所有变量。所以var_list 必须是一个列表,而不是单个张量。
    • 有没有办法把它改成变量?
    猜你喜欢
    • 2019-08-15
    • 1970-01-01
    • 1970-01-01
    • 2018-09-25
    • 2017-08-19
    • 2018-11-02
    • 2023-01-03
    • 1970-01-01
    • 2019-01-16
    相关资源
    最近更新 更多