【发布时间】:2016-07-06 07:12:20
【问题描述】:
我想查看保存在 TensorFlow 检查点中的变量及其值。如何找到保存在 TensorFlow 检查点中的变量名称?
我使用了tf.train.NewCheckpointReader,解释了here。但是,TensorFlow 的文档中没有给出。有没有其他办法?
【问题讨论】:
标签: tensorflow
我想查看保存在 TensorFlow 检查点中的变量及其值。如何找到保存在 TensorFlow 检查点中的变量名称?
我使用了tf.train.NewCheckpointReader,解释了here。但是,TensorFlow 的文档中没有给出。有没有其他办法?
【问题讨论】:
标签: tensorflow
使用示例:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
更新: all_tensors 参数已添加到 print_tensors_in_checkpoint_file,因为 Tensorflow 0.12.0-rc0 因此您可能需要添加 all_tensors=False 或 all_tensors=True(如果需要)。
替代方法:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
希望对你有帮助。
【讨论】:
您可以使用inspect_checkpoint.py 工具。
因此,例如,如果您将检查点存储在当前目录中,那么您可以按如下方式打印变量及其值
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
【讨论】:
更多细节。
如果你的模型是用V2格式保存的,比如我们在/my/dir/目录下有如下文件
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
那么file_name参数应该只是前缀,即
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
请参阅https://github.com/tensorflow/tensorflow/issues/7696 进行讨论。
【讨论】:
上述答案的更新
对于最新的 Tensorflow 版本(在 TF 1.13+ 上验证),更简洁的方法如下
ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
value = ckpt_reader.get_tensor(name_of_the_tensor)
name_of_the_tensor 应该对应于变量名称(您要检查其值)。要获取检查点中的变量名称和形状列表,您可以通过
vars_list = tf.train.list_variables(ckpt_dir_or_file)
【讨论】:
向print_tensors_in_checkpoint_file添加更多参数详情
file_name:不是物理文件,只是文件名的前缀
如果没有提供tensor_name,则打印张量名称和形状
在检查点文件中。如果提供了tensor_name,则打印张量的内容。(inspect_checkpoint.py)
如果all_tensor_names 是True,则打印所有张量名称
如果all_tensor 为 'True`,则打印所有张量名称和对应的内容。
注意 all_tensor 和 all_tensor_names 将覆盖 tensor_name
【讨论】:
补充说明,print_tensors_in_checkpoint_file 不能打印大张量中的所有值(某些值将被省略为'...')。要查看所有值,您可以使用如下代码
import tensorflow as tf
tf.enable_eager_execution()
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('/dir/to/ckpt/model.ckpt-81230')
t = reader.get_tensor('YOUR_TENSOR_NAME_HERE')
# t is an numpy array, and you can check it like print(list(t))
【讨论】: