代码之家  ›  专栏  ›  技术社区  ›  nyanpasu64

如何比较持有numpy.ndarray(bool(a==b)引发ValueError)的数据类的相等性?

  •  1
  • nyanpasu64  · 技术社区  · 6 年前

    如果我创建了一个包含Numpy ndarray的Python数据类,我就不能再使用自动生成的 __eq__ 不再。

    import numpy as np
    
    @dataclass
    class Instr:
        foo: np.ndarray
        bar: np.ndarray
    
    arr = np.array([1])
    arr2 = np.array([1, 2])
    print(Instr(arr, arr) == Instr(arr2, arr2))
    

    ValueError:具有多个元素的数组的真值不明确。使用a.any()或a.all()

    这是因为 ndarray.__eq__ 有时 返回 ndarray 通过比较 a[0] b[0] ,以此类推,直到2中较长的一个。这是非常复杂和不直观的,事实上只有当数组是不同的形状,或者具有不同的值或其他东西时才会引发错误。

    如何安全地比较 @dataclass 持有核弹阵列?


    @数据类 的实现 __情商__ 是使用 eval() . 堆栈跟踪中缺少其源,无法使用 inspect ,但实际上它使用了 元组比较 ,它调用bool(foo)。

    import dis
    dis.dis(Instr.__eq__)
    

    节选:

      3          12 LOAD_FAST                0 (self)
                 14 LOAD_ATTR                1 (foo)
                 16 LOAD_FAST                0 (self)
                 18 LOAD_ATTR                2 (bar)
                 20 BUILD_TUPLE              2
                 22 LOAD_FAST                1 (other)
                 24 LOAD_ATTR                1 (foo)
                 26 LOAD_FAST                1 (other)
                 28 LOAD_ATTR                2 (bar)
                 30 BUILD_TUPLE              2
                 32 COMPARE_OP               2 (==)
                 34 RETURN_VALUE
    
    1 回复  |  直到 6 年前
        1
  •  2
  •   FHTMitchell    5 年前

    解决办法是你自己 __eq__ 方法和集合 eq=False 所以数据类不会生成自己的(尽管检查 docs 最后一步是不必要的,但我认为还是要明确一点。

    import numpy as np
    
    def array_eq(arr1, arr2):
        return (isinstance(arr1, np.ndarray) and
                isinstance(arr2, np.ndarray) and
                arr1.shape == arr2.shape and
                (arr1 == arr2).all())
    
    @dataclass(eq=False)
    class Instr:
    
        foo: np.ndarray
        bar: np.ndarray
    
        def __eq__(self, other):
            if not isinstance(other, Instr):
                return NotImplemented
            return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
    

    编辑

    通用数据类的通用快速解决方案,其中一些值是numpy数组,而另一些不是

    import numpy as np
    from dataclasses import dataclass, astuple
    
    def array_safe_eq(a, b) -> bool:
        """Check if a and b are equal, even if they are numpy arrays"""
        if a is b:
            return True
        if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
            return a.shape == b.shape and (a == b).all()
        try:
            return a == b
        except TypeError:
            return NotImplemented
    
    def dc_eq(dc1, dc2) -> bool:
       """checks if two dataclasses which hold numpy arrays are equal"""
       if dc1 is dc2:
            return True
       if dc1.__class__ is not dc2.__class__:
           return NotImplmeneted  # better than False
       t1 = astuple(dc1)
       t2 = astuple(dc2)
       return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
    
    # usage
    @dataclass(eq=False)
    class T:
    
       a: int
       b: np.ndarray
       c: np.ndarray
    
       def __eq__(self, other):
            return dc_eq(self, other)