【发布时间】:2018-12-09 23:05:45
【问题描述】:
我在将一些代码从 tensorflow 移植到 pytorch 时遇到了一些麻烦。
所以我有一个尺寸为 10x30 的矩阵,代表 10 个示例,每个示例具有 30 个特征。然后我有另一个尺寸为 10x5 的矩阵,其中包含第一个矩阵中每个示例的 5 个最接近示例的索引。我想使用第二个矩阵中包含的索引来“收集”第一个矩阵中每个示例的 5 个最接近的示例,从而为我留下一个形状为 10x5x30 的 3d 张量。
在 tensorflow 中,这是通过 tf.gather(matrix1, matrix2) 完成的。有谁知道我如何在 pytorch 中做到这一点?
【问题讨论】:
-
我不太确定它在 TF 中是如何完成的,但你检查过
torch.gather吗?
标签: tensorflow pytorch