【问题标题】:Comparing two dictionaries with numpy matrices as values将两个字典与 numpy 矩阵作为值进行比较
【发布时间】:2014-12-12 19:04:14
【问题描述】:

我想断言两个 Python 字典是相等的(这意味着:键的数量相等,并且从键到值的每个映射都是相等的;顺序并不重要)。一个简单的方法是assert A==B,但是,如果字典的值为numpy arrays,这将不起作用。如何编写一个函数来检查两个字典是否相等?

>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

编辑 我知道 numpy 矩阵应检查是否与 .all() 相等。我正在寻找的是一种检查此问题的通用方法,而无需检查isinstance(np.ndarray)。这可能吗?

没有numpy数组的相关主题:

【问题讨论】:

    标签: python numpy dictionary equality


    【解决方案1】:

    【讨论】:

    • 这不返回布尔值。相反,如果对象不相等,它会引发异常。它可以用来构造一个 cmp 函数,但它本身不是一个。
    • @GuilhermedeLazari 在某些时候你在这里分裂头发。只需使用 try/except 块创建您的 cmp 函数。它几乎是自己写的。
    • @GuilhermedeLazari 最初的问题是“我想断言两个 Python 字典是相等的”
    • 只有当您预先知道这些值是 numpy 数组时,此答案才有效。问题是在不首先检查值的实例类型的情况下找到一种通用方法。
    • 我已经有一段时间没有使用它了,但是文档说“给定两个对象(标量、列表、元组、字典或 numpy 数组),检查这些对象的所有元素是否相等。在第一个冲突值处引发异常。”似乎它应该适用于其他类型,但如果你发现这不是真的,那可能是一个错误。
    【解决方案2】:

    我将回答隐藏在你的问题标题和前半部分中的一半问题,因为坦率地说,这是一个更常见的问题需要解决,现有的答案并不能很好地解决它。这个问题是“如何比较两个 numpy 数组的字典是否相等”?

    问题的第一部分是“从远处”检查字典:查看它们的键是否相同。如果所有键都相同,则第二部分是比较每个对应的值。

    现在微妙的问题是很多 numpy 数组不是整数值的,double-precision is imprecise。因此,除非您有整数值(或其他非浮点型)数组,否则您可能需要检查这些值是否几乎相同,即在机器精度范围内。所以在这种情况下,您不会使用np.array_equal(检查精确的数值相等性),而是使用np.allclose(对两个数组之间的相对和绝对误差使用有限容差)。

    问题的前半部分很简单:检查字典的键是否一致,并使用生成器推导来比较每个值(并在推导之外使用all 来验证每个项目是同样):

    import numpy as np
    
    # some dummy data
    
    # these are equal exactly
    dct1 = {'a': np.array([2, 3, 4])}
    dct2 = {'a': np.array([2, 3, 4])}
    
    # these are equal _roughly_
    dct3 = {'b': np.array([42.0, 0.2])}
    dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])}  # still 0.2, right?
    
    def compare_exact(first, second):
        """Return whether two dicts of arrays are exactly equal"""
        if first.keys() != second.keys():
            return False
        return all(np.array_equal(first[key], second[key]) for key in first)
    
    def compare_approximate(first, second):
        """Return whether two dicts of arrays are roughly equal"""
        if first.keys() != second.keys():
            return False
        return all(np.allclose(first[key], second[key]) for key in first)
    
    # let's try them:
    print(compare_exact(dct1, dct2))  # True
    print(compare_exact(dct3, dct4))  # False
    print(compare_approximate(dct3, dct4))  # True
    

    正如您在上面的示例中所看到的,整数数组比较准确,并且取决于您正在做什么(或者如果您很幸运),它甚至可以用于浮点数。但是,如果您的浮点数是任何算术的结果(例如线性变换?),您绝对应该使用近似检查。有关后一个选项的完整描述,请参阅the docs of numpy.allclose(及其元素朋友numpy.isclose),特别注意rtolatol 关键字参数。

    【讨论】:

      【解决方案3】:

      您可以分离两个字典的键和值,并比较键与键以及值与值: 这是解决方案

      import numpy as np
      
      def dic_to_keys_values(dic):
          keys, values = list(dic.keys()), list(dic.values())
          return keys, values
      
      def numpy_assert_almost_dict_values(dict1, dict2):
          keys1, values1 = dic_to_keys_values(dict1)
          keys2, values2 = dic_to_keys_values(dict2)
          np.testing.assert_equal(keys1, keys2)
          np.testing.assert_almost_equal(values1, values2)
      
      dict1 = {"b": np.array([1, 2, 0.2])}
      dict2 = {"b": np.array([1, 2, 3 * 0.1 - 0.1])}  # almost 0.2, but not equal
      dict3 = {"b": np.array([999, 888, 444])} # completely different
      
      numpy_assert_almost_dict_values(dict1, dict2) # no exception because almost equal
      # numpy_assert_almost_dict_values(dict1, dict3) # exception because not equal
      
      

      (注意,上面检查了精确的键和几乎相等的值)

      【讨论】:

        【解决方案4】:

        考虑这段代码

        >>> import numpy as np
        >>> np.identity(5)
        array([[ 1.,  0.,  0.,  0.,  0.],
               [ 0.,  1.,  0.,  0.,  0.],
               [ 0.,  0.,  1.,  0.,  0.],
               [ 0.,  0.,  0.,  1.,  0.],
               [ 0.,  0.,  0.,  0.,  1.]])
        >>> np.identity(5)+np.ones([5,5])
        array([[ 2.,  1.,  1.,  1.,  1.],
               [ 1.,  2.,  1.,  1.,  1.],
               [ 1.,  1.,  2.,  1.,  1.],
               [ 1.,  1.,  1.,  2.,  1.],
               [ 1.,  1.,  1.,  1.,  2.]])
        >>> np.identity(5) == np.identity(5)+np.ones([5,5])
        array([[False, False, False, False, False],
               [False, False, False, False, False],
               [False, False, False, False, False],
               [False, False, False, False, False],
               [False, False, False, False, False]], dtype=bool)
        >>> 
        

        注意比较的结果是一个矩阵,而不是一个布尔值。 dict比较将使用values cmp方法比较值,这意味着在比较矩阵值时,dict比较会得到一个复合结果。你想要做的是使用 numpy.all 将复合数组结果折叠成标量布尔结果

        >>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
        False
        >>> np.all(np.identity(5) == np.identity(5))
        True
        >>> 
        

        您需要编写自己的函数来比较这些字典,测试值类型以查看它们是否为矩阵,然后使用numpy.all 进行比较,否则使用==。当然,如果您愿意,您也可以随时花哨并开始子类化 dict 和重载 cmp

        【讨论】:

        • 我对此不是很清楚,但我希望有一种通用的方法,而不需要明确检查类型。今天它是一个numpy数组,明天它是我今天从未听说过的类型。
        • 恐怕没有办法绕过它。如果您(或 numpy 或其他人的)类型覆盖 cmp 以返回非标量,则标准 python 比较将无法处理它。
        • 您不需要编写自己的函数,因为 numpy 已经涵盖了您。请参阅 vitral 的回答。
        猜你喜欢
        • 2019-05-23
        • 2017-08-13
        • 1970-01-01
        • 2011-07-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2019-03-14
        • 1970-01-01
        相关资源
        最近更新 更多