您的列表理解适用于嵌套列表
In [100]: xl=[[0,1,3],[0,0,1],[1,1,0]]
In [101]: [row.index(1) for index, row in enumerate(xl) if 1 in row]
Out[101]: [1, 2, 0]
(注意index 只返回第三行中的第一个匹配项)。
但不适用于numpy.array:
In [102]: xa=np.array(xl)
In [103]: [row.index(1) for index, row in enumerate(xa) if 1 in row]
...
AttributeError: 'numpy.ndarray' object has no attribute 'index'
而不是稀疏矩阵:
In [104]: xs=sparse.csr_matrix(xl)
In [105]: xs
Out[105]:
<3x3 sparse matrix of type '<class 'numpy.int32'>'
with 5 stored elements in Compressed Sparse Row format>
In [106]: [row.index(1) for index, row in enumerate(xs) if 1 in row]
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().
如果我删除 if 测试,我会得到一个不同的错误,即密集数组错误的变体。
In [108]: [row.index(1) for index, row in enumerate(xs)]
...
AttributeError: index not found
看看枚举给了我们什么工作;
In [109]: [(index,row) for index, row in enumerate(xs)]
Out[109]:
[(0, <1x3 sparse matrix of type '<class 'numpy.int32'>'
with 2 stored elements in Compressed Sparse Row format>),
(1, <1x3 sparse matrix of type '<class 'numpy.int32'>'
with 1 stored elements in Compressed Sparse Row format>),
(2, <1x3 sparse matrix of type '<class 'numpy.int32'>'
with 2 stored elements in Compressed Sparse Row format>)]
row 是另一个稀疏矩阵,与xs[0] 等相同。因此1 in row 和row.index(1) 表达式必须与数组或矩阵一起使用,否则会出错。
我们已经看到index 方法也没有。那是一种列表方法——你必须对数组或稀疏矩阵使用其他东西。您的理解包含if 子句,因为如果找不到该项目,列表index 会引发错误。从这个意义上说,if in 和 index 一起出现。
in 适用于数组,但给出稀疏矩阵的值错误:
In [114]: 1 in xa[0]
Out[114]: True
In [115]: 1 in xs[0]
....
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().
更常见的 ValueError 是由以下等价物产生的:
In [117]: if np.array([True, False, True]):'yes'
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
也就是说,给if 一个布尔数组。在您的情况下,此故障发生在 sparse 代码中。实际上,in 尚未针对稀疏实现。
因此,如果您坚持使用这种列表推导方法,则必须将稀疏矩阵转换为列表列表:
In [120]: [row.index(1) for index, row in enumerate(xs.toarray().tolist()) if 1 in row]
Out[120]: [1, 2, 0]
这是unutbu's 答案的变体:
使用矩阵/数组相等性测试来查找所有匹配的元素:
In [121]: xs==1
Out[121]:
<3x3 sparse matrix of type '<class 'numpy.bool_'>'
with 4 stored elements in Compressed Sparse Row format>
In [122]: (xs==1).A
Out[122]:
array([[False, True, False],
[False, False, True],
[ True, True, False]], dtype=bool)
然后使用内置方法获取那些True 元素的索引:
In [123]: (xs==1).nonzero()
Out[123]: (array([0, 1, 2, 2], dtype=int32), array([1, 2, 0, 1], dtype=int32))
该元组的第二个元素是您想要的列表(第 3 行有 2 个值)。
或者收集行的值(记住,在迭代每一行时是一个矩阵)
In [125]: [i.nonzero() for i in (xs==1)]
Out[125]:
[(array([0], dtype=int32), array([1], dtype=int32)),
(array([0], dtype=int32), array([2], dtype=int32)),
(array([0, 0], dtype=int32), array([0, 1], dtype=int32))]
将该列表简化为简单的索引列表需要更多的摆弄
In [131]: [i.nonzero()[1].tolist() for i in (xs==1)]
Out[131]: [[1], [2], [0, 1]]