【问题标题】:Condition on array passed as argument to lambda function数组的条件作为参数传递给 lambda 函数
【发布时间】:2021-05-18 00:58:37
【问题描述】:

我想在我的 lambda 函数中添加一个条件

import numpy as np
import matplotlib.pyplot as plt    

def my_func():
    return lambda x: np.sin(x) if x < np.pi else np.cos(x)
    
x = np.linspace(0, 2*np.pi, 1000)
y = my_func()
plt.plot(x, y(x))
plt.show()

当我将数组“x”传递给 lambda 函数时,它显然给出了一个模糊/多个真值的数组,并建议使用 any() 或 all()。但这不符合我的目的。

我怎样才能达到预期的结果?谢谢。

【问题讨论】:

  • 这确实在这个例子中提供了预期的结果,但是我需要一个 lambda 函数来作为我的主程序的回报,而 numpy.piecewise 给出了评估的数组。
  • 如果将y(x) 替换为numpy.apply_along_axes(y, 0, x) 会是一个解决方案吗?
  • lambda 有什么特别之处? lambda 只是定义函数的 1 行方法。您正在绘制一个数组,y(x) 的结果,而不是y。当x 是一个数组时,if x&lt;pi 会产生歧义错误。
  • @Enzo, apply_along_axis 将一维数组传递给函数,迭代其他维度。它不是一个速度工具。不推荐。

标签: python-3.x numpy lambda


【解决方案1】:

这个功能怎么样:

In [257]: def myfunc():
     ...:     def foo(x):
     ...:         res = [np.sin(i) if i<np.pi else np.cos(i) for i in x]
     ...:         return np.array(res)
     ...:     return foo
     ...: 
In [258]: y=myfunc()
In [260]: y(np.linspace(0, 2*np.pi,5))
Out[260]: 
array([ 0.0000000e+00,  1.0000000e+00, -1.0000000e+00, -1.8369702e-16,
        1.0000000e+00])

我猜可以写成lambda;但有什么意义呢? lambda 只是编写“匿名”单行函数的一种方式。否则,该语法没有什么特别之处。

piecewise 可以做同样的事情;我们也可以where 或其他掩码从sincos 中选择值。但我的主要观点是“lambda”并不是什么特别的东西,其次,像列表理解这样简单的东西可以绕过歧义错误。

In [268]: def myfunc():
     ...:     return lambda x: np.where(x<np.pi, np.sin(x), np.cos(x))
              # alt: lambda x: np.sin(x, where=x<np.pi, out=np.cos(x))

In [270]: myfunc()(np.linspace(0,2*np.pi,5))
Out[270]: 
array([ 0.0000000e+00,  1.0000000e+00, -1.0000000e+00, -1.8369702e-16,
        1.0000000e+00])

为恩佐

恩佐建议apply_along_axes。虽然我反对它对速度没有帮助,但更重要的是它在这里没有帮助。

In [261]: def my_func():
     ...:     return lambda x: np.sin(x) if x < np.pi else np.cos(x)
     ...: 
In [262]: y=my_func()
In [263]: y(0)
Out[263]: 0.0

名字错误:

In [265]: np.apply_along_axes(y, 0, np.linspace(0,2*np.pi,5))
Traceback (most recent call last):
  File "<ipython-input-265-24f826b71908>", line 1, in <module>
    np.apply_along_axes(y, 0, np.linspace(0,2*np.pi,5))
  File "/usr/local/lib/python3.8/dist-packages/numpy/__init__.py", line 303, in __getattr__
    raise AttributeError("module {!r} has no attribute "
AttributeError: module 'numpy' has no attribute 'apply_along_axes'

但我们仍然得到歧义错误:

In [266]: np.apply_along_axis(y, 0, np.linspace(0,2*np.pi,5))
Traceback (most recent call last):
  File "<ipython-input-266-fca42b21012f>", line 1, in <module>
    np.apply_along_axis(y, 0, np.linspace(0,2*np.pi,5))
  File "<__array_function__ internals>", line 5, in apply_along_axis
  File "/usr/local/lib/python3.8/dist-packages/numpy/lib/shape_base.py", line 379, in apply_along_axis
    res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
  File "<ipython-input-261-ac6ffeefd41a>", line 2, in <lambda>
    return lambda x: np.sin(x) if x < np.pi else np.cos(x)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

为什么?因为apply 旨在将一维数组传递给函数。它适用于 2d 数组,但更适用于嵌套迭代看起来很混乱的 3d(或更大)。使用一维输入,它只传递整个数组:

In [267]: np.apply_along_axis(lambda x:str(x), 0, np.linspace(0,2*np.pi,5))
Out[267]: 
array('[0.         1.57079633 3.14159265 4.71238898 6.28318531]',
      dtype='<U56')

添加维度apply... 确实有效。但通常当人们对apply... 有问题时,他们的目标是提高速度(“矢量化”)。他们认为,或者至少希望,它会比简单的迭代更快。它可能不会被弃用,但肯定可以使用一些免责声明!

In [292]: y=my_func()
In [293]: timeit np.array([y(i) for i in np.linspace(0, 2*np.pi, 5)])
88.7 µs ± 966 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [294]: timeit np.apply_along_axis(y, 0, [np.linspace(0, 2*np.pi, 5)])
181 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

【讨论】:

  • 如果您没有按照 OP 要求使用 lambda,您的解决方案比我的解决方案如何?要修复我的解决方案,只需执行np.apply_along_axis(y, 0, [np.linspace(0, 2*np.pi, 5)]) 和 OP 即可。
  • numpy.where 帮助我保留了 lambda 函数并应用了条件。我知道 lambda 本身不需要,因此定义为 foo 的子函数的解决方案也适用于此。我很惊讶我以前使用过这两种东西,但是对于这个特定的问题却没有考虑到这一点。谢谢!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-12-07
  • 2013-01-27
  • 2018-10-26
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多