【问题标题】:Efficiently select subsection of numpy array有效地选择 numpy 数组的子部分
【发布时间】:2014-10-20 10:16:41
【问题描述】:

我想根据逻辑比较将一个 numpy 数组拆分为三个不同的数组。我要拆分的 numpy 数组称为x。它的形状如下所示,但它的条目有所不同:(为了回应 Saullo Castro 的评论,我包含了一个稍微不同的数组 x。)

array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
      [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

此数组的值沿列单调递增。我还有另外两个名为lowest_gridpointshighest_gridpoints 的数组。这些数组的条目也有所不同,但形状始终与以下内容相同:

 array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

我要申请的选拔程序如下:

  • 应从x 中删除所有值低于lowest_gridpoints 中任何值的列,并构成数组temp1
  • 所有值高于highest_gridpoints 中任何值的列都应从x 中删除,并构成数组temp2
  • x 中不包含在 temp1temp2 中的所有列构成数组 x_new

我编写的以下代码实现了该任务。

if np.any( x[:,-1] > highest_gridpoints ) or np.any( x[:,0] < lowest_gridpoints ):
    for idx, sample, in enumerate(x.T):
        if np.any( sample > highest_gridpoints):
            max_idx = idx
            break
        elif np.any( sample < lowest_gridpoints ):
            min_idx = idx 
    temp1, temp2 = np.array([[],[]]), np.array([[],[]])
    if 'min_idx' in locals():
        temp1 = x[:,0:min_idx+1]
    if 'max_idx' in locals():
        temp2 = x[:,max_idx:]
    if 'min_idx' in locals() or 'max_idx' in locals():
        if 'min_idx' not in locals():
            min_idx = -1
        if 'max_idx' not in locals():
            max_idx = x.shape[1]
        x_new = x[:,min_idx+1:max_idx]

但是,由于大量使用循环,我怀疑这段代码效率很低。另外,我认为语法很臃肿。

是否有人对更有效地完成上述任务或看起来简洁的代码有想法?

【问题讨论】:

  • 你的例子为我返回了[]...如果有一个不同的输入可以用于比较就好了...
  • @SaulloCastro:感谢您的评论。我稍微修改了数组 x。你知道如何修改我的代码吗?
  • 您是否希望 temp1 和 temp2 互斥,或者列的值是否会低于lowest_gridpoints 中的值而另一个值高于highest_gridpoints 中的值?另外,您的意思是沿行单调递增吗?
  • 也许你可以使用np.argsort(x[i] + [lowest_gridpoints[i]])[-1]。这将为您提供大于lowest_gridpoints[i] 的第一个元素的索引。对所有i 执行此操作并获得最大值(highest_gridpoints 的最小值)
  • @greschd:这很好。我希望temp1temp2 是互斥的。在我的代码中,这是由break 命令在` if np.any(sample >highest_gridpoints): In doubt, I classify columns of x` 到para2 而不是para1 之后保证的。我的意思是沿 np.arrays 的第二维单调递增,因此x[0,i] &gt;= x[0,j] for i > j。我希望(并认为)这指的是列。

标签: python arrays loops numpy


【解决方案1】:

仅问题的第一部分

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

# construct an array of two rows of bools expressing your conditions
indices1 = array((x[0,:]<low[0], x[1,:]<low[1]))
print indices1

# do an or of the values along the first axis
indices1 = any(indices1, axis=0)
# now it's a single row array
print indices1

# use the indices1 to extract what you want,
# the double transposition because the elements
# of a 2d array are  the rows
tmp1 = x.T[indices1].T
print tmp1

# [[ True  True False False False]
#  [ True  True False False False]]
# [ True  True False False False]
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]

next 构造类似indices2tmp2,剩余的索引是前两个索引的oring 的否定。 (即numpy.logical_not(numpy.logical_or(i1,i2)))。

附录

如果您有数千个条目,另一种方法可能更快,暗示numpy.searchsorted

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')

h0l = searchsorted(x[0,:], high[0], side='left')
h1l = searchsorted(x[1,:], high[1], side='left')

lr = max(l0r, l1r)
hl = min(h0l, h1l)

print lr, hl
print x[:,:lr]
print x[:,lr:hl]
print x[:,hl]

# 2 4
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]
# [[ 0.70164242  0.84519205]
#  [ 0.05        0.05      ]]
# [ 1.4   0.05]

排除重叠可以通过hl = max(lr, hl)获取。注意在之前的方法中,数组切片被复制到新对象,在这里你可以看到x,如果你想要新对象,你必须明确。

编辑 不必要的优化

如果我们在sortedsearches 的第二对中仅使用x 的上半部分(如果您查看代码,您会明白我的意思......)我们有两个好处,1)非常搜索的小幅加速(sortedsearch 总是足够快)和 2) 自动管理重叠的情况。

作为奖励,将x 的段复制到新数组中的代码。 NB x 改为强制重叠

from numpy import *

# I changed x to force overlap
x = array([[ 0.46006547,  1.4 ,        1.4,   1.4,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05,  0.05, 0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
lr = max(l0r, l1r)

h0l = searchsorted(x[0,lr:], high[0], side='left')
h1l = searchsorted(x[1,lr:], high[1], side='left')

hl = min(h0l, h1l) + lr

t1 = x[:,range(lr)]
xn = x[:,range(lr,hl)]
ncol = shape(x)[1]
t2 = x[:,range(hl,ncol)]

print x
del(x)
print
print t1
print
# note that xn is a void array 
print xn
print
print t2

# [[ 0.46006547  1.4         1.4         1.4         1.4       ]
#  [ 0.00912908  0.00912908  0.05        0.05        0.05      ]]
# 
# [[ 0.46006547  1.4       ]
#  [ 0.00912908  0.00912908]]
# 
# []
# 
# [[ 1.4   1.4   1.4 ]
#  [ 0.05  0.05  0.05]]

【讨论】:

  • 我开始担心我没有理解 OP 要求。
  • 感谢您的回答;你的附录对我来说效果很好,除了一个修改:为了避免重叠,我不得不使用if hl &lt; lr: hl = hl + lr
猜你喜欢
  • 2015-11-16
  • 1970-01-01
  • 2020-10-27
  • 2015-07-30
  • 1970-01-01
  • 1970-01-01
  • 2016-10-15
  • 2016-04-04
  • 1970-01-01
相关资源
最近更新 更多