一个大问题是输入真的是巨大(~15.6 GiB)。另一个是在最坏的情况下它被传输了多达 400 次(导致在 RAM 中写入多达 6240 GiB)。问题是重叠区域写了多次。
更好的解决方案是遍历前两个维度(“图像”之一)以找到必须复制的边界框,如 @dankal444 所建议的那样。这类似于基于Z-buffer 的算法在计算机图形学中所做的。
基于此,一个更好的解决方案是使用scanline-rendering 算法。在您的情况下,该算法比传统算法简单得多,因为您使用的是边界框而不是复杂的多边形。对于每条扫描线(此处为 2000 条),您可以快速过滤写入扫描线的边界框,然后对其进行迭代。对于您的简单情况,经典算法有点太复杂了。对于每条扫描线,迭代过滤的边界框并覆盖它们在每个像素中的索引就足够了。此操作可以使用 Numba 在 parallel 中完成。它非常快,因为计算主要在 CPU 缓存中执行。
最后的操作是根据之前的索引执行实际的数据写入(仍然使用 Numba 并行)。此操作仍然是内存绑定,但输出数组只写入一次(最坏情况下只会写入 15.6 GiB 的 RAM,float32 会写入 7.8 GiB项目)。在大多数机器上,这应该只需要几分之一秒。如果这还不够,您可以尝试使用专用 GPU,因为 GPU RAM 通常比主 RAM 快得多(通常快一个数量级)。
这里是实现:
# Assume the last dimension of `large` and `embeddings` is contiguous in memory
@nb.njit('void(float32[:,:,::1], float32[:,::1], int_[:,::1], int_[:,::1])', parallel=True)
def fastFill(large, embeddings, boxes_y, boxes_x):
n, m, l = large.shape
boxCount = embeddings.shape[0]
assert embeddings.shape == (boxCount, l)
assert boxes_y.shape == (boxCount, 2)
assert boxes_x.shape == (boxCount, 2)
imageBoxIds = np.full((n, m), -1, dtype=np.int16)
for y in nb.prange(n):
# Filtering -- A sort is not required since the number of bounding-box is small
boxIds = np.where((boxes_y[:,0] <= y) & (y < boxes_y[:,1]))[0]
for k in boxIds:
lower, upper = boxes_x[k]
imageBoxIds[y, lower:upper] = k
# Actual filling
for y in nb.prange(n):
for x in range(m):
boxId = imageBoxIds[y, x]
if boxId >= 0:
large[y, x, :] = embeddings[boxId]
这是基准:
large = np.zeros((1000, 750, 700), dtype=np.float32) # 8 times smaller in memory
boxes_y = np.cumsum(np.random.randint(0, large.shape[0]//2, size=(400, 2)), axis=1)
boxes_x = np.cumsum(np.random.randint(0, large.shape[1]//2, size=(400, 2)), axis=1)
embeddings = np.random.rand(400, 700).astype(np.float32)
# Called many times
for i in range(400):
large[boxes_y[i][0]:boxes_y[i][1], boxes_x[i][0]:boxes_x[i][1]] = embeddings[i]
# Called many times
fastFill(large, embeddings, boxes_y, boxes_x)
这是我机器上的结果:
Initial code: 2.71 s
Numba (sequential): 0.13 s
Numba (parallel): 0.12 s (x22 times faster than the initial code)
请注意,由于virtual zero-mapped memory,第一次运行速度较慢。在这种情况下,Numba 版本的速度仍然快 10 倍左右。