【问题标题】:numpy binned mean, conserving extra axesnumpy binned mean,保留额外的轴
【发布时间】:2017-09-28 11:20:08
【问题描述】:

我似乎陷入了numpy 的以下问题。

我有一个数组X,形状为:X.shape = (nexp, ntime, ndim, npart) 我需要根据binvals(和一些bins)中的值沿npart 维度计算此数组的分箱统计信息,但将所有其他维度保留在那里,因为我必须使用分箱统计数据来删除一些原始数组X 中的偏差。分箱值的形状为 binvals.shape = (nexp, ntime, npart)

一个完整的、最小的例子,用来解释我想要做什么。请注意,实际上,我正在处理大型数组和数百个 bin(所以这个实现需要很长时间):

import numpy as np

np.random.seed(12345)

X = np.random.randn(24).reshape(1,2,3,4)
binvals = np.random.randn(8).reshape(1,2,4)
bins = [-np.inf, 0, np.inf]
nexp, ntime, ndim, npart = X.shape

cleanX = np.zeros_like(X)
for ne in range(nexp):
    for nt in range(ntime):
        indices = np.digitize(binvals[ne, nt, :], bins)
        for nd in range(ndim):
            for nb in range(1, len(bins)):
                inds = indices==nb
                cleanX[ne, nt, nd, inds] = X[ne, nt, nd, inds] - \
                     np.mean(X[ne, nt, nd, inds], axis = -1)

看看这个结果可能会更清楚吗?

In [8]: X
Out[8]: 
array([[[[-0.20470766,  0.47894334, -0.51943872, -0.5557303 ],
         [ 1.96578057,  1.39340583,  0.09290788,  0.28174615],
         [ 0.76902257,  1.24643474,  1.00718936, -1.29622111]],

        [[ 0.27499163,  0.22891288,  1.35291684,  0.88642934],
         [-2.00163731, -0.37184254,  1.66902531, -0.43856974],
         [-0.53974145,  0.47698501,  3.24894392, -1.02122752]]]])

In [10]: cleanX
Out[10]: 
array([[[[ 0.        ,  0.67768523, -0.32069682, -0.35698841],
         [ 0.        ,  0.80405255, -0.49644541, -0.30760713],
         [ 0.        ,  0.92730041,  0.68805503, -1.61535544]],

        [[ 0.02303938, -0.02303938,  0.23324375, -0.23324375],
         [-0.81489739,  0.81489739,  1.05379752, -1.05379752],
         [-0.50836323,  0.50836323,  2.13508572, -2.13508572]]]])


In [12]: binvals
Out[12]: 
array([[[ -5.77087303e-01,   1.24121276e-01,   3.02613562e-01,
           5.23772068e-01],
        [  9.40277775e-04,   1.34380979e+00,  -7.13543985e-01,
          -8.31153539e-01]]])

是否有矢量化解决方案?我想过使用scipy.stats.binned_statistic,但我似乎无法理解如何使用它来达到这个目的。谢谢!

【问题讨论】:

  • 你能提供一些虚拟输入吗?
  • 什么意思?什么都能做:X=np.random.randn(120).reshape(3,4,2,5)binvals=np.random.randn(24).reshape(3,4,2)bins=np.linspace(binvals.min(), binvals.max(), 10)
  • 我收到了IndexError: boolean index did not match..,其中包含已发布代码的示例数据。
  • 确实,那是因为我放错了 binvals 的形状。应该是:X=np.random.randn(120).reshape(3,4,2,5)binvals=np.random.randn(60).reshape(3,4,5)bins=np.linspace(binvals.min(), binvals.max(), 10),如问题所示。
  • 如果您确定了一些输入和预期输出,解决这些问题会容易得多。 stackoverflow.com/help/mcve

标签: arrays numpy multidimensional-array binning


【解决方案1】:
import numpy as np

np.random.seed(100)

nexp = 3
ntime = 4
ndim = 5
npart = 100
nbins = 4

binvals = np.random.rand(nexp, ntime, npart)
X = np.random.rand(nexp, ntime, ndim, npart)
bins = np.linspace(0, 1, nbins + 1)

d = np.digitize(binvals, bins)[:, :, np.newaxis, :]
r = np.arange(1, len(bins)).reshape((-1, 1, 1, 1, 1))
m = d[np.newaxis, ...] == r
counts = np.sum(m, axis=-1, keepdims=True).clip(min=1)
means = np.sum(X[np.newaxis, ...] * m, axis=-1, keepdims=True) / counts
cleanX = X - np.choose(d - 1, means)

【讨论】:

  • 好吧,我得再考虑一下,但在我看来,这与我一直在寻找的东西并不一样。
  • @user6760680 我添加了一个没有循环的替代解决方案(应该更快),但会消耗更多内存。
  • 我花了一段时间才明白什么没有让我信服,但关键是您正在对 a 进行分箱,而我必须计算 a 上的统计信息,但要对不同的数组进行分箱。
  • @user6760680 好的,我明白你的意思,我误解了问题,我会解决它。
  • 顺便说一句,答案最后使用np.choose,显然仅限于32个不同的选择(所以你最多只能使用32个垃圾箱)......你需要一个不同的路径最后一步,如果您需要更多...
【解决方案2】:

好的,我想我明白了,主要基于@jdehesa 的回答。

clean2 = np.zeros_like(X)
d = np.digitize(binvals, bins)
for i in range(1, len(bins)):
    m = d == i
    minds = np.where(m)
    sl = [*minds[:2], slice(None), minds[2]]
    msum = m.sum(axis=-1)
    clean2[sl] = (X - \
                  (np.sum(X * m[...,np.newaxis,:], axis=-1) / 
                  msum[..., np.newaxis])[..., np.newaxis])[sl]

这给出了与我的原始代码相同的结果。 在此处示例中的小数组上,此解决方案的速度大约是原始代码的三倍。我希望它在更大的阵列上更快。

更新:

确实,它在更大的阵列上更快(没有进行任何正式测试),但尽管如此,它在性能方面只是达到了可接受的水平......非常欢迎任何关于额外矢量化的进一步建议。

【讨论】:

  • 我也更新了我的答案。我的代码没有给出相同的结果,但是......当我运行它时它产生的值接近于零(我猜这就是重点),而你的原始代码产生的值高达 +/- 6(这很奇怪,因为@ 987654322@ 值在[0, 1])... 我不知道有什么区别!以防万一它对你有用......
  • @jdehesa X 值来自标准正态分布,因此它们不限于 [0,1]。我检查了我的代码,它可以满足我的需求,即使它可能没有我希望的那么快。无论如何,非常感谢您的建议,至少对于显着提高性能非常有用!
猜你喜欢
  • 2017-08-20
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多