【发布时间】:2019-01-22 14:52:12
【问题描述】:
我正在尝试获取每行 2D torch.tensor 的中位数。但与使用标准数组或 numpy 相比,结果不是我所期望的
import torch
import numpy as np
from statistics import median
print(torch.__version__)
>>> 0.4.1
y = [[1, 2, 3, 5, 9, 1],[1, 2, 3, 5, 9, 1]]
median(y[0])
>>> 2.5
np.median(y,axis=1)
>>> array([2.5, 2.5])
yt = torch.tensor(y,dtype=torch.float32)
yt.median(1)[0]
>>> tensor([2., 2.])
【问题讨论】:
标签: python pytorch median torch