【问题标题】:find index of certain value in 3 dimension array in parallel way以并行方式在 3 维数组中查找某个值的索引
【发布时间】:2021-04-06 02:34:54
【问题描述】:

我在我的 pytorch 文件中间创建了一个 zero_mask2 函数(复制如下)。但是,它太慢了。所以,我正在寻找更好的方法。

首先,让我解释一下下面的函数的要点。

输入

  • all_idx_before'[batch_size, number of points, 20]的维度):batch_size是深度学习社会中众所周知的mini batch的定义。在每一批次中,都有很多类似1024的点。而且,在每个点中,一个点都有 20 个值。

  • idx[batch_size, number of points, 5]的维度):批次和点数与之前相同。但是,在这种情况下,一个点只有 5 个值。

  • p:任意数字,例如7

玩具示例

假设all_idx_before 中的一个点具有(9, 10, 11, 12, 6, 7, 8, 14, 2, 3, 4, 5, 18, 19, 15, 16, 17, 13, 1, 20)

idx 中的一个点具有(10, 6, 2, 4, 18)

在这个设置中,我想在idx 中找到每个值的索引。例如,idx 中的(10)all_idx_before 中的第二个元素,idx 中的(6)all_idx_before 中的第五个元素...

因此,这一点的潜在输出将类似于(1, 5, 8, 10, 12)('second' becomes '1' because of number system in python.)

我在每个批次和每个点都执行此逻辑。所以,我用于迭代。但这太慢了。有没有办法并行执行此操作?我使用 NumPy 格式。但如果它更好,我可以使用张量。

def zero_mask2(idx, all_idx_before, p):

size_batch = len(idx)
size_points = len(idx[0])
size_neighbor = len(idx[0][0])

mask = torch.empty(size_batch, size_points, size_neighbor)

for i in range(size_batch):
 for j in range(size_points):
  for l in range(size_neighbor):
   mask[i][j][l] = np.where(all_idx_before[i][j] == idx[i][j][l])[0][0] 

mask_t = mask <= p # smaller than p or same : true , bigger than p : false
return mask_t

【问题讨论】:

  • 您是否保证idx 中的每个值都会在all_idx_before准确找到一次?
  • 是的!这是有保证的,而且是唯一的。

标签: python numpy indexing pytorch


【解决方案1】:

您可以比较和使用argmax

tmp = all_idx_before[..., None, :] == idx[..., None]  # compare along additional last dimension
mask = torch.argmax(tmp, dim=-1)

【讨论】:

  • 哦!感谢您的回答。我会试试这个,让你知道。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2019-07-05
  • 2015-01-05
  • 1970-01-01
  • 2021-03-03
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多