【问题标题】:How to access tensor shape inside map function如何在地图函数中访问张量形状
【发布时间】:2020-11-22 05:05:46
【问题描述】:

尽管通过image.shape[0] and image.shape[1] 访问时,我需要访问图像形状来执行增强管道,但我无法执行增强,因为它输出我的张量形状为无。

相关问题:How to access Tensor shape in .map?

如果有人可以提供帮助,不胜感激。

parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
augmented_dataset = parsed_dataset.map(augment_pipeline) 
augmented_dataset = augmented_dataset.unbatch()

映射函数

""" 
    Returns:
      5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
"""
def augment_pipeline(original_image,label):
  central_crop = lambda image: tf.image.central_crop(image,0.5)
  corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
                                                sizes=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1], # 50% of the image's height and width
                                                rates=[1, 1, 1, 1],
                                                strides=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1],
                                                padding="SAME")
  reshaped_patches = tf.reshape(corner_crops(original_image), [-1,int(0.5*original_image.shape[0]),int(0.5*original_image.shape[1]),3])
  images = tf.concat([reshaped_patches,tf.expand_dims(central_crop(original_image),axis=0)],axis=0)
  label = tf.reshape(label,[1,1])
  labels = tf.tile(label,[5,1])
  return images,labels

【问题讨论】:

    标签: tensorflow tensorflow2.x tf.data.dataset


    【解决方案1】:

    经过进一步研究,我能够按照建议的heretf.shape(image)[0] here 使用py_func 进行管理。

    代码:

    """ 
        Returns:
          5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
    """
    def augment_pipeline(original_image,label):
      height  = int(tf.shape(original_image)[0].numpy() * 0.5)  # 50% of the image's height and width
      width = int(tf.shape(original_image)[1].numpy() * 0.5)
      central_crop = lambda image: tf.image.central_crop(image,0.5)
      corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
                                                    sizes=[1, height, width, 1],
                                                    rates=[1, 1, 1, 1],
                                                    strides=[1, height, width, 1],
                                                    padding="SAME")
    
                                                  .
                                                  .
                                                  .
    

    然后我们使用 py_func 来允许访问 map 函数中的 numpy 值:

    parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
    augmented_dataset = parsed_dataset.map(lambda image,label: tf.py_function(func=augment_pipeline,
                                                                              inp=[image,label],
                                                                              Tout=[tf.float32,tf.int64])) 
    augmented_dataset = augmented_dataset.unbatch()
    

    【讨论】:

    【解决方案2】:

    每个 Dataset 对象都是可迭代的。现在 Dataset 对象可以是批处理形式或非批处理形式。我会告诉你如何在这两种情况下获得它们的元素形状。

    案例 1. 数据集对象为非批处理形式。

    方法一、使用iter消费其元素

    it = iter(dataset)
    element = next(it)
    image,label = element
    ## element is a tuple
    

    方法2.使用take

    element = dataset.take(1)
    image,label = element
    # element is a tuple
    

    案例 2。当数据集被批处理时。现在我假设数据集包含 (image,label) 元组

    方法一、使用iter

    it = iter(dataset)
    batch = next(it)
    images,labels = batch
    ## batch is a tuple check it using type(batch)
    

    方法2.使用take

    batch = dataset.take(1)
    ## Note here each element of the dataset is a batch and each batch contains some number of 
    ## (image,label) tuples
    batch = next(iter(batch))
    images,labels = batch
    ## batch is again a tuple
    

    【讨论】:

    • 我很欣赏这个答案,但我想这与如何迭代 tf 数据集有关,这样我看不到如何使用您的代码访问数据形状。
    猜你喜欢
    • 2020-10-04
    • 1970-01-01
    • 2021-12-01
    • 2018-12-13
    • 2019-07-19
    • 1970-01-01
    • 1970-01-01
    • 2020-12-07
    • 1970-01-01
    相关资源
    最近更新 更多