【问题标题】:How to restore pretrained checkpoint for current model in Tensorflow?如何在 Tensorflow 中恢复当前模型的预训练检查点?
【发布时间】:2019-03-21 06:57:48
【问题描述】:

我有一个预训练的检查点。现在我正在尝试将这个预训练模型恢复到当前网络。但是,变量名称是不同的。 Tensorflow document 说使用像这样的字典:

v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2": v2})

但是,当前网络中的变量定义如下:

with tf.variable_scope('a'):
    b=tf.get_variable(......)

所以,变量名似乎是a/b。 如何让字典像"v2": a/b

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    您可以使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)获取当前图表中所有变量名称的列表。您也可以指定范围。

    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')
    

    您可以使用tf.train.list_variables(ckpt_file) 获取检查点中所有变量的列表。

    假设您的检查点中有变量 b,并且您想在名称 a/b 下加载到 tf.variable_scope('a')。为此,您只需定义它

    with tf.variable_scope('a'):
        b=tf.get_variable(......)
    

    并加载

    saver = tf.train.Saver({'v2': b})
    
    with tf.Session() as sess:
        saver.restore(sess, ckpt_file))
        print(b)
    

    这将输出

    <tf.Variable 'a/b:0' shape dtype>
    

    编辑:如前所述,您可以使用

    vars_dict = {}
    for var_current in tf.global_variables():
        print(var_current)
        print(var_current.op.name) # this gets only name
    
    for var_ckpt in tf.train.list_variables(ckpt):
        print(var_ckpt[0]) this gets only name
    

    当您知道所有变量的确切名称时,您可以分配您需要的任何值,前提是变量具有相同的形状和 dtype 所以得到一个字典

    vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 
    

    您可以显式地或在您认为合适的任何类型的循环中构建此字典。然后你把它传递给 saver

    saver = tf.train.Saver(vars_dict)
    

    【讨论】:

    • 我从检查点获取变量使用:tf.train.list_variables(ckpt_file)。当前图中的变量使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')。 saver 的字典如下所示:{'generator/bn1_audio/batch_normalization/beta': 'mfcc_encoder/bn1_audio/batch_normalization/beta'}。但我收到了这个错误:TypeError: names_to_saveables must be a dict mapping string names to Tensors/Variables。不是变量:Tensor("Const:0", shape=(), dtype=string).
    • 当你指定作用域时,你将只得到这个作用域的变量。要获取所有变量,您不需要它。然后你会得到'scope_name/var_name'
    • 也许我没有说清楚。我的问题类似于this one。但是,我的问题是,要恢复的变量与saver 不在同一个文件中。所以,我不知道如何定义字典的值。
    • 所以你有当前变量,你需要从检查点加载它们的值,其中变量的形状相同但名称不同,对吧?
    • 是的。这就是我想要的。而且我知道关键是检查点中的变量名,值是当前变量。但是字典和变量不在同一个函数中。情况不像TF文件那么简单。我很困惑如何定义字典的值。
    猜你喜欢
    • 2018-02-16
    • 2020-08-11
    • 2017-07-12
    • 2016-09-29
    • 1970-01-01
    • 2019-08-28
    • 2018-08-05
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多