【问题标题】:Calculate covariance of torch tensor (2d feature map)计算火炬张量的协方差(二维特征图)
【发布时间】:2021-02-09 17:52:44
【问题描述】:

我有一个形状为 (batch_size, number_maps, x_val, y_val) 的火炬张量。张量使用 sigmoid 函数进行归一化,因此在[0, 1] 范围内。我想找到每个地图的协方差,所以我想要一个形状为 (batch_size, number_maps, 2, 2) 的张量。据我所知,在 numpy 中没有 torch.cov() 函数。如何在不将其转换为 numpy 的情况下有效地计算协方差?

编辑:

def get_covariance(tensor):
bn, nk, w, h = tensor.shape
tensor_reshape = tensor.reshape(bn, nk, 2, -1)
x = tensor_reshape[:, :, 0, :]
y = tensor_reshape[:, :, 1, :]
mean_x = torch.mean(x, dim=2).unsqueeze(-1)
mean_y = torch.mean(y, dim=2).unsqueeze(-1)

xx = torch.sum((x - mean_x) * (x - mean_x), dim=2).unsqueeze(-1) / (h*w - 1)
xy = torch.sum((x - mean_x) * (y - mean_y), dim=2).unsqueeze(-1) / (h*w - 1)
yx = xy
yy = torch.sum((y - mean_y) * (y - mean_y), dim=2).unsqueeze(-1) / (h*w - 1)

cov = torch.cat((xx, xy, yx, yy), dim=2)
cov = cov.reshape(bn, nk, 2, 2)

return cov

我现在尝试了以下方法,但我很确定它不正确。

【问题讨论】:

    标签: python pytorch covariance torch


    【解决方案1】:

    你可以试试 Github 上推荐的功能:

    def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
        """Estimates covariance matrix like numpy.cov"""
        # ensure at least 2D
        if x.dim() == 1:
            x = x.view(-1, 1)
    
        # treat each column as a data point, each row as a variable
        if rowvar and x.shape[0] != 1:
            x = x.t()
    
        if ddof is None:
            if bias == 0:
                ddof = 1
            else:
                ddof = 0
    
        w = aweights
        if w is not None:
            if not torch.is_tensor(w):
                w = torch.tensor(w, dtype=torch.float)
            w_sum = torch.sum(w)
            avg = torch.sum(x * (w/w_sum)[:,None], 0)
        else:
            avg = torch.mean(x, 0)
    
        # Determine the normalization
        if w is None:
            fact = x.shape[0] - ddof
        elif ddof == 0:
            fact = w_sum
        elif aweights is None:
            fact = w_sum - ddof
        else:
            fact = w_sum - ddof * torch.sum(w * w) / w_sum
    
        xm = x.sub(avg.expand_as(x))
    
        if w is None:
            X_T = xm.t()
        else:
            X_T = torch.mm(torch.diag(w), xm).t()
    
        c = torch.mm(X_T, xm)
        c = c / fact
    
        return c.squeeze()
    

    https://github.com/pytorch/pytorch/issues/19037

    【讨论】:

      猜你喜欢
      • 2022-07-20
      • 2021-01-29
      • 2021-09-27
      • 2021-09-14
      • 2020-10-26
      • 1970-01-01
      • 2017-02-23
      • 2019-01-19
      • 2021-08-06
      相关资源
      最近更新 更多