这里的问题(您可能已经知道但只是重复一遍)是list.index 的工作方式如下:
for idx, item in enumerate(your_list):
if item == wanted_item:
return idx
if item == wanted_item 行是问题所在,因为它隐式地将item == wanted_item 转换为布尔值。但是numpy.ndarray(除非它是一个标量)然后引发ValueError:
ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()
方案一:适配器(瘦包装)类
每当我需要使用诸如list.index 之类的python 函数时,我通常会在numpy.ndarray 周围使用一个薄包装器(适配器):
class ArrayWrapper(object):
__slots__ = ["_array"] # minimizes the memory footprint of the class.
def __init__(self, array):
self._array = array
def __eq__(self, other_array):
# array_equal also makes sure the shape is identical!
# If you don't mind broadcasting you can also use
# np.all(self._array == other_array)
return np.array_equal(self._array, other_array)
def __array__(self):
# This makes sure that `np.asarray` works and quite fast.
return self._array
def __repr__(self):
return repr(self._array)
这些瘦包装器比手动使用一些 enumerate 循环或理解更昂贵,但您不必重新实现 python 函数。假设列表只包含 numpy-arrays(否则你需要做一些if ... else ... 检查):
list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]
在此步骤之后,您可以使用此列表中的所有 python 函数:
>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1
这些包装器不再是 numpy-arrays,但是你有很薄的包装器,所以额外的列表非常小。因此,根据您的需要,您可以保留包装列表和原始列表,并选择在哪个上执行操作,例如您现在也可以list.count 相同的数组:
>>> list_of_wrapped_arrays.count(np.ones((3)))
2
或list.remove:
>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]]),
array([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]]),
array([ 1., 1., 1.])]
这种方法使用numpy.array 的显式子类。它的优点是您可以获得所有内置数组功能并且只修改请求的操作(即__eq__):
class ArrayWrapper(np.ndarray):
def __eq__(self, other_array):
return np.array_equal(self, other_array)
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
您再次通过这种方式获得大多数列表方法:list.remove、list.count 除了list.index。
但是,如果某些操作隐式使用__eq__,这种方法可能会产生微妙的行为。您始终可以使用 np.asarray 或 .view(np.ndarray) 将其重新解释为普通的 numpy 数组:
>>> view_list[1]
ArrayWrapper([ 2., 2., 2.])
>>> view_list[1].view(np.ndarray)
array([ 2., 2., 2.])
>>> np.asarray(view_list[1])
array([ 2., 2., 2.])
替代方案:覆盖 __bool__(或 __nonzero__ 用于 python 2)
除了在__eq__ 方法中解决问题,您还可以覆盖__bool__ 或__nonzero__:
class ArrayWrapper(np.ndarray):
# This could also be done in the adapter solution.
def __bool__(self):
return bool(np.all(self))
__nonzero__ = __bool__
这再次使list.index 工作正常:
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
但这肯定会改变更多的行为!例如:
>>> if ArrayWrapper([1,2,3]):
... print('that was previously impossible!')
that was previously impossible!