这个功能怎么样:
In [257]: def myfunc():
...: def foo(x):
...: res = [np.sin(i) if i<np.pi else np.cos(i) for i in x]
...: return np.array(res)
...: return foo
...:
In [258]: y=myfunc()
In [260]: y(np.linspace(0, 2*np.pi,5))
Out[260]:
array([ 0.0000000e+00, 1.0000000e+00, -1.0000000e+00, -1.8369702e-16,
1.0000000e+00])
我猜可以写成lambda;但有什么意义呢? lambda 只是编写“匿名”单行函数的一种方式。否则,该语法没有什么特别之处。
piecewise 可以做同样的事情;我们也可以where 或其他掩码从sin 和cos 中选择值。但我的主要观点是“lambda”并不是什么特别的东西,其次,像列表理解这样简单的东西可以绕过歧义错误。
In [268]: def myfunc():
...: return lambda x: np.where(x<np.pi, np.sin(x), np.cos(x))
# alt: lambda x: np.sin(x, where=x<np.pi, out=np.cos(x))
In [270]: myfunc()(np.linspace(0,2*np.pi,5))
Out[270]:
array([ 0.0000000e+00, 1.0000000e+00, -1.0000000e+00, -1.8369702e-16,
1.0000000e+00])
为恩佐
恩佐建议apply_along_axes。虽然我反对它对速度没有帮助,但更重要的是它在这里没有帮助。
In [261]: def my_func():
...: return lambda x: np.sin(x) if x < np.pi else np.cos(x)
...:
In [262]: y=my_func()
In [263]: y(0)
Out[263]: 0.0
名字错误:
In [265]: np.apply_along_axes(y, 0, np.linspace(0,2*np.pi,5))
Traceback (most recent call last):
File "<ipython-input-265-24f826b71908>", line 1, in <module>
np.apply_along_axes(y, 0, np.linspace(0,2*np.pi,5))
File "/usr/local/lib/python3.8/dist-packages/numpy/__init__.py", line 303, in __getattr__
raise AttributeError("module {!r} has no attribute "
AttributeError: module 'numpy' has no attribute 'apply_along_axes'
但我们仍然得到歧义错误:
In [266]: np.apply_along_axis(y, 0, np.linspace(0,2*np.pi,5))
Traceback (most recent call last):
File "<ipython-input-266-fca42b21012f>", line 1, in <module>
np.apply_along_axis(y, 0, np.linspace(0,2*np.pi,5))
File "<__array_function__ internals>", line 5, in apply_along_axis
File "/usr/local/lib/python3.8/dist-packages/numpy/lib/shape_base.py", line 379, in apply_along_axis
res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
File "<ipython-input-261-ac6ffeefd41a>", line 2, in <lambda>
return lambda x: np.sin(x) if x < np.pi else np.cos(x)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
为什么?因为apply 旨在将一维数组传递给函数。它适用于 2d 数组,但更适用于嵌套迭代看起来很混乱的 3d(或更大)。使用一维输入,它只传递整个数组:
In [267]: np.apply_along_axis(lambda x:str(x), 0, np.linspace(0,2*np.pi,5))
Out[267]:
array('[0. 1.57079633 3.14159265 4.71238898 6.28318531]',
dtype='<U56')
添加维度apply... 确实有效。但通常当人们对apply... 有问题时,他们的目标是提高速度(“矢量化”)。他们认为,或者至少希望,它会比简单的迭代更快。它可能不会被弃用,但肯定可以使用一些免责声明!
In [292]: y=my_func()
In [293]: timeit np.array([y(i) for i in np.linspace(0, 2*np.pi, 5)])
88.7 µs ± 966 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [294]: timeit np.apply_along_axis(y, 0, [np.linspace(0, 2*np.pi, 5)])
181 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)