【问题标题】:Custom class with __add__ to add with NumPy array带有 __add__ 的自定义类与 NumPy 数组一起添加
【发布时间】:2017-08-29 19:40:44
【问题描述】:

我有一个自定义类将 __add__ 和 __radd__ 实现为

import numpy

class Foo(object):

    def __init__(self, val):
        self.val = val

    def __add__(self, other):
        print('__add__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return self.val + other

    def __radd__(self, other):
        print('__radd__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return other + self.val

我首先测试__add__

r1 = Foo(numpy.arange(3)) + numpy.arange(3,6)
print('type results = %s' % type(r1))
print('result = {}'.format(r1))

它会导致预期的结果

>>> __add__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'numpy.ndarray'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [3  5  7]

但是,测试 __radd__

r2 = numpy.arange(3) + Foo(numpy.arange(3,6))
print('type results = %s' % type(r2))
print('result = {}'.format(r2))

我明白了

>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [array([3, 4, 5]) array([4, 5, 6]) array([5, 6, 7])]

这对我来说没有任何意义。 NumPy 是否为任意对象重载 __add__,然后优先于我的 __radd__?如果是,他们为什么要这样做?此外,我该如何避免这种情况,我真的希望能够在左侧添加带有 NumPy 数组的自定义类。谢谢。

【问题讨论】:

  • 是的,让算术运算符与numpy 数组一起工作有点棘手,有很多底层机制。我相信numpy 提供了可以让你相对轻松地做到这一点的mixin。如果没有其他人有时间,我也许可以稍后再研究。你可以read more about it here
  • 感谢您的链接。我没有完全理解这些 ufunc 是什么以及它们是如何工作的,但是通过在我的课堂上设置 __numpy_ufunc__ = None__array_ufunc__ = None 用于 NumPy 13.0+),我得到了我想要的结果。
  • ufuncs 是向量化函数。
  • 另外,您应该发布您的答案并接受它。这是一个相当不错的问题。

标签: python numpy operator-overloading


【解决方案1】:

这被 cmets 隐藏了,但应该是答案。

默认情况下,Numpy 操作在每个元素的基础上进行,获取任意对象,然后尝试按元素执行操作(根据广播规则)。

这意味着,例如,给定

class N:
    def __init__(self, x):
        self.x = x

    def __add__(self, other):
        return self.x + other

    def __radd__(self, other):
        return other + self.x

由于 Python 的操作符解析

N(3) + np.array([1, 2, 3])

将使用N(3) 和整个数组作为other 一次到达上述__add__,然后执行常规的Numpy 添加。

另一方面

np.array([1, 2, 3]) + N(3)

将成功进入 Numpy 的 ufuncs(本例中为运算符),因为它们将任意对象作为“其他”,然后尝试依次执行:

1 + N(3)
2 + N(3)
3 + N(3)

这意味着上面的 __add__被调用 3 次而不是一次,每个元素调用一次,显着减慢了操作。要禁用此行为,并使Numpy 在获取N 对象时引发NotImplementedError,从而允许RHS 重载radd 接管,请将以下内容添加到类的主体中:

class N:
    ...
    __numpy_ufunc__ = None # Numpy up to 13.0
    __array_ufunc__ = None # Numpy 13.0 and above

如果向后兼容性不是问题,则只需要第二个。

【讨论】:

    猜你喜欢
    • 2018-03-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2014-03-27
    • 2019-09-04
    相关资源
    最近更新 更多