【问题标题】:Finding the index of a numpy array in a list在列表中查找 numpy 数组的索引
【发布时间】:2017-05-07 10:06:33
【问题描述】:
import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]

我希望

foo.index( np.array([[1,2,3]]) ) 

返回

2

但我得到了

ValueError:具有多个元素的数组的真值是 模糊的。使用 a.any() 或 a.all()

有什么比我目前的解决方案更好的吗?似乎效率低下。

def find_index_of_array(list, array):
    for i in range(len(list)):
        if np.all(list[i]==array):
            return i

find_index_of_array(foo, np.array([[1,2,3]]) )
# 2

【问题讨论】:

  • 非常非常有趣。
  • 非同质列表只是一个例子,还是你真的有一个包含许多不同类型的列表?
  • @mgilson 只是我做作的例子。我正在使用等维度的 numpy 数组列表
  • 您能否重铸此以使用is 而不是== 进行比较?
  • @MadPhysicist -- 事实证明,如果你使用相同的数组,numpy python 做正确的事。 lst = [array]; lst.find(array) # 0。这样做的原因是因为is 检查速度非常快(指针比较),并且由于在列表中搜索您已经引用过的内容是相当普遍的,python 在回退到 @ 之前会进行 is 比较987654330@比较。

标签: python arrays list numpy


【解决方案1】:

这里出错的原因很明显是因为numpy的ndarray覆盖==返回一个数组而不是布尔值。

AFAIK,这里没有简单的解决方案。只要
np.all(val == array) 位有效,以下将有效。

next((i for i, val in enumerate(lst) if np.all(val == array)), -1)

该位是否有效主要取决于数组中的其他元素是什么以及它们是否可以与 numpy 数组进行比较。

【讨论】:

  • 请注意,这与list.index 不同,后者在没有此类项目时会抛出ValueError。但是好的和简单的解决方案!
  • @MSeifert -- 是的。我想要一个更像str.find 的 API。如果你想要一个例外,那么你可以删除-1 部分(仅将生成器传递给next)。在这种情况下,如果找不到 StopIteration,您将收到它。
【解决方案2】:

这个怎么样?

arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)

# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]

# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]

输出:

[2]

注意:即使foo 是一个列表,这也将起作用。我只是把它作为numpy 数组放在这里。

【讨论】:

  • OP 在 cmets 中说真正的列表只会包含数组,所以这充其量只是一个预处理步骤。
  • 从技术上讲,第一种方法只有在列表中只有一个数组并且您预先知道该数组就是您要查找的数组时才能可靠地工作。
  • 是的。 OP 的问题不是要求 a numpy 数组的索引吗?
  • 没有。 OP 的问题是如何让 == 比较器返回一个带有 numpy 数组的布尔值,以便他可以找到正确的索引,就像在您的第二个解决方案中一样。
【解决方案3】:

为了提高性能,您可能只想处理输入列表中的 NumPy 数组。因此,我们可以在进入循环之前进行类型检查并索引数组元素。

因此,实现将是 -

def find_index_of_array_v2(list1, array1):
    idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
    for i in idx:
        if np.all(list1[i]==array1):
            return i

【讨论】:

  • 不幸的是,OP 的列表仅包含开头的 numpy 数组(基于 cmets),因此这将提供比优化更多的开销。
  • @MadPhysicist 真的吗?我认为 OP 有一个foo = [1, "hello", np.array([[1,2,3]]) ] 的样本,这是一个混合样本。我错过了提到".. list consists of only numpy arrays"
  • 是的。对问题的第三条评论。 OP 想要最通用的答案是有道理的,这就是为什么他用人为的数组询问的原因,但不幸的是,它确实阻碍了你的答案。
  • @MadPhysicist 我想我会把它留给未来可能有混合形式作为输入列表的读者。可能对他们有用。
【解决方案4】:

这里的问题(您可能已经知道但只是重复一遍)是list.index 的工作方式如下:

for idx, item in enumerate(your_list):
    if item == wanted_item:
        return idx

if item == wanted_item 行是问题所在,因为它隐式地将item == wanted_item 转换为布尔值。但是numpy.ndarray(除非它是一个标量)然后引发ValueError

ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

方案一:适配器(瘦包装)类

每当我需要使用诸如list.index 之类的python 函数时,我通常会在numpy.ndarray 周围使用一个薄包装器(适配器):

class ArrayWrapper(object):

    __slots__ = ["_array"]  # minimizes the memory footprint of the class.

    def __init__(self, array):
        self._array = array

    def __eq__(self, other_array):
        # array_equal also makes sure the shape is identical!
        # If you don't mind broadcasting you can also use
        # np.all(self._array == other_array)
        return np.array_equal(self._array, other_array)

    def __array__(self):
        # This makes sure that `np.asarray` works and quite fast.
        return self._array

    def __repr__(self):
        return repr(self._array)

这些瘦包装器比手动使用一些 enumerate 循环或理解更昂贵,但您不必重新实现 python 函数。假设列表只包含 numpy-arrays(否则你需要做一些if ... else ... 检查):

list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]

在此步骤之后,您可以使用此列表中的所有 python 函数:

>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1

这些包装器不再是 numpy-arrays,但是你有很薄的包装器,所以额外的列表非常小。因此,根据您的需要,您可以保留包装列表和原始列表,并选择在哪个上执行操作,例如您现在也可以list.count 相同的数组:

>>> list_of_wrapped_arrays.count(np.ones((3)))
2

list.remove:

>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]), 
 array([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]]), 
 array([ 1.,  1.,  1.])]

解决方案2:子类和ndarray.view

这种方法使用numpy.array 的显式子类。它的优点是您可以获得所有内置数组功能并且只修改请求的操作(即__eq__):

class ArrayWrapper(np.ndarray):
    def __eq__(self, other_array):
        return np.array_equal(self, other_array)

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]

>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]

>>> view_list.index(np.array([2,2,2]))
1

您再次通过这种方式获得大多数列表方法:list.removelist.count 除了list.index

但是,如果某些操作隐式使用__eq__,这种方法可能会产生微妙的行为。您始终可以使用 np.asarray.view(np.ndarray) 将其重新解释为普通的 numpy 数组:

>>> view_list[1]
ArrayWrapper([ 2.,  2.,  2.])

>>> view_list[1].view(np.ndarray)
array([ 2.,  2.,  2.])

>>> np.asarray(view_list[1])
array([ 2.,  2.,  2.])

替代方案:覆盖 __bool__(或 __nonzero__ 用于 python 2)

除了在__eq__ 方法中解决问题,您还可以覆盖__bool____nonzero__

class ArrayWrapper(np.ndarray):
    # This could also be done in the adapter solution.
    def __bool__(self):
        return bool(np.all(self))

    __nonzero__ = __bool__

这再次使list.index 工作正常:

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1

但这肯定会改变更多的行为!例如:

>>> if ArrayWrapper([1,2,3]):
...     print('that was previously impossible!')
that was previously impossible!

【讨论】:

  • 我一直在寻找一种方法来覆盖数组类本身,但这对现有对象没有帮助。非常好的解决方案。
  • @MadPhysicist 是的,也可以使用ndarray.view 和覆盖__eq__ 的子类。步骤保持不变,在应用list.index-操作之前,您需要一次创建这些视图的列表。
  • 是的,这样做的好处是您可以使用子类列表作为您唯一的列表,而无需进一步修改周围的代码。
  • 无论哪种方式,您都需要更多的支持。如果您将子类化方式放入单独的答案中,我也会赞成。
  • @MSeifert 作为 python 的初学者,从教学的角度来看,这对我非常有用。谢谢。
【解决方案5】:

这应该可以完成工作:

[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-11-19
    • 2023-03-15
    • 2019-02-06
    • 2017-12-23
    • 1970-01-01
    相关资源
    最近更新 更多