【问题标题】:How to get shape of tensor from TFRecordDataset如何从 TFRecordDataset 获取张量的形状
【发布时间】:2019-01-30 19:14:03
【问题描述】:

我的训练 TFRecord 中写入了以下功能:

feature = {'label': _int64_feature(gt),
           'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
           'height': _int64_feature(h),
           'width': _int64_feature(w)}

我正在阅读它:

train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)

而我的 parse_func 看起来像这样:

def parse_func(ex):
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               'height': tf.FixedLenFeature([], tf.int64),
               'width': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(ex, features=feature)
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    im_shape = tf.stack([width, height])
    image = tf.reshape(image, im_shape)
    label = tf.cast(features['label'], tf.int32)
    return image, label

现在,我想获得 imagelabel 的形状,例如:

image.get_shape().as_list()

打印
[无,无,无]
而不是
[None, 224, 224] (图片大小(batch, width, height))

有没有什么函数可以给我这些张量的大小?

【问题讨论】:

    标签: tensorflow tensor tensorflow-datasets tfrecord


    【解决方案1】:

    由于您的地图函数“parse_func”作为操作是图形的一部分,并且它不知道您输入的固定大小和先验标签,因此使用 get_shape() 不会返回预期的固定形状。

    如果您的图像,标签形状是固定的,作为一个 hack,您可以尝试重塑您的图像,具有已知大小的标签(这实际上不会做任何事情,但会显式设置输出张量)。

    例如。 image = tf.reshape(image, [224,224])

    有了这个,你应该能够得到预期的 get_shape() 结果。

    【讨论】:

    • 硬编码目前确实解决了这个问题,但我正在寻找更通用的解决方案。
    【解决方案2】:

    另一种解决方案是存储编码的图像而不是解码的原始字节,这样你只需要在读取 tfrecords 时使用 tensorflow 将图像解码回来,这也将帮助你节省存储空间,这样你就可以得到从张量中得到图像形状。

        # Load your image as you would normally do then do:
    
        # Convert the image to raw bytes.
        img_bytes = tf.io.encode_jpeg(img).numpy()
    
        # Create a dict with the data we want to save in the
        # TFRecords file. You can add more relevant data here.
        data = \
        {'image': wrap_bytes(img_bytes),
         'label': wrap_int64(label)}
    
        # Wrap the data as TensorFlow Features.
        feature = tf.train.Features(feature=data)
    
        # Wrap again as a TensorFlow Example.
        example = tf.train.Example(features=feature)
    
        # Serialize the data.
        serialized = example.SerializeToString()
                
        # Write the serialized data to the TFRecords file.
        writer.write(serialized) 
    

    然后阅读你可以使用:

        features = \
            {
                'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64)            
            }
    
        # Parse the serialized data so we get a dict with our data.
        parsed_example = tf.io.parse_single_example(serialized=serialized,
                                                 features=features)
    
        # Get the image as raw bytes.
        image_raw = parsed_example['image']
    
        # Decode the raw bytes so it becomes a tensor with type.
        image = tf.io.decode_jpeg(image_raw)
        
        image = tf.cast(image, tf.float32) # optional
        
        # Get the label associated with the image.
        label = parsed_example['label']
        
        # The image and label are now correct TensorFlow types.
        return image, label
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-05-21
      • 2021-06-12
      • 1970-01-01
      • 2023-01-13
      • 2018-03-31
      • 1970-01-01
      • 2019-09-01
      相关资源
      最近更新 更多