【问题标题】:how to store numpy arrays as tfrecord?如何将numpy数组存储为tfrecord?
【发布时间】:2018-05-31 08:38:56
【问题描述】:

我正在尝试从 numpy 数组创建 tfrecord 格式的数据集。我正在尝试存储 2d 和 3d 坐标。

2d 坐标是 float64 类型的形状 (2,10) 的 numpy 数组 3d 坐标是 float64 类型的形状 (3,10) 的 numpy 数组

这是我的代码:

def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


train_filename = 'train.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)


for c in range(0,1000):

    #get 2d and 3d coordinates and save in c2d and c3d

    feature = {'train/coord2d': _floats_feature(c2d),
                   'train/coord3d': _floats_feature(c3d)}
    sample = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(sample.SerializeToString())

writer.close()

当我运行它时,我得到了错误:

  feature = {'train/coord2d': _floats_feature(c2d),
  File "genData.py", line 19, in _floats_feature
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\google\protobuf\internal\python_message.py", line 510, in init
copy.extend(field_value)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\google\protobuf\internal\containers.py", line 275, in extend
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\google\protobuf\internal\containers.py", line 275, in <listcomp>
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\google\protobuf\internal\type_checkers.py", line 109, in CheckValue
raise TypeError(message)
TypeError: array([-163.685,  240.818, -114.05 , -518.554,  107.968,  427.184,
    157.418, -161.798,   87.102,  406.318]) has type <class 'numpy.ndarray'>, but expected one of: ((<class 'numbers.Real'>,),)

我不知道如何解决这个问题。我应该将功能存储为 int64 还是字节?我不知道该怎么做,因为我对 tensorflow 完全陌生。任何帮助都会很棒!谢谢

【问题讨论】:

    标签: python numpy tensorflow tfrecord


    【解决方案1】:

    Tensorflow-Guide 中描述的函数 _floats_feature 需要一个标量(float32 或 float64)作为输入。

    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    

    如您所见,输入的标量被写入列表 (value=[value]),随后将其作为输入提供给 tf.train.FloatListtf.train.FloatList 期望迭代器在每次迭代中输出一个浮点数(就像列表一样)。

    如果您的特征不是标量而是向量,则可以重写 _float_feature 以将迭代器直接传递给 tf.train.FloatList(而不是先将其放入列表中)。

    def _float_array_feature(value):
      return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    

    但是,如果您的特征有两个或更多维度,则此解决方案不再适用。就像@mmry 在他的回答中所描述的那样,在这种情况下,将您的特征展平或将其拆分为几个一维特征将是一种解决方案。这两种方法的缺点是,如果不付出更多努力,有关特征实际形状的信息就会丢失。

    为更高维数组编写示例消息的另一种可能性是将数组转换为字节字符串,然后使用 Tensorflow-Guide 中描述的_bytes_feature 函数为其编写示例消息。然后将示例消息序列化并写入 TFRecord 文件。

    import tensorflow as tf
    import numpy as np
    
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))): # if value ist tensor
            value = value.numpy() # get value of tensor
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def serialize_array(array):
      array = tf.io.serialize_tensor(array)
      return array
    
    
    #----------------------------------------------------------------------------------
    # Create example data
    array_blueprint = np.arange(4, dtype='float64').reshape(2,2)
    arrays = [array_blueprint+1, array_blueprint+2, array_blueprint+3]
    
    #----------------------------------------------------------------------------------
    # Write TFrecord file
    file_path = 'data.tfrecords'
    with tf.io.TFRecordWriter(file_path) as writer:
      for array in arrays:
        serialized_array = serialize_array(array)
        feature = {'b_feature': _bytes_feature(serialized_array)}
        example_message = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example_message.SerializeToString())
    

    可以通过tf.data.TFRecordDataset 访问存储在 TFRecord 文件中的序列化示例消息。解析示例消息后,需要从转换为的字节字符串中恢复原始数组。这可以通过tf.io.parse_tensor 实现。

    # Read TFRecord file
    def _parse_tfr_element(element):
      parse_dic = {
        'b_feature': tf.io.FixedLenFeature([], tf.string), # Note that it is tf.string, not tf.float32
        }
      example_message = tf.io.parse_single_example(element, parse_dic)
    
      b_feature = example_message['b_feature'] # get byte string
      feature = tf.io.parse_tensor(b_feature, out_type=tf.float64) # restore 2D array from byte string
      return feature
    
    
    tfr_dataset = tf.data.TFRecordDataset('data.tfrecords') 
    for serialized_instance in tfr_dataset:
      print(serialized_instance) # print serialized example messages
    
    dataset = tfr_dataset.map(_parse_tfr_element)
    for instance in dataset:
      print()
      print(instance) # print parsed example messages with restored arrays
    

    【讨论】:

      【解决方案2】:

      tf.train.Feature 类仅在使用float_list 参数时支持列表(或一维数组)。根据您的数据,您可以尝试以下方法之一:

      1. 在将数组中的数据传递给 tf.train.Feature 之前将其展平:

        def _floats_feature(value):
          return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))
        

        请注意,您可能需要添加另一个功能来指示当您再次解析此数据时应如何对其进行整形(您可以为此使用int64_list 功能)。

      2. 将多维特征拆分为多个一维特征。例如,如果 c2d 包含 x 和 y 坐标的 N * 2 数组,则可以将该特征拆分为单独的 train/coord2d/xtrain/coord2d/y 特征,每个特征分别包含 x 和 y 坐标数据。

      【讨论】:

        【解决方案3】:

        关于Tfrecord的文档推荐使用serialize_tensor

        TFRecord and tf.train.Example

        注意:为简单起见,此示例仅使用标量输入。处理非标量特征的最简单方法是使用 tf.io.serialize_tensor 将张量转换为二进制字符串。字符串是张量流中的标量。使用 tf.io.parse_tensor 将二进制字符串转换回张量。

        2 行代码就可以解决问题:

        tensor = tf.convert_to_tensor(array)
        result = tf.io.serialize_tensor(tensor)
        

        【讨论】:

          猜你喜欢
          • 2018-10-29
          • 2020-01-01
          • 1970-01-01
          • 2020-02-21
          • 2013-07-08
          • 1970-01-01
          • 2015-08-23
          • 1970-01-01
          • 2023-03-14
          相关资源
          最近更新 更多