【问题标题】:How to save and load selected and all variables in tensorflow 2.0 using tf.train.Checkpoint?如何使用 tf.train.Checkpoint 在 tensorflow 2.0 中保存和加载所选变量和所有变量?
【发布时间】:2019-03-20 14:03:44
【问题描述】:

如何将下面显示的 tensorflow 2.0 中的选定变量保存在文件中,并使用 tf.train.Checkpoint 将它们加载到另一个代码中的一些已定义变量中?

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element

另外,请展示如何使用 tf.train.Checkpoint 保存所有变量并重新加载。提前致谢。

【问题讨论】:

  • 我不确定我是否理解这个问题。我假设您已经阅读了checkpoints in 2.0 上的信息。如果您为您希望它应该工作的特定变量创建tf.train.Checkpoint,对吗?或者,是什么阻止你这样做?
  • 看不懂上面的官方链接,太复杂了。另外,我不明白为什么每个 tf 教程都需要为 keras 废话编写。如果您可以简单地保存上述 3 个变量并恢复,我会很高兴。此外,保存所有 110 个文件并使用 tf.train.Checkpoint 以一种与文档不同的简单方式恢复。

标签: python tensorflow tensorflow2.0


【解决方案1】:

我不确定这是否是您的意思,但您可以专门为要保存和恢复的变量创建一个tf.train.Checkpoint 对象。请参阅以下示例:

import tensorflow as tf

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()
        self.ckpt = self.makeCheckpoint()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def makeCheckpoint(self):
        return tf.train.Checkpoint(
            init3=self.initList[3], init55=self.initList[55],
            init60=self.initList[60], more4=self.moreList[4])

    def saveVariables(self):
        self.ckpt.save('./ckpt')

    def restoreVariables(self):
        status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
        status.assert_consumed()  # Optional check

# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
    v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
    v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()

# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594  -0.18425298 -0.4898144  -1.2330629   0.08798431]
#  [ 1.5002227   0.99475247  0.7817361   0.3849587  -0.59548247]
#  [-0.57121766 -1.277224    0.6957546  -0.67618763  0.0510064 ]
#  [ 0.85491985  0.13310803 -0.93152267  0.10205163  0.57520276]
#  [-1.0606447  -0.16966362 -1.0448577   0.56799036 -0.90726566]]

# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]]

如果您想保存列表中的所有变量,可以将makeCheckpoint 替换为以下内容:

def makeCheckpoint(self):
    return tf.train.Checkpoint(
        **{f'init{i}': v for i, v in enumerate(self.initList)},
        **{f'more{i}': v for i, v in enumerate(self.moreList)})

请注意,您可以拥有“嵌套”检查点,因此更一般地说,您可以拥有一个为变量列表创建检查点的函数,例如:

def listCheckpoint(varList):
    # Use 'item{}'.format(i) if using Python <3.6
    return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})

那么你可以有这个:

def makeCheckpoint(self):
    return tf.train.Checkpoint(init=listCheckpoint(self.initList),
                               more=listCheckpoint(self.moreList))

【讨论】:

  • 非常感谢@jdehesa。这正是我想要的。还请说明如何保存所有变量并恢复,因为此方法不适用于许多变量。我会接受你的回答并适当地编辑我的问题,以便其他人也能从中受益。
  • @caissalover 我已经编辑了答案,看看是否涵盖了您要查找的内容。
  • 不完全。这是最后一个的蛮力。 “管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。tf.train.Checkpoint、tf.keras.layers.Layer 和 tf.keras.Model 的子类自动跟踪分配给它们的属性的变量”在官方文档中,创建一个检查点并给出 v1,我们的对象作为它的参数应该保存它。仍然没有得到这些人所做的事情,但是这样的事情应该将所有内容保存在任何 manyVariables 对象中。之前,将会话对象提供给 tf.train.Saver 更简单
  • @caissalover 我同意你的观点,许多变化似乎有点“Keras 方式或高速公路”。过去检查点只是默认将变量保存在全局变量的集合中,但是 2.x 上没有集合,所以这已经消失了。正如它所说,如果你围绕 Keras 模型/层构建你的东西​​,那么一切都应该是 Just Work™,但如果你想以不同的方式做事,那么你在很大程度上只能靠自己。当然,您可以让自己的类为您管理它,使用可覆盖的方法来创建检查点变量......这将重复 TF/Keras 所做的事情。
  • 那么我想你的方法就是要走的路。我认为让 manyVariables 成为 tf.train.Checkpoint 的子类应该做点什么,但继承不是我的事。我现在会坚持你的方式。如果没有全局变量的集合,保存和恢复肯定会更加困难。
【解决方案2】:

在下面的代码中,我将一个名为 variables 的数组保存到一个您选择的名称的 .txt 文件中。该文件将与您的 python 文件位于同一文件夹中。 open 函数中的“wb”表示使用截断写入(因此删除文件中以前的所有内容)并使用字节格式。我使用 pickle 来处理保存/解析列表。

import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb+') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables

要将特定内容保存到文件中,只需将其作为变量参数添加到 saveVariables 中,如下所示:

myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)

从具有特定名称的文本文件中检索变量:

myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]

您可以在类中的任何位置添加这些函数。

然而,为了能够在您的示例中访问 initList/moreList,您应该从它们的函数中返回它们(就像我对 variables 列表所做的那样)或将它们设为全局。 p>

【讨论】:

  • 直截了当,得到一个错误,即变量必须是 file.write() 中的 str。能够将 str(variables) 保存为 file.write 参数,但无法加载并转换回张量。 tf.train.Checkpoint() 一定有更好的方法。
  • 哎呀你就在那里,我忘了仔细检查。您可以使用 pickle 使其工作,它允许您将许多格式保存为字节。我会更新答案。 Tensorflow 可能确实有一个自动功能,我从未使用过它,所以我不确定。
  • 完成了这项工作。非常感谢你。虽然也应该有一些特定于 tensorflow 的方式,所以如果可能的话,我会再等一会儿,等待 tf.train.Checkpoint 的另一个答案..
猜你喜欢
  • 2018-12-03
  • 1970-01-01
  • 2018-11-17
  • 2016-05-02
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多