CSR 矩阵X 的非零项由以下方式获得
X[i].data
并且(排列)实际行的值将通过在其上附加X.shape[1] - len(X[i].data) 零来获得。
logsumexp(a) = max(a) + log(∑ exp[a - max(a)])
对于矢量a。让我们设置b = X[i].data 和k = X.shape[1] - len(X[i].data) 并将我们之前的X 置换行表示为
(b, 0ₖ)
使用 0ₖ 表示长度为零的向量 k 和 (⋅, ⋅) 表示连接。那么
logsumexp((b, 0ₖ))
= max((b, 0ₖ)) + log(∑ exp[(b, 0ₖ) - max((b, 0ₖ))])
= max(max(b), 0) + log(∑ exp[(b, 0ₖ) - max(max(b), 0)])
= max(max(b), 0) + log(∑ exp[b - max(max(b), 0)] + ∑ exp[0ₖ - max(max(b), 0)])
= max(max(b), 0) + log(∑ exp[b - max(max(b), 0)] + k × exp[-max(max(b), 0)])
所以我们得到了算法
def logsumexp_csr_row(x):
data = x.data
mx = max(np.max(data), 0)
tmp = data - mx
r = np.exp(tmp, out=tmp).sum()
k = X.shape[1] - len(data)
return mx + np.log(r + k * np.exp(-mx))
用于 CSR 行向量。将此算法扩展到完整矩阵很容易通过列表推导完成,尽管更有效的形式会使用 indptr 循环遍历行:
def logsumexp_csr_rows(X):
result = np.empty(X.shape[0])
for i in range(X.shape[0]):
data = X.data[X.indptr[i]:X.indptr[i+1]]
# fill in from logsumexp_csr_row
result[i] = mx + np.log(r + k * np.exp(-mx))
return result
按列的版本要复杂得多;转置矩阵并转换回 CSR 可能是最简单的。
更新好吧,我误解了问题:OP根本对处理零不感兴趣,所以上面的推导没用,算法应该是
def logsumexp_row_nonzeros(X):
result = np.empty(X.shape[0])
for i in range(X.shape[0]):
result[i] = logsumexp(X.data[X.indptr[i]:X.indptr[i+1]])
return result
这只是在 CSR 矩阵上填充按行操作的一般方案。对于按列,转置,转换回 CSR 并应用上述内容。