【问题标题】:How to extract patches from an image in pytorch?如何从pytorch中的图像中提取补丁?
【发布时间】:2020-08-22 02:59:24
【问题描述】:

我想从补丁大小为 128、步幅为 32 的图像中提取图像补丁,所以我有这段代码,但它给了我一个错误:

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

我得到的错误是:

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128

这是迄今为止我发现的唯一方法。但它给了我这个错误

【问题讨论】:

    标签: python image-processing pytorch


    【解决方案1】:

    x 的大小是[1, 3, height, width]。调用x.unfold(1, size, stride) 会尝试从尺寸为 3 的维度 1 创建大小为 128 的切片,因此它太小而无法创建任何切片。

    您不想创建跨维度 1 的切片,因为这些是图像的通道(在本例中为 RGB),并且对于所有补丁都需要保持原样。补丁只在图像的高度和宽度上创建。

    patches = x.unfold(2, size, stride).unfold(3, size, stride)
    

    生成的张量大小为[1, 3, num_vertical_slices, num_horizontal_slices, 128, 128]。您可以对其进行整形以组合切片以获得补丁列表,即[1, 3, num_patches, 128, 128]的大小:

    patches = patches.reshape(1, 3, -1, size, size)
    

    【讨论】:

      猜你喜欢
      • 2018-01-31
      • 1970-01-01
      • 2017-07-09
      • 2021-01-20
      • 2022-01-19
      • 2019-06-17
      • 2017-04-05
      • 2016-10-20
      • 1970-01-01
      相关资源
      最近更新 更多