【发布时间】:2016-08-26 10:55:21
【问题描述】:
我编写了一个脚本来评估arr 的某些条目是否在check_elements 中。我的方法不比较单个条目,而是比较arr 内的整个向量。因此,脚本会检查 [8, 3]、[4, 5]、... 是否在 check_elements 中。
这是一个例子:
import numpy as np
# arr.shape -> (2, 3, 2)
arr = np.array([[[8, 3],
[4, 5],
[6, 2]],
[[9, 0],
[1, 10],
[7, 11]]])
# check_elements.shape -> (3, 2)
# generally: (n, 2)
check_elements = np.array([[4, 5], [9, 0], [7, 11]])
# rslt.shape -> (2, 3)
rslt = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)
for i, j in np.ndindex((arr.shape[0], arr.shape[1])):
if arr[i, j] in check_elements: # <-- condition is checked against
# the whole last dimension
rslt[i, j] = True
else:
rslt[i, j] = False
现在:
print(rslt)
...将打印:
[[False True False]
[ True False True]]
获取 I 使用的索引:
print(np.transpose(np.nonzero(rslt)))
...打印以下内容:
[[0 1] # arr[0, 1] -> [4, 5] -> is in check_elements
[1 0] # arr[1, 0] -> [9, 0] -> is in check_elements
[1 2]] # arr[1, 2] -> [7, 11] -> is in check_elements
如果我要检查单个值的条件,例如 arr > 3 或 np.where(...),那么这项任务会很简单且高效,但我不对单个值感兴趣。我想检查整个最后一个维度(或它的切片)的条件。
我的问题是:有没有更快的方法来达到同样的效果?我对矢量化尝试和np.where 之类的东西不能用于我的问题是否正确,因为它们总是对单个值而不是整个维度或该维度的切片进行操作?
【问题讨论】: