【发布时间】:2020-04-01 06:39:05
【问题描述】:
给定
import torch
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]
意思是我想用不同的 k 对每一列进行操作。
不是kthvalue 也不是 topk 提供这样的 API。
有没有一种矢量化的方法?
备注 - 第 k 个值不是第 k 个索引中的值,而是第 k 个最小的元素。 Pytorch docs
torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
返回一个 namedtuple (values, indices),其中 values 是给定维度 dim 中输入张量每一行的第 k 个最小元素。 index 是找到的每个元素的索引位置。
【问题讨论】:
-
你能澄清一下你是如何结束
[1,5,1]的吗?不应该是[1,5,7]吗?另外,请注意A将无效,因为您没有为构造函数提供有效的数据类型。请在输入中包含 working minimal reproducible example。当然,您的预期输出仍然可以是伪代码。 -
另外,请注意 Python 是零索引的,所以
k应该是torch.tensor([0, 1, 0])。 -
ccl 请阅读关于 kthvalue 的 pytorch 文档,你输入 k 作为你想要的索引之上的一个
-
用一个工作示例编辑,当然更正了所需的输出:)
-
每行的第 k 个值通常称为列?
标签: pytorch vectorization