【问题标题】:How to initialize the weights of a network with the weights of another network?如何用另一个网络的权重初始化一个网络的权重?
【发布时间】:2018-10-15 04:09:03
【问题描述】:

我想将 2 个网络合并为一个网络,同时保持原始网络的权重。

我使用以下方法以 numpy 形式保存了权重:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    weights[i.name] = i.eval()

我找不到将权重加载到新网络变量中的方法。 有没有办法将权重加载到所有变量?

我尝试了以下方法,但出现错误:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    i.initializer = weights[i.name]

错误:

AttributeError: can't set attribute

【问题讨论】:

    标签: python numpy tensorflow


    【解决方案1】:

    两个函数都可以写

    def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
        return {v.name: sess.run(v) for v in tf.global_variables()}
    
    
    def load_from_dict(sess, data):
        for v in tf.global_variables():
            if v.name in data.keys():
                sess.run(v.assign(data[v.name]))
    

    诀窍是简单地遍历所有变量并检查它们是否存在于字典中,例如

    import tensorflow as tf
    import numpy as np
    
    
    def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
        return {v.name: sess.run(v) for v in tf.global_variables()}
    
    
    def load_from_dict(sess, data):
        for v in tf.global_variables():
            if v.name in data.keys():
                sess.run(v.assign(data[v.name]))
    
    
    def network(x):
        x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc0')
        x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc1')
        x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc2')
        x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc3')
        x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc4')
        return x
    
    
    element = np.random.randn(8, 10)
    weights = None
    
    # first session
    with tf.Session() as sess:
    
        x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
        y = network(x)
        sess.run(tf.global_variables_initializer())
    
        # first evaluation
        expected = sess.run(y, {x: element})
    
        # dump as dict
        weights = save_to_dict(sess)
    
    # destroy session and graph
    tf.reset_default_graph()
    
    # second session
    with tf.Session() as sess:
    
        x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
        y = network(x)
        sess.run(tf.global_variables_initializer())
    
        # use randomly initialized parameters
        actual = sess.run(y, {x: element})
        assert np.sum(np.abs(actual - expected)) > 0  # should NOT match
    
        # load previous parameters
        load_from_dict(sess, weights)
    
        actual = sess.run(y, {x: element})
        assert np.sum(np.abs(actual - expected)) == 0  # should match
    

    这样,您可以简单地从字典中删除一些参数,在加载之前更改权重,甚至更改参数名称。

    【讨论】:

      猜你喜欢
      • 2020-02-08
      • 1970-01-01
      • 2019-10-09
      • 2014-10-30
      • 2021-06-25
      • 2011-12-09
      • 1970-01-01
      • 2020-03-28
      • 2019-04-10
      相关资源
      最近更新 更多