【问题标题】:How can I efficiently modify/make pairwise distance matrix?如何有效地修改/制作成对距离矩阵?
【发布时间】:2020-02-03 13:59:02
【问题描述】:
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)
    dist = (x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)))
    return dist

上面是用于计算x(M个点)和y(N个点)之间的成对距离矩阵(M*N)的代码。

当两点之间的距离大于特定值'T'时,我希望制作具有0元素的成对距离矩阵。

在这种情况下,我该怎么办?

谢谢

【问题讨论】:

    标签: pytorch tensor


    【解决方案1】:

    我想你在找torch.where:

    new_dist = troch.where(dist > T, dist, 0.)
    

    【讨论】:

    • 没错!非常感谢
    猜你喜欢
    • 2020-05-16
    • 2020-09-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-03-10
    • 2017-08-02
    • 1970-01-01
    • 2016-06-16
    相关资源
    最近更新 更多