【问题标题】:Pytorch median - is it bug or am I using it wrongPytorch 中位数 - 是错误还是我用错了
【发布时间】:2019-01-22 14:52:12
【问题描述】:

我正在尝试获取每行 2D torch.tensor 的中位数。但与使用标准数组或 numpy 相比,结果不是我所期望的

import torch
import numpy as np
from statistics import median

print(torch.__version__)
>>> 0.4.1

y = [[1, 2, 3, 5, 9, 1],[1, 2, 3, 5, 9, 1]]
median(y[0])
>>> 2.5

np.median(y,axis=1)
>>> array([2.5, 2.5])

yt = torch.tensor(y,dtype=torch.float32)
yt.median(1)[0]
>>> tensor([2., 2.])

【问题讨论】:

    标签: python pytorch median torch


    【解决方案1】:

    看起来这是本期提到的 Torch 的预期行为

    https://github.com/pytorch/pytorch/issues/1837
    https://github.com/torch/torch7/pull/182

    上面链接中提到的推理

    Median 在奇数个元素的情况下返回“中间”元素,否则在中间元素之前返回一个(也可以做另一个约定来取两个中间元素的平均值,但这将是两倍以上贵,所以我决定买这个)。

    【讨论】:

    • 谢谢,那我就用numpy。这没什么大不了的。
    • 我可以通过其他函数模拟真实均值吗?
    • x.quantile(q=0.5) 提供与np.median(x) 相同的行为
    【解决方案2】:

    您可以使用 pytorch 模拟 numpy 中位数:

    import torch
    import numpy as np
    y =[1, 2, 3, 5, 9, 1]
    print("numpy=",np.median(y))
    print(sorted([1, 2, 3, 5, 9, 1]))
    yt = torch.tensor(y,dtype=torch.float32)
    ymax = torch.tensor([yt.max()])
    print("torch=",yt.median())
    print("torch_fixed=",(torch.cat((yt,ymax)).median()+yt.median())/2.)
    

    【讨论】:

      猜你喜欢
      • 2011-01-11
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-09-06
      • 1970-01-01
      • 2023-04-08
      • 2019-07-04
      • 1970-01-01
      相关资源
      最近更新 更多