【发布时间】:2021-04-06 02:34:54
【问题描述】:
我在我的 pytorch 文件中间创建了一个 zero_mask2 函数(复制如下)。但是,它太慢了。所以,我正在寻找更好的方法。
首先,让我解释一下下面的函数的要点。
输入
-
all_idx_before'([batch_size, number of points, 20]的维度):batch_size是深度学习社会中众所周知的mini batch的定义。在每一批次中,都有很多类似1024的点。而且,在每个点中,一个点都有 20 个值。 -
idx([batch_size, number of points, 5]的维度):批次和点数与之前相同。但是,在这种情况下,一个点只有 5 个值。 -
p:任意数字,例如7。
玩具示例
假设all_idx_before 中的一个点具有(9, 10, 11, 12, 6, 7, 8, 14, 2, 3, 4, 5, 18, 19, 15, 16, 17, 13, 1, 20)。
idx 中的一个点具有(10, 6, 2, 4, 18)。
在这个设置中,我想在idx 中找到每个值的索引。例如,idx 中的(10) 是all_idx_before 中的第二个元素,idx 中的(6) 是all_idx_before 中的第五个元素...
因此,这一点的潜在输出将类似于(1, 5, 8, 10, 12)、('second' becomes '1' because of number system in python.)。
我在每个批次和每个点都执行此逻辑。所以,我用于迭代。但这太慢了。有没有办法并行执行此操作?我使用 NumPy 格式。但如果它更好,我可以使用张量。
def zero_mask2(idx, all_idx_before, p):
size_batch = len(idx)
size_points = len(idx[0])
size_neighbor = len(idx[0][0])
mask = torch.empty(size_batch, size_points, size_neighbor)
for i in range(size_batch):
for j in range(size_points):
for l in range(size_neighbor):
mask[i][j][l] = np.where(all_idx_before[i][j] == idx[i][j][l])[0][0]
mask_t = mask <= p # smaller than p or same : true , bigger than p : false
return mask_t
【问题讨论】:
-
您是否保证
idx中的每个值都会在all_idx_before中准确找到一次? -
是的!这是有保证的,而且是唯一的。
标签: python numpy indexing pytorch