代码之家  ›  专栏  ›  技术社区  ›  rollcat

Mypy:寻找平均函数的完美签名

  •  4
  • rollcat  · 技术社区  · 7 年前

    def avg(xs):
        it = iter(xs)
        try:
            s = next(it)
            i = 1
        except StopIteration:
            raise ValueError("Cannot average empty sequence")
        for x in it:
            s += x
            i += 1
        return s / i
    

    这段代码的优点在于,它可以与iterables一起工作,并为iterables生成正确的结果 int , float complex ,但也适用于 datetime.timedelta . 尝试添加签名时会出现问题。我尝试了以下方法:

    def avg(xs: t.Iterable[t.Any]) -> t.Any: ...
    

    def avg(xs: t.Iterable[T]) -> T: ...
    

    这失败是因为 T

    N = TypeVar("N", int, float, complex, datetime.timedelta)
    def avg(xs: t.Iterable[N]) -> N: ...
    

    失败是因为 int / int 是一个 浮动 // 对于几乎所有其他内容,都给出了错误的结果。也很糟糕,因为只要支持加法和除法,代码就应该适用于其他类型。

    N = TypeVar("N", float, complex, datetime.timedelta)
    def avg(xs: t.Iterable[N]) -> N: ...
    

    ...然后我还试着用 abc typing.overload

    最优雅的解决方案是什么 mypy --strict

    1 回复  |  直到 7 年前
        1
  •  1
  •   Michael0x2a    7 年前

    因此,不幸的是,Python/PEP 484中的数字系统目前有点混乱。

    从技术上讲,我们有一个 "numeric tower" 这应该表示一组ABC,Python中所有“类似数字”的实体都应该遵循这些ABC。

    此外,Python中的许多内置类型(例如 int , float , complex timedelta

    使问题复杂化的是 numbers module is largely dynamically typed

    目前,您需要做的是:

    1. 安装 typing_extensions 模块( python3 -m pip install typing_extensions )
    2. python3 -m pip install -U git+git://github.com/python/mypy.git

    然后,我们可以为“支持添加或分割”类型定义协议,如下所示:

    from datetime import timedelta
    
    from typing import TypeVar, Iterable
    from typing_extensions import Protocol
    
    T = TypeVar('T')
    S = TypeVar('S', covariant=True)
    
    class SupportsAddAndDivide(Protocol[S]):
        def __add__(self: T, other: T) -> T: ...
    
        def __truediv__(self, other: int) -> S: ...
    
    def avg(xs: Iterable[SupportsAddAndDivide[S]]) -> S:
        it = iter(xs)
        try:
            s = next(it)
            i = 1
        except StopIteration:
            raise ValueError("Cannot average empty sequence")
        for x in it:
            s += x
            i += 1
        return s / i
    
    reveal_type(avg([1, 2, 3]))
    reveal_type(avg([3.24, 4.22, 5.33]))
    reveal_type(avg([3 + 2j, 3j]))
    reveal_type(avg([timedelta(1), timedelta(2), timedelta(3)]))
    

    test.py:27: error: Revealed type is 'builtins.float*'
    test.py:28: error: Revealed type is 'builtins.float*'
    test.py:29: error: Revealed type is 'builtins.complex*'
    test.py:30: error: Revealed type is 'datetime.timedelta*'