我有一种使用广播的方法,它更有效,因为它将在 C 中循环。基本上使数组具有可比性。一开始可能看起来很困难,但是一旦你知道它是如何工作的,这就是你将使用的全部[条件适用]....
import numpy as np
x = np.array([[[255, 0, 0],[ 0, 255, 0], [ 0, 255, 0], [ 0, 255, 0]], [[255, 0, 0],[ 0, 255, 0], [ 0, 255, 0], [ 0, 255, 0]]])
print(x.shape)
# (2, 4, 3)
color_list = np.array([[255,0,0], [255,255,0], [255,0,255]])
print(color_list.shape)
# (3, 3)
# make array compatible
x = x[:, :, np.newaxis, :]
### Analogy for interpreting broadcasting
# Here repeating is for analogy and does not mean it will allocate new copy of memory
# element wise comparision, possibler due to broadcast
# shape of x is (2, 4, 1, 3)
# By broadcasting conceptually x will be repeated along axis=2 this will make (2, 4, 3, 3)
# color_list will be repeated over (2, 4) making it (2, 4, 3, 3) and they will have same shape also the final shape after == will be (2, 4, 3, 3)
f1 = np.all(x[:, :, np.newaxis, :] == color_list, axis=3)
#array([[[ True, False, False],
# [False, False, False],
# [False, False, False],
# [False, False, False]],
#
# [[ True, False, False],
# [False, False, False],
# [False, False, False],
# [False, False, False]]])
mask = np.any(f1, axis=2)
我们有形状为(W, H, C) == (2, 4, 3)的目标数组,我们需要找到大小为3的color_list == [[255,0,0], [255,255,0], [255,0,255]]数组
理想情况下,我们想做cross comparison,我的意思是如果一侧有M,另一侧有N 条目,那么经过一些操作后,我们会得到M * N 的结果。这似乎每 N 次重复 M 个条目并进行比较。虽然乍一看这似乎不可能,但 numpy 提供了广播。这将在概念上重复您的 for 循环之类的条目(实际上它的内存效率很高,它不会创建实际的副本)
所以我们需要广播,以使这两个数组兼容,但它们不兼容,正如broadcasting rules 中提到的那样,形状从右到左比较,它们必须相同或其中一个必须为 1。
color_list 形状为 (3, 3),x 形状为 (2, 4, 3)。我们将在 x 中添加新轴以使其与广播兼容,即x[:, :, np.newaxis, :],其形状为 (2, 4, 1, 3)。现在两者都兼容,我们可以比较。
比较最后一个轴,它是颜色通道轴 = 3,然后在最后一个轴上比较他的轴 = 2 将给出 (W, H) 布尔值,如果颜色通道三元组存在于color_list 中,则每个条目表示 True .
这种技术与给定两个点数组时可用于计算距离矩阵的技术完全相同,如Fast way to calculate min distance between two numpy arrays of 3D points