【问题标题】:kth-value per row in pytorch?pytorch中每行的第k个值?
【发布时间】:2020-04-01 06:39:05
【问题描述】:

给定

import torch    
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]

意思是我想用不同的 k 对每一列进行操作。
不是kthvalue 也不是 topk 提供这样的 API。 有没有一种矢量化的方法?
备注 - 第 k 个值不是第 k 个索引中的值,而是第 k 个最小的元素。 Pytorch docs

torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
返回一个 namedtuple (values, indices),其中 values 是给定维度 dim 中输入张量每一行的第 k 个最小元素。 index 是找到的每个元素的索引位置。

【问题讨论】:

  • 你能澄清一下你是如何结束[1,5,1]的吗?不应该是[1,5,7]吗?另外,请注意 A 将无效,因为您没有为构造函数提供有效的数据类型。请在输入中包含 working minimal reproducible example。当然,您的预期输出仍然可以是伪代码。
  • 另外,请注意 Python 是零索引的,所以 k 应该是 torch.tensor([0, 1, 0])
  • ccl 请阅读关于 kthvalue 的 pytorch 文档,你输入 k 作为你想要的索引之上的一个
  • 用一个工作示例编辑,当然更正了所需的输出:)
  • 每行的第 k 个值通常称为列?

标签: pytorch vectorization


【解决方案1】:

假设您不需要原始矩阵的索引(如果需要,也只需对第二个返回值使用花哨的索引),您可以简单地对值进行排序(默认情况下按最后一个索引)并返回适当的值,如下所示:

def kth_smallest(tensor, indices):
    tensor_sorted, _ = torch.sort(tensor)
    return tensor_sorted[torch.arange(len(indices)), indices]

这个测试用例给你你想要的值:

tensor = torch.tensor(
    [[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)

print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]

【讨论】:

  • 它缺少 kthvalue 函数相对于 nlogn 的 O(n) 优势,但它是我错过的一个很好的选择,如果真的 kth_value 不能被矢量化。
  • 是的,很遗憾,但我知道 PyTorch 中没有 O(n) 预制解决方案。如果此操作至关重要,您可以编写自己的C++/CUDA extension 或集成numba/cupy(不确定这种方法是否以及有多复杂)。
猜你喜欢
  • 1970-01-01
  • 2021-03-04
  • 1970-01-01
  • 2022-08-23
  • 2021-01-06
  • 1970-01-01
  • 2014-03-16
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多