【问题标题】:pytorch equivalent tf.gatherpytorch 等效 tf.gather
【发布时间】:2018-12-09 23:05:45
【问题描述】:

我在将一些代码从 tensorflow 移植到 pytorch 时遇到了一些麻烦。

所以我有一个尺寸为 10x30 的矩阵,代表 10 个示例,每个示例具有 30 个特征。然后我有另一个尺寸为 10x5 的矩阵,其中包含第一个矩阵中每个示例的 5 个最接近示例的索引。我想使用第二个矩阵中包含的索引来“收集”第一个矩阵中每个示例的 5 个最接近的示例,从而为我留下一个形状为 10x5x30 的 3d 张量。

在 tensorflow 中,这是通过 tf.gather(matrix1, matrix2) 完成的。有谁知道我如何在 pytorch 中做到这一点?

【问题讨论】:

  • 我不太确定它在 TF 中是如何完成的,但你检查过torch.gather吗?

标签: tensorflow pytorch


【解决方案1】:

这个怎么样?

matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]

它使用整数数组索引的技巧。

【讨论】:

  • 它导致张量的形状为 (10, 5, 30).... 它真的适用于 OP...吗?
【解决方案2】:

我有一个场景,我必须在整数数组上应用gather()

考试-01

torch.Tensor().gather(dim, input_tensor)
# here,
#   input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = my_tensor.gather(0, input_tensor) # 0 -> is the dimension

考试-02

torch.gather(param_tensor, dim, input_tensor)
# here,
#   input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = torch.gather(my_tensor, 0, input_tensor) # 0 -> is the dimension

【讨论】:

    猜你喜欢
    • 2020-10-30
    • 2021-04-04
    • 2021-08-01
    • 2021-03-16
    • 2022-10-12
    • 2019-01-02
    • 2020-03-28
    • 2020-09-23
    • 2021-11-29
    相关资源
    最近更新 更多