【问题标题】:Test if an array is broadcastable to a shape?测试一个数组是否可以广播到一个形状?
【发布时间】:2014-09-04 19:05:48
【问题描述】:

测试数组是否可以广播到给定形状的最佳方法是什么?

trying 的“pythonic”方法不适用于我的情况,因为其目的是对操作进行惰性评估。

我在问如何在下面实现is_broadcastable

>>> x = np.ones([2,2,2])
>>> y = np.ones([2,2])
>>> is_broadcastable(x,y)
True
>>> y = np.ones([2,3])
>>> is_broadcastable(x,y)
False

或者更好:

>>> is_broadcastable(x.shape, y.shape)

【问题讨论】:

标签: python arrays numpy multidimensional-array


【解决方案1】:

您可以使用np.broadcast。例如:

In [47]: x = np.ones([2,2,2])

In [48]: y = np.ones([2,3])

In [49]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:     
Not compatible for broadcasting

In [50]: y = np.ones([2,2])

In [51]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:
Result has shape (2, 2, 2)

对于惰性求值的实现,您可能还会发现np.broadcast_arrays 很有用。

【讨论】:

  • broadcast_arrays 是我为了获得懒惰而寻找的东西,但我仍然需要一个验证阶段 - 我如何测试形状 A 是否可以广播到形状 B,假设我不手头有形状 B 的数组吗?
【解决方案2】:

如果您只是想避免使用给定形状实现数组,可以使用 as_strided:

import numpy as np
from numpy.lib.stride_tricks import as_strided

def is_broadcastable(shp1, shp2):
    x = np.array([1])
    a = as_strided(x, shape=shp1, strides=[0] * len(shp1))
    b = as_strided(x, shape=shp2, strides=[0] * len(shp2))
    try:
        c = np.broadcast_arrays(a, b)
        return True
    except ValueError:
        return False

is_broadcastable((1000, 1000, 1000), (1000, 1, 1000))  # True
is_broadcastable((1000, 1000, 1000), (3,))  # False

这是节省内存的,因为 a 和 b 都由单个记录支持

【讨论】:

    【解决方案3】:

    我真的觉得你们想多了,为什么不简单点呢?

    def is_broadcastable(shp1, shp2):
        for a, b in zip(shp1[::-1], shp2[::-1]):
            if a == 1 or b == 1 or a == b:
                pass
            else:
                return False
        return True
    

    【讨论】:

    • 我想不出这有什么问题...没有任何我可能遗漏的微妙的额外规则,是吗?
    【解决方案4】:

    要将其推广到任意多个形状,您可以这样做:

    def is_broadcast_compatible(*shapes):
        if len(shapes) < 2:
            return True
        else:
            for dim in zip(*[shape[::-1] for shape in shapes]):
                if len(set(dim).union({1})) <= 2:
                    pass
                else:
                    return False
            return True
    

    对应的测试用例如下:

    import unittest
    
    
    class TestBroadcastCompatibility(unittest.TestCase):
        def check_true(self, *shapes):
            self.assertTrue(is_broadcast_compatible(*shapes), msg=shapes)
    
        def check_false(self, *shapes):
            self.assertFalse(is_broadcast_compatible(*shapes), msg=shapes)
    
        def test(self):
            self.check_true((1, 2, 3), (1, 2, 3))
            self.check_true((3, 1, 3), (3, 3, 3))
            self.check_true((1,), (2,), (2,))
    
            self.check_false((1, 2, 3), (1, 2, 2))
            self.check_false((1, 2, 3), (1, 2, 3, 4))
            self.check_false((1,), (2,), (3,))
    

    【讨论】:

      【解决方案5】:

      当您想要检查任意数量的类似数组的对象(与传递形状相反)时,我们可以将np.nditer 用于broadcasting array iteration

      def is_broadcastable(*arrays):
          try:
              np.nditer(arrays)
              return True
          except ValueError:
              return False
      

      请注意,这只适用于np.ndarray 或定义__array__ 的类(被调用)。

      【讨论】:

        【解决方案6】:

        numpy.broadcast_shapes 现在从 numpy 1.20 开始可用,因此它可以像这样轻松实现:

        import numpy as np
        
        def is_broadcastable(shp1, shp2):
            try:
                c = np.broadcast_shapes(shp1, shp2)
                return True
            except ValueError:
                return False
        

        在底层,它使用零长度列表numpy数组来调用broadcast_arrays,这样做:

        np.empty(shp, dtype=[])
        

        这样可以避免分配内存。类似于@ChrisB提出的解决方案,但不依赖as_strided的技巧,我觉得有点混乱。

        【讨论】:

          【解决方案7】:

          当顺序很重要时,例如测试np.broadcast_to(a, b.shape) 是否有效,这似乎很有效:

          def is_broadcastable(src, dst):
              try:
                  return np.broadcast_shapes(src, dst) == dst
              except ValueError:
                  return False
          

          np.broadcast_shaps 的返回值与给定参数的顺序无关,我们需要确保结果与dst 相同; dst 是“更大”的尺寸。

          【讨论】:

            猜你喜欢
            • 1970-01-01
            • 2023-04-03
            • 1970-01-01
            • 1970-01-01
            • 2019-12-03
            • 1970-01-01
            • 2012-08-29
            • 1970-01-01
            • 1970-01-01
            相关资源
            最近更新 更多