【问题标题】:Bitwise input verification with Numpy arrays使用 Numpy 数组进行按位输入验证
【发布时间】:2020-10-20 00:48:12
【问题描述】:

我已阅读 this question 并了解 Numpy 数组不能在布尔上下文中使用。假设我想对函数输入的有效性执行逐元素布尔检查。我可以在仍然使用 Numpy 矢量化的同时实现这种行为吗?如果可以,如何实现? (如果不是,为什么?)

在以下示例中,我从两个输入计算一个值,同时检查两个输入是否有效(都必须大于 0)

import math, numpy
def calculate(input_1, input_2):
    if input_1 < 0 or input_2 < 0:
        return 0
    return math.sqrt(input_1) + math.sqrt(input_2)
calculate_many = (lambda x: calculate(x, 20 - x))(np.arange(-20, 40))

由于ValueError,这本身不适用于 Numpy 数组。但是,绝对不要在负输入上运行 math.sqrt,因为这会导致另一个错误。

使用列表推导的一种解决方案如下:

calculate_many = [calculate(x, 20 - x) for x in np.arange(-20, 40)]/=

但是,这不再使用矢量化,并且如果 arange 的大小急剧增加,将会非常缓慢。 有没有办法在仍然使用矢量化的同时实现这个if 检查?

【问题讨论】:

  • math.sqrt 仅适用于标量,因此您的 calculate,即使没有 if,也不适用于整个数组 - 即。没有“矢量化”。有一个 np.sqrt 可以与整个数组一起使用。它接受where 参数来控制评估哪些值(与out 参数一起使用)。
  • 可选的where 是类数组是什么意思?这是否意味着如果我想对input_1 中的所有正数进行平方根并将所有其他输出条目保留为未定义,我会写np.sqrt(input_1, input_1 &gt; 0)
  • ^上面运行时不起作用,但我不知道怎么写,文档没有解释。

标签: python numpy


【解决方案1】:

我相信下面的表达式执行矢量化操作并避免使用循环/lambda 函数

np.sqrt(((input1>0) & 1)*input1) + np.sqrt(((input2>0) & 1)*input2)

【讨论】:

  • 我是一个不熟悉行业的学生。像我使用它们的方式那样使用 lambdas 是否通常被认为是不好的做法?
  • 我会说这不是一个坏习惯。但它也不是矢量化操作。当您没有长数组时,或者在其他非 numpy 情况下,这应该不是一件坏事。事实上它是pythonic!
【解决方案2】:
In [121]: x = np.array([1, 10, 21, -1.])                                                
In [122]: y = 20-x                                                                      
In [123]: np.sqrt(x)                                                                    
/usr/local/bin/ipython3:1: RuntimeWarning: invalid value encountered in sqrt
  #!/usr/bin/python3
Out[123]: array([1.        , 3.16227766, 4.58257569,        nan])

有几种方法可以处理“超出范围”的值。

@Sam 的方法是调整输入以使其有效

In [129]: ((x>0) & 1)*x                                                                 
Out[129]: array([ 1., 10., 21., -0.])

另一个是使用掩码来限制计算的值。

您的函数跳过sqrt 要么输入为负;要么相反,它会在两者都有效的情况下进行计算。这与单独测试不同。

In [124]: mask = (x>=0) & (y>=0)                                                        
In [125]: mask                                                                          
Out[125]: array([ True,  True, False, False])

我们可以这样使用掩码:

In [126]: res = np.zeros_like(x)                                                        
In [127]: res[mask] = np.sqrt(x[mask]) + np.sqrt(y[mask])                               
In [128]: res                                                                           
Out[128]: array([5.35889894, 6.32455532, 0.        , 0.        ])

在我的 cmets 中,我建议使用 np.sqrtwhere 参数。不过,它也需要一个out 参数。

In [130]: np.sqrt(x, where=mask, out=np.zeros_like(x)) +
          np.sqrt(y, where=mask, out=np.zeros_like(x))                                                                   
Out[130]: array([5.35889894, 6.32455532, 0.        , 0.        ])

或者,如果我们对 Out[123] 中的 nan 感到满意,我们可以取消 RuntimeWarning。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-04-30
    • 2013-02-20
    • 2019-04-22
    • 2020-03-21
    • 1970-01-01
    相关资源
    最近更新 更多