【问题标题】:How to extract patches of same size from images of different size and batch them together with tensorflow dataset api?如何从不同大小的图像中提取相同大小的块并将它们与 tensorflow 数据集 api 一起批处理?
【发布时间】:2018-12-16 06:15:27
【问题描述】:

我正在尝试为一组不同大小的图像制作一个 tensorflow 数据集 api(tf 版本 1.8)。为此,我从图像中提取相同大小的块并将其馈送到我的神经网络。

问题出在 tf.extract_patches_from_images 中,补丁存储在通道维度中。由于每个图像的大小不同,因此每个图像的补丁数量不同。因此,每个生成的图像的形状都是不同的。因此,我无法使用 tf dataset api 将它们批处理在一起。

有人可以建议对我的以下 modify_image 函数进行更改以解决该问题吗? 我想将补丁分成不同的图像,然后将它们批处理在一起会起作用。但我不明白该怎么做。

我想扫描整个图像,因此随机选择相同数量的补丁对我不起作用。

def modify_image(image):
'''add preprocessing functions here'''
    image = tf.expand_dims(image,0)
    image = tf.extract_image_patches(
        image,
        ksizes=[1,patch_size,patch_size,1],
        strides=[1,patch_size,patch_size,1],
        rates=[1,1,1,1],
        padding='SAME',
        name=None
    )
    image = tf.reshape(image,shape=[-1,patch_size,patch_size,1])

return image;

def parse_function(image,labels):
    image= tf.read_file(image)
    image = tf.image.decode_image(image)
    labels = tf.read_file(labels)
    labels = tf.image.decode_image(labels)
    image = modify_image(image)
    labels = modify_image(labels)
    return image,labels


def list_files(directory):
    files = glob.glob(directory)
    return files

def load_dataset(img_dir,labels_dir):
    images = list_files(img_dir)
    images = tf.constant(images)
    labels = list_files(labels_dir)
    labels = tf.constant(labels)

    dataset = tf.data.Dataset.from_tensor_slices((images,labels))
    dataset = dataset.map(parse_function)
    return dataset




def make_batches(home_dir,img_dir,labels_dir,batch_size):

    img_dir = home_dir + img_dir
    labels_dir = home_dir +labels_dir

    dataset = load_dataset(img_dir,labels_dir)
    batched_dataset = dataset.batch(batch_size)
    return batched_dataset  

【问题讨论】:

    标签: python tensorflow machine-learning deep-learning computer-vision


    【解决方案1】:

    tf.contrib.data.unbatch() 转换在这里可能会有所帮助,因为它可以将单个图像中的补丁分成不同的元素:

    dataset = tf.data.Dataset.from_tensor_slices((images,labels))
    dataset = dataset.map(parse_function)
    patches_dataset = dataset.apply(tf.contrib.data.unbatch())
    batched_dataset = dataset.batch(batch_size)
    

    请注意,要使tf.contrib.data.unbatch() 起作用,图像中的补丁数必须与labels 中的元素/行数相匹配。例如,如果每个补丁都应该获得相同的标签,您可以通过将parse_function() 修改为tf.tile() 标签适当的次数来实现:

    def parse_function(images, labels):
      # ...
      return image, tf.tile([labels], tf.shape(image)[0:1])
    

    【讨论】:

    • 是的,这修复了它。谢谢,
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-11-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-07-12
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多