【问题标题】:How to efficiently compute the heat map of two Gaussian distribution in Python?如何在 Python 中高效计算两个高斯分布的热图?
【发布时间】:2017-12-10 05:41:14
【问题描述】:

我正在尝试生成一个热图,其中像素值由两个独立的二维高斯分布控制。让它们分别为 Kernel1 (muX1, muY1, sigmaX1, sigmaY1) 和 Kernel2 (muX2, muY2, sigmaX2, sigmaY2)。更具体地说,每个内核的长度是其标准差的三倍。第一个内核的 sigmaX1 = sigmaY1,第二个内核的 sigmaX2

我尝试了以下两种方法,结果都不尽如人意。有人可以给我一些建议吗?

方法1:

遍历地图上的所有像素值对 (i, j) 并计算由 I(i,j) = P(i, j | Kernel1, Kernel2) = 1 给出的 I(i,j) 的值 - (1 - P(i, j | Kernel1)) * (1 - P(i, j | Kernel2))。然后我得到了以下结果,这在平滑度方面是不错的。但是在我的电脑上运行需要10秒,太慢了。

代码:

def genDensityBox(self, height, width, muY1, muX1, muY2, muX2, sigmaK1, sigmaY2, sigmaX2):
    densityBox = np.zeros((height, width))
    for y in range(height):
        for x in range(width):
            densityBox[y, x] += 1. - (1. - multivariateNormal(y, x, muY1, muX1, sigmaK1, sigmaK1)) * (1. - multivariateNormal(y, x, muY2, muX2, sigmaY2, sigmaX2))
    return densityBox

def multivariateNormal(y, x, muY, muX, sigmaY, sigmaX):
    return norm.pdf(y, loc=muY, scale=sigmaY) * norm.pdf(x, loc=muX, scale=sigmaX)

方法2:

分别生成对应于两个内核的两个图像,然后使用特定的 alpha 值将它们混合在一起。每个图像都是通过两个一维高斯滤波器的外积生成的。然后我得到了以下结果,非常粗略。但是这种方法的优点是由于在两个向量之间使用了外积,因此速度非常快。

由于第一个很慢,第二个很粗糙,我试图找到一种新的方法,它可以同时实现良好的平滑度和低时间复杂度。有人可以帮帮我吗?

谢谢!

对于第二种方法,可以轻松生成二维高斯图,如here所述:

def gkern(self, sigmaY, sigmaX, yKernelLen, xKernelLen, nsigma=3):
    """Returns a 2D Gaussian kernel array."""
    yInterval = (2*nsigma+1.)/(yKernelLen)
    yRow = np.linspace(-nsigma-yInterval/2.,nsigma+yInterval/2.,yKernelLen + 1)
    kernelY = np.diff(st.norm.cdf(yRow, 0, sigmaY))
    xInterval = (2*nsigma+1.)/(xKernelLen)
    xRow = np.linspace(-nsigma-xInterval/2.,nsigma+xInterval/2.,xKernelLen + 1)
    kernelX = np.diff(st.norm.cdf(xRow, 0, sigmaX))    
    kernelRaw = np.sqrt(np.outer(kernelY, kernelX))
    kernel = kernelRaw / (kernelRaw.sum())
    return kernel

【问题讨论】:

    标签: python numpy heatmap gaussian


    【解决方案1】:

    您的方法很好,除了您不应该循环 norm.pdf 而只是推送您希望评估内核的所有值,然后将输出重塑为所需的图像形状。

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import multivariate_normal
    
    # create 2 kernels
    m1 = (-1,-1)
    s1 = np.eye(2)
    k1 = multivariate_normal(mean=m1, cov=s1)
    
    m2 = (1,1)
    s2 = np.eye(2)
    k2 = multivariate_normal(mean=m2, cov=s2)
    
    # create a grid of (x,y) coordinates at which to evaluate the kernels
    xlim = (-3, 3)
    ylim = (-3, 3)
    xres = 100
    yres = 100
    
    x = np.linspace(xlim[0], xlim[1], xres)
    y = np.linspace(ylim[0], ylim[1], yres)
    xx, yy = np.meshgrid(x,y)
    
    # evaluate kernels at grid points
    xxyy = np.c_[xx.ravel(), yy.ravel()]
    zz = k1.pdf(xxyy) + k2.pdf(xxyy)
    
    # reshape and plot image
    img = zz.reshape((xres,yres))
    plt.imshow(img); plt.show()
    

    这种方法应该不会花费太长时间:

    In [26]: %timeit zz = k1.pdf(xxyy) + k2.pdf(xxyy)
    1000 loops, best of 3: 1.16 ms per loop
    

    【讨论】:

      【解决方案2】:

      根据 Paul 的回答,我制作了一个函数来制作高斯热图,以高斯中心作为输入(这可能对其他人有所帮助):

      import numpy as np
      import matplotlib.pyplot as plt
      from scipy.stats import multivariate_normal
      
      def points_to_gaussian_heatmap(centers, height, width, scale):
          gaussians = []
          for y,x in centers:
              s = np.eye(2)*scale
              g = multivariate_normal(mean=(x,y), cov=s)
              gaussians.append(g)
      
          # create a grid of (x,y) coordinates at which to evaluate the kernels
          x = np.arange(0, width)
          y = np.arange(0, height)
          xx, yy = np.meshgrid(x,y)
          xxyy = np.stack([xx.ravel(), yy.ravel()]).T
          
          # evaluate kernels at grid points
          zz = sum(g.pdf(xxyy) for g in gaussians)
      
          img = zz.reshape((height,width))
          return img
      
      W = 800  # width of heatmap
      H = 400  # height of heatmap
      SCALE = 64  # increase scale to make larger gaussians
      CENTERS = [(100,100), 
                 (100,300), 
                 (300,100)] # center points of the gaussians
      
      img = points_to_gaussian_heatmap(CENTERS, H, W, SCALE)
      
      plt.imshow(img); plt.show()
      

      【讨论】:

        猜你喜欢
        • 2019-07-15
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2016-01-15
        • 2017-11-24
        • 2020-10-29
        • 2015-08-28
        • 1970-01-01
        相关资源
        最近更新 更多