【问题标题】:Wavelet 2D Scattering transform of an input image输入图像的小波二维散射变换
【发布时间】:2020-02-12 15:34:31
【问题描述】:

我正在尝试对输入图像进行 2D 散射变换。当我运行以下代码时,我收到此错误:“过滤器不兼容乘法!”。有人可以帮忙吗? 谢谢!

import torch
from kymatio import Scattering2D
import numpy as np
import PIL
from PIL import Image

FILENAME = "add a png file path"
image = PIL.Image.open(FILENAME).convert("L")

a = np.array(image).astype(np.float32)
x = torch.from_numpy(a)
imageSize=x.shape

scattering = Scattering2D(J=2, shape=imageSize, L=8)
Sx = scattering.forward(x)

print(Sx.size()) 

【问题讨论】:

    标签: python numpy torch wavelet wavelet-transform


    【解决方案1】:

    对于具有相同宽度和高度(正方形,而不是矩形)的 1KB .png 来说似乎可以正常工作:

    import torch
    from kymatio import Scattering2D
    import numpy as np
    import PIL
    from PIL import Image
    
    FILENAME = "/path/to/dir/small_size_1_KB.png"
    image = PIL.Image.open(FILENAME).convert("L")
    
    a = np.array(image).astype(np.float64)
    x = torch.from_numpy(a)
    imageSize = x.shape
    
    scattering = Scattering2D(J=2, shape=imageSize, L=8)
    
    Sx = scattering.forward(x)
    
    print(Sx.size())
    

    输出

    torch.Size([81, 19, 19])
    

    您遇到的错误在此方法 (backend_torch.py) 中,应该与张量大小有关:

    def cdgmm(A, B, inplace=False):
        """
            Complex pointwise multiplication between (batched) tensor A and tensor B.
    
            Parameters
            ----------
            A : tensor
                input tensor with size (B, C, M, N, 2)
            B : tensor
                B is a complex tensor of size (M, N, 2)
            inplace : boolean, optional
                if set to True, all the operations are performed inplace
    
            Returns
            -------
            C : tensor
                output tensor of size (B, C, M, N, 2) such that:
                C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]
        """
        A, B = A.contiguous(), B.contiguous()
        if A.size()[-3:] != B.size():
            raise RuntimeError('The filters are not compatible for multiplication!')
    
        if not iscomplex(A) or not iscomplex(B):
            raise TypeError('The input, filter and output should be complex')
    
        if B.ndimension() != 3:
            raise RuntimeError('The filters must be simply a complex array!')
    
        if type(A) is not type(B):
            raise RuntimeError('A and B should be same type!')
    
    
        C = A.new(A.size())
    
        A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3))
        A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3))
    
        B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i)
        B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r)
    
        C[..., 0].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_r - A_i * B_i
        C[..., 1].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_i + A_i * B_r
    
        return C if not inplace else A.copy_(C)
    

    来源

    https://github.com/edouardoyallon/pyscatwave/blob/master/scatwave/utils.py

    【讨论】:

      猜你喜欢
      • 2016-07-31
      • 2012-10-30
      • 2014-07-01
      • 2012-05-23
      • 2014-02-28
      • 2016-03-07
      • 2016-02-18
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多