首先要声明的是,尽管 Python 是一种动态类型1的语言,但是 Python 在 3.x 版本中还是引入了类型标注 (Type Annotation) 机制,使得我们可以以一种类似静态语言的方式来标注函数入参、返回的类型,或是类的成员类型。这种方式看似使得代码变得更加繁琐,但是合理的类型标注可以被编辑器/IDE 的静态检查机制所利用,进而提供更丰富的代码补全服务和静态错误排查服务。这使得我们在使用 Python 开发结构非常复杂的软件的时候能够提高编写效率,并减少大部分低级错误。

在类型声明中,泛型(Generic)机制是最为复杂的一种。本文着眼于 Python 的泛型机制的实现,而非细致的类型标注介绍。如果对类型标注和泛型缺乏基本了解,建议检索相关教程文章阅读之后再来阅读本文。


Python 在 3.6 到 3.7 版本的迭代中对于泛型的实现做出了重大的改变。我一开始是从 3.6 开始的。

使用泛型可以分为两个层面。首先是只是做进阶版本的类型提升。例如实现一个字典,我们可以用泛型来进行更加通用的标注:

1
2
3
4
5
6
7
8
9
10
from typing import TypeVar, Generic, Hashable)

_KT = TypeVar("_KT", bound=Hashable)
_VT = TypeVar("_VT")

class SomeDict(Generic[_KT, _VT]):

def __getitem__(self, key: _KT) -> _VT:
# implementation

这时,泛型信息只是被静态类型检查系统(如 Pylance、mypy)所使用,在运行时(runtime),这些泛型信息其实并没有直接其作用。换言之,如果在实现层面,上面 __getitem__ 函数完全可能返回非 _VT 类型的数据。如果我们想在运行时检查返回的类型,这就需要我们动态地获取构造泛型对象的时传给类的类型参数。这个问题就比较复杂了,这也是我这篇文章要探讨的重点。

1 Python 3.6 中的泛型实现

1.1 从泛型类中获取类型参数

我们来看看 Python 3.6 源码中 Generic 的实现。Generic 源码的实现位于 typing.py 文件中,你可以从 Github 上找到这个文件的源码(typing.py)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Generic(metaclass=GenericMeta):
"""Abstract base class for generic types.
A generic type is typically declared by inheriting from
this class parameterized with one or more type variables.
For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]):
def __getitem__(self, key: KT) -> VT:
...
# Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
try:
return mapping[key]
except KeyError:
return default
"""

__slots__ = ()

def __new__(cls, *args, **kwds):
if cls._gorg is Generic:
raise TypeError("Type Generic cannot be instantiated; "
"it can be used only as a base class")
return _generic_new(cls.__next_in_mro__, cls, *args, **kwds)

def _generic_new(base_cls, cls, *args, **kwds):
# Assure type is erased on instantiation,
# but attempt to store it in __orig_class__
if cls.__origin__ is None:
if (base_cls.__new__ is object.__new__ and
cls.__init__ is not object.__init__):
return base_cls.__new__(cls)
else:
return base_cls.__new__(cls, *args, **kwds)
else:
origin = cls._gorg
if (base_cls.__new__ is object.__new__ and
cls.__init__ is not object.__init__):
obj = base_cls.__new__(origin)
else:
obj = base_cls.__new__(origin, *args, **kwds)
try:
obj.__orig_class__ = cls
except AttributeError:
pass
obj.__init__(*args, **kwds)
return obj

Generic 的基类非常简单,它只覆盖了 __new__ 的实现。我们需要注意的是 Generic 的元类 GenericMeta。观察 GenericMeta 的源码,我们可以发现这个类实现了 __getitem__ 函数,联想到我们定义泛型类的时候会写 Generic[_T] 的形式,这实际上是在调用 GenericMeta__getitem__ 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class GenericMeta(TypingMeta, abc.ABCMeta):
# ...
@_tp_cache
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
if not params and self._gorg is not Tuple:
raise TypeError(
"Parameter list to %s[...] cannot be empty" % _qualname(self))
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
if self is Generic:
# Generic can only be subscripted with unique type variables.
if not all(isinstance(p, TypeVar) for p in params):
raise TypeError(
"Parameters to Generic[...] must all be type variables")
if len(set(params)) != len(params):
raise TypeError(
"Parameters to Generic[...] must all be unique")
tvars = params
args = params
elif self in (Tuple, Callable):
tvars = _type_vars(params)
args = params
elif self is _Protocol:
# _Protocol is internal, don't check anything.
tvars = params
args = params
elif self.__origin__ in (Generic, _Protocol):
# Can't subscript Generic[...] or _Protocol[...].
raise TypeError("Cannot subscript already-subscripted %s" %
repr(self))
else:
# Subscripting a regular Generic subclass.
_check_generic(self, params)
tvars = _type_vars(params)
args = params

prepend = (self,) if self.__origin__ is None else ()
return self.__class__(self.__name__,
prepend + self.__bases__,
_no_slots_copy(self.__dict__),
tvars=tvars,
args=args,
origin=self,
extra=self.__extra__,
orig_bases=self.__orig_bases__)

注意函数的返回,这个函数设计使用 self.__class__,也即使用元类创建了一个新的类,这个创建过程实际上调用的是 GenericMeta__new__ 方法、

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __new__(cls, name, bases, namespace,
tvars=None, args=None, origin=None, extra=None, orig_bases=None):
"""Create a new generic class. GenericMeta.__new__ accepts
keyword arguments that are used for internal bookkeeping, therefore
an override should pass unused keyword arguments to super().
"""
if tvars is not None:
# Called from __getitem__() below.
assert origin is not None
assert all(isinstance(t, TypeVar) for t in tvars), tvars
else:
# Called from class statement.
assert tvars is None, tvars
assert args is None, args
assert origin is None, origin

# Get the full set of tvars from the bases.
tvars = _type_vars(bases)
# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...].
gvars = None
for base in bases:
if base is Generic:
raise TypeError("Cannot inherit from plain Generic")
if (isinstance(base, GenericMeta) and
base.__origin__ is Generic):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
raise TypeError(
"Some type variables (%s) "
"are not listed in Generic[%s]" %
(", ".join(str(t) for t in tvars if t not in gvarset),
", ".join(str(g) for g in gvars)))
tvars = gvars

initial_bases = bases
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
bases = (extra,) + bases
bases = tuple(b._gorg if isinstance(b, GenericMeta) else b for b in bases)

# remove bare Generic from bases if there are other generic bases
if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
bases = tuple(b for b in bases if b is not Generic)
namespace.update({'__origin__': origin, '__extra__': extra,
'_gorg': None if not origin else origin._gorg})
self = super().__new__(cls, name, bases, namespace, _root=True)
super(GenericMeta, self).__setattr__('_gorg',
self if not origin else origin._gorg)
self.__parameters__ = tvars
# Be prepared that GenericMeta will be subclassed by TupleMeta
# and CallableMeta, those two allow ..., (), or [] in __args___.
self.__args__ = tuple(... if a is _TypingEllipsis else
() if a is _TypingEmpty else
a for a in args) if args else None
# Speed hack (https://github.com/python/typing/issues/196).
self.__next_in_mro__ = _next_in_mro(self)
# Preserve base classes on subclassing (__bases__ are type erased now).
if orig_bases is None:
self.__orig_bases__ = initial_bases

# This allows unparameterized generic collections to be used
# with issubclass() and isinstance() in the same way as their
# collections.abc counterparts (e.g., isinstance([], Iterable)).
if (
'__subclasshook__' not in namespace and extra or
# allow overriding
getattr(self.__subclasshook__, '__name__', '') == '__extrahook__'
):
self.__subclasshook__ = _make_subclasshook(self)
if isinstance(extra, abc.ABCMeta):
self._abc_registry = extra._abc_registry
self._abc_cache = extra._abc_cache
elif origin is not None:
self._abc_registry = origin._abc_registry
self._abc_cache = origin._abc_cache

if origin and hasattr(origin, '__qualname__'): # Fix for Python 3.2.
self.__qualname__ = origin.__qualname__
self.__tree_hash__ = (hash(self._subs_tree()) if origin else
super(GenericMeta, self).__hash__())
return self

注意 __new__ 的入参和 __getitem__ 返回处调用 __class__ 构造函数时的入参对应关系。在上面的 63 行,GenericMeta 在新构造的类中设置了 __args__ 属性,而这些属性就是构造泛型类的入参。考虑开始的 SomeDict 的例子。我们可以使用下面的方法来获取自定义字典类的 Key 与 Value 的类型。

1
2
3
dict_class = SomeDict[str, int]
key_type = dict_class.__args__[0]
val_type = dict_class.__args__[1]

1.2 从泛型示例中获取泛型参数

注意这里我们是从泛型的类中获取参数,那么在实例里面呢?我们需要注意到 Generic__new__ 方法, 这个方法又调用了 _generic_new 来创建实际的实例对象。注意到 _generic_new 最终是使用 base_cls 参数指定的类来创建的对象,而非 cls。注意到 _generic_new 的调用方式是:

1
_generic_new(cls.__next_in_mro__, cls, *args, **kwds)

即实际创建的对象并非是 cls。以 dict_class = SomeDict[str, int] 为例,这里 dict_class 是通过 GenericMeta.__getitem__ 函数调用 GenericMeta.__new__ 函数创建出来的一个新的类。我们利用 dict_class 来创建实例的时候,其实创建出来的并非 dict_class 类型的实例,而是 dict_class.__next_in_mro__ 这个特殊变量的保存的类的实例。要搞清楚这些关系,我们还是要回到 GenericMeta.__new__ 这个函数的实现中。

注意到第 55-57 行:

1
2
3
namespace.update({'__origin__': origin, '__extra__': extra,
'_gorg': None if not origin else origin._gorg})
self = super().__new__(cls, name, bases, namespace, _root=True)

新创建出来的类(即 dict_class)的 __origin__ 属性(也即 _generic_new 函数中的 cls.__origin__),记录了泛型模板的类型。注意 GenericMeta.__new__ 除了可以通过 GenericMeta.__getitem__ 触发以外,还可以通过正常的类继承过程触发。此时 origin 会是空的。反过来看,如果一个泛型类是通过 GenericMeta.__getitem__ 类创建的,那么他会拥有 __origin__ 属性来指向其模板类;而如果这个类是通过类集成产生的,则这个新类不会拥有 __origin__ 属性。

这里的泛型模板说的是在泛型定义语句中位于方括号左侧的那个类,相对于 dict_class,其泛型模板就是 SomeDict

那么 __next_in_mro__ 又是什么呢?mro 的全称是 Method Resolution Order,2意指在类型集成树中,尤其是多继承的系统中,子类解析父类方法的一个顺序。这个属性同样是 GenericMeta.__new__ 方法中设置的。注意第 67 行

1
self.__next_in_mro__ = _next_in_mro(self)

其中 _next_in_mro 的实现是:

1
2
3
4
5
6
7
8
9
10
11
def _next_in_mro(cls):
"""Helper for Generic.__new__.
Returns the class after the last occurrence of Generic or
Generic[...] in cls.__mro__.
"""
next_in_mro = object
# Look for the last occurrence of Generic or Generic[...].
for i, c in enumerate(cls.__mro__[:-1]):
if isinstance(c, GenericMeta) and c._gorg is Generic:
next_in_mro = cls.__mro__[i + 1]
return next_in_mro

这里的 __mro__ 是 Python 类的一个内建属性,用来记录 Python 类的父类方法解相顺序。我们还需要理解 _gorg 的含义。这个属性的设置代码位于 GenericMeta.__new__ 的 55 - 58 行

1
2
3
4
5
namespace.update({'__origin__': origin, '__extra__': extra,
'_gorg': None if not origin else origin._gorg})
self = super().__new__(cls, name, bases, namespace, _root=True)
super(GenericMeta, self).__setattr__('_gorg',
self if not origin else origin._gorg)

_gorg 的设置逻辑可以表示为:如果当前创建的类没有 origin,_gorg 会设置成类自身,否则会设置成 origin_gorg 属性。注意 origin 实际上代表了一个不可示例化的泛型模板通过 Generic[...] 形式的调用转化成可实例化的泛型类的过程中,新的可实例化的类与其模板之间的关系。这么看 _gorg__origin__ 具有相似的作用。但是 _gorgorigin 为空时,可以指向类本身。回过头看 _next_in_mro,我们可以发现这个函数是查找到的在方法解析树上 Generic 声明后的一个可实例化类,大多数情况下这里获取的是 object 类。单如果你声明 Generic 子类的时候使用了多继承,那么 _next_in_mro 会找到出现在 Generic 声明后面的类。考虑下面的例子:

1
2
3
4
5
6
class A:
pass


class B(Generic[T], A):
pass

这里 _next_in_mro 找出来的就会是 A。OK,现在我们回到 _generic_new 函数,注意尽管代码的形式是 obj = base_cls.__new__(origin, *args, **kwds),但是事实上被实例化的并非 base_cls,而是 origin 类型,这种写法的意思是使用 base_cls__new__ 方法去构造 origin,而不是使用 origin__new__ 方法。由于 base_cls 可以确保不是 Generic 的子类(_next_in_mro 的作用),因此此处的调用不会出现 _generic_new 的无限递归调用。

现在我们弄清楚了 _generic_new 的构造过程,知道了最终构造出来的示例是 origin 类型。在 dict_class = SomeDict[str, int] 的例子中,dict_class() 构造出来的对象就是 SomeDict 类型的。这也意味着在 dict_class 中调用 __getitem__ 方法,方法内部使用 type(self) 或者 self.__class__ 拿到的是 SomeDict,这个类自然不会包含 [str, int] 的类型信息。

再来看 _genric_new 函数,在创建 obj = base_cls.__new__(origin, *args, **kwds) 创建出对象之后,还执行了下面的操作:

1
2
3
4
try:
obj.__orig_class__ = cls
except AttributeError:
pass

注意这里的 cls 是指来自 Generic.__new__ 的第一个参数,也即你构造泛型类对象时形式上的这个类。放到 dict_class = SomeDict[str, int] 的例子中,你在使用 dict_class() 创建示例时,__orig_class__ 指向的就是 dict_class 本身。这时我们就可以解决从实例中获取泛型参数的问题了。我们通过 SomeDict.__getitem__ 函数中加入类型检查的例子来说明这个问题:

1
2
3
4
5
6
7
8
class SomeDict(Generic[_KT, _VT]):

def __getitem__(self, key: _KT) -> _VT:
# get value
value = ...
value_type = self.__orig_class__.__args__[1]
if not isinstance(value, value_type):
raise ValueType("value type is inconsistent with dict declaration")

1.3 后记:处理前向声明

为了解决循环引用的问题,有时候我们会使用字符串表示的前向引用来声明泛型的类型参数。例如:

1
2
3
4
5
6
7
8
class A(Generic[_T]):
pass

AA = A["B"]


class B:
pass

这时,你使用我们上面提到的手段来获取泛型的类型参数拿到的并不是你所期望的类型本身,而会是一个名为 _ForwardRef 特殊的类型。这个类定义在标准库的 typing.py 文件中。观察 GenericMeta.__getitem__ 函数,可以发现我们传入的泛型类型参数会经过 _type_check 函数进行一次映射,这个函数的定义是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def _type_check(arg, msg):
"""Check that the argument is a type, and return it (internal helper).
As a special case, accept None and return type(None) instead.
Also, _TypeAlias instances (e.g. Match, Pattern) are acceptable.
The msg argument is a human-readable error message, e.g.
"Union[arg, ...]: arg should be a type."
We append the repr() of the actual value (truncated to 100 chars).
"""
if arg is None:
return type(None)
if isinstance(arg, str):
arg = _ForwardRef(arg)
if (
isinstance(arg, _TypingBase) and type(arg).__name__ == '_ClassVar' or
not isinstance(arg, (type, _TypingBase)) and not callable(arg)
):
raise TypeError(msg + " Got %.100r." % (arg,))
# Bare Union etc. are not valid as type arguments
if (
type(arg).__name__ in ('_Union', '_Optional') and
not getattr(arg, '__origin__', None) or
isinstance(arg, TypingMeta) and arg._gorg in (Generic, _Protocol)
):
raise TypeError("Plain %s is not valid as type argument" % arg)
return arg

可以看到,当类型输入是字符串类型时,这个函数会使用 _ForwardRef 包裹对应的字符串。为了获取真正的类型,我们需要根据这个前向声明类的对象获取到真正类对象。我们来看 _ForwardRef 的实现(省略了一些次要函数):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class _ForwardRef(_TypingBase, _root=True):
"""Internal wrapper to hold a forward reference."""

__slots__ = ('__forward_arg__', '__forward_code__',
'__forward_evaluated__', '__forward_value__')

def __init__(self, arg):
super().__init__(arg)
if not isinstance(arg, str):
raise TypeError('Forward reference must be a string -- got %r' % (arg,))
try:
code = compile(arg, '<string>', 'eval')
except SyntaxError:
raise SyntaxError('Forward reference must be an expression -- got %r' %
(arg,))
self.__forward_arg__ = arg
self.__forward_code__ = code
self.__forward_evaluated__ = False
self.__forward_value__ = None

def _eval_type(self, globalns, localns):
if not self.__forward_evaluated__ or localns is not globalns:
if globalns is None and localns is None:
globalns = localns = {}
elif globalns is None:
globalns = localns
elif localns is None:
localns = globalns
self.__forward_value__ = _type_check(
eval(self.__forward_code__, globalns, localns),
"Forward references must evaluate to types.")
self.__forward_evaluated__ = True
return self.__forward_value__

这个类定义了一个 _eval_type 函数,其接受两个命名域,并输出真正的类型。因此我们要想办法维护合适的命名域。这里所说的命名域本质上是一个从类名到类对象的大字典。

此处我就不展开讲如何维护一个命名域了,后续有精力了再整理这些内容。

2 新版本的 Python 中泛型实现的变化

PEP-560 中,Python 对泛型系统的实现进行了重新设计。

To Be Continued

https://github.com/python/typing/issues/629 https://hanjianwei.com/2013/07/25/python-mro/


  1. 静态类型和动态类型有什么区别?↩︎

  2. Python 的方法解析顺序(MRO)↩︎