【问题标题】:How do I find the variable names and values that are saved in a checkpoint?如何找到保存在检查点中的变量名称和值?
【发布时间】:2016-07-06 07:12:20
【问题描述】:

我想查看保存在 TensorFlow 检查点中的变量及其值。如何找到保存在 TensorFlow 检查点中的变量名称?

我使用了tf.train.NewCheckpointReader,解释了here。但是,TensorFlow 的文档中没有给出。有没有其他办法?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    使用示例:

    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    import os
    checkpoint_path = os.path.join(model_dir, "model.ckpt")
    
    # List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
    
    # List contents of v0 tensor.
    # Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
    
    # List contents of v1 tensor.
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
    

    更新: all_tensors 参数已添加到 print_tensors_in_checkpoint_file,因为 Tensorflow 0.12.0-rc0 因此您可能需要添加 all_tensors=Falseall_tensors=True(如果需要)。

    替代方法:

    from tensorflow.python import pywrap_tensorflow
    import os
    
    checkpoint_path = os.path.join(model_dir, "model.ckpt")
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key)) # Remove this is you want to print only variable names
    

    希望对你有帮助。

    【讨论】:

      【解决方案2】:

      您可以使用inspect_checkpoint.py 工具。

      因此,例如,如果您将检查点存储在当前目录中,那么您可以按如下方式打印变量及其值

      import tensorflow as tf
      from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
      
      
      latest_ckp = tf.train.latest_checkpoint('./')
      print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
      

      【讨论】:

      • 它打印变量名称和值,但不返回它们...如何只返回存储在检查点中的变量名称?
      • @kurtosis 您可以更改代码以返回名称和权重。
      【解决方案3】:

      更多细节。

      如果你的模型是用V2格式保存的,比如我们在/my/dir/目录下有如下文件

      model-10000.data-00000-of-00001
      model-10000.index
      model-10000.meta
      

      那么file_name参数应该只是前缀,即

      print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
      

      请参阅https://github.com/tensorflow/tensorflow/issues/7696 进行讨论。

      【讨论】:

        【解决方案4】:

        上述答案的更新

        对于最新的 Tensorflow 版本(在 TF 1.13+ 上验证),更简洁的方法如下

        ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
        value = ckpt_reader.get_tensor(name_of_the_tensor)
        

        name_of_the_tensor 应该对应于变量名称(您要检查其值)。要获取检查点中的变量名称和形状列表,您可以通过

        vars_list = tf.train.list_variables(ckpt_dir_or_file)
        

        【讨论】:

          【解决方案5】:

          print_tensors_in_checkpoint_file添加更多参数详情

          file_name:不是物理文件,只是文件名的前缀

          如果没有提供tensor_name,则打印张量名称和形状 在检查点文件中。如果提供了tensor_name,则打印张量的内容。(inspect_checkpoint.py)

          如果all_tensor_namesTrue,则打印所有张量名称

          如果all_tensor 为 'True`,则打印所有张量名称和对应的内容。

          注意 all_tensorall_tensor_names 将覆盖 tensor_name

          【讨论】:

            【解决方案6】:

            补充说明,print_tensors_in_checkpoint_file 不能打印大张量中的所有值(某些值将被省略为'...')。要查看所有值,您可以使用如下代码

            import tensorflow as tf
            tf.enable_eager_execution()
            from tensorflow.python import pywrap_tensorflow
            reader = pywrap_tensorflow.NewCheckpointReader('/dir/to/ckpt/model.ckpt-81230')
            t = reader.get_tensor('YOUR_TENSOR_NAME_HERE')
            # t is an numpy array, and you can check it like print(list(t))
            

            【讨论】:

              猜你喜欢
              • 1970-01-01
              • 1970-01-01
              • 1970-01-01
              • 1970-01-01
              • 1970-01-01
              • 1970-01-01
              • 2010-09-12
              • 1970-01-01
              • 1970-01-01
              相关资源
              最近更新 更多