Skip to content

Commit 33c9d13

Browse files
committed
Support fake globals in Python generic classes' __init__ methods
1 parent 2ca8f95 commit 33c9d13

File tree

2 files changed

+67
-34
lines changed

2 files changed

+67
-34
lines changed

Lib/annotationlib.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,17 @@ def _get_annotate_attr(annotate, attr, default):
897897
if isinstance(annotate, types.MethodType):
898898
return _get_annotate_attr(annotate.__func__, attr, default)
899899

900+
# Python generics are callable. Usually, the __init__ method sets attributes.
901+
# However, typing._BaseGenericAlias overrides the __init__ method, so we need
902+
# to use the original class method for fake globals and the like.
903+
# _BaseGenericAlias also override __call__, so let's handle this earlier than
904+
# other class construction.
905+
if (
906+
(typing := sys.modules.get("typing", None))
907+
and isinstance(annotate, typing._BaseGenericAlias)
908+
):
909+
return getattr(annotate.__origin__.__init__, attr, default)
910+
900911
# If annotate is a class instance, its __call__ is the relevant function.
901912
# However, __call__ Could be a method, a function descriptor, or any other callable.
902913
# Normal functions have a __call__ property which is a useless method wrapper,
@@ -937,6 +948,22 @@ def _direct_call_annotate(func, annotate, *args):
937948
# argument.
938949
return _direct_call_annotate(func, annotate.__func__, self, *args)
939950

951+
# Python generics (typing._BaseGenericAlias) override __call__, so let's handle
952+
# them earlier than other class construction.
953+
if (
954+
(typing := sys.modules.get("typing", None))
955+
and isinstance(annotate, typing._BaseGenericAlias)
956+
):
957+
inst = annotate.__new__(annotate.__origin__)
958+
func(inst, *args)
959+
# Try to set the original class on the instance, if possible.
960+
# This is the same logic used in typing for custom generics.
961+
try:
962+
inst.__orig_class__ = annotate
963+
except Exception:
964+
pass
965+
return inst
966+
940967
# If annotate is a class instance, its __call__ is the function.
941968
# __call__ Could be a method, a function descriptor, or any other callable.
942969
# Normal functions have a __call__ property which is a useless method wrapper,
@@ -954,7 +981,8 @@ def _direct_call_annotate(func, annotate, *args):
954981
func(inst, *args)
955982
return inst
956983

957-
# Generic instantiation is slightly different.
984+
# Generic instantiation is slightly different. Since we want to give
985+
# __call__ priority, the custom logic for builtin generics is here.
958986
if isinstance(annotate, types.GenericAlias):
959987
inst = annotate.__new__(annotate.__origin__)
960988
func(inst, *args)

Lib/test/test_annotationlib.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,44 @@ def __init__(self, format, /, __Format=Format,
15591559
# We manually set the __orig_class__ for this special-case, check this too.
15601560
self.assertEqual(annotations.__orig_class__, Annotate[str, type])
15611561

1562+
def test_callable_typing_generic_class_annotate_forwardref_fakeglobals(self):
1563+
# Normally, generics are 'typing._GenericAlias' objects. These are implemented
1564+
# in Python with a __call__ method (in _typing.BaseGenericAlias), but this
1565+
# needs to be bypassed so we can inject fake globals into the origin class'
1566+
# __init__ method.
1567+
class Annotate[T]:
1568+
def __init__(self, format, /, __Format=Format,
1569+
__NotImplementedError=NotImplementedError):
1570+
if format == __Format.VALUE:
1571+
self.data = {'x': str}
1572+
elif format == __Format.VALUE_WITH_FAKE_GLOBALS:
1573+
self.data = {"x": int}
1574+
else:
1575+
raise __NotImplementedError(format)
1576+
def __getitem__(self, item):
1577+
return self.data[item]
1578+
def __iter__(self):
1579+
return iter(self.data)
1580+
def __len__(self):
1581+
return len(self.data)
1582+
def __getattr__(self, attr):
1583+
val = getattr(collections.abc.Mapping, attr)
1584+
if isinstance(val, types.FunctionType):
1585+
return types.MethodType(val, self)
1586+
return val
1587+
def __eq__(self, other):
1588+
return dict(self.items()) == dict(other.items())
1589+
1590+
annotations = annotationlib.call_annotate_function(
1591+
Annotate[int],
1592+
Format.FORWARDREF,
1593+
)
1594+
1595+
self.assertEqual(annotations, {"x": int})
1596+
1597+
# We manually set the __orig_class__ for this special-case, check this too.
1598+
self.assertEqual(annotations.__orig_class__, Annotate[int])
1599+
15621600
def test_user_annotate_forwardref_value_fallback(self):
15631601
# If Format.FORWARDREF and Format.VALUE_WITH_FAKE_GLOBALS are not supported
15641602
# use Format.VALUE
@@ -1808,39 +1846,6 @@ class Annotate:
18081846

18091847
self.assertEqual(annotations, {"x": int})
18101848

1811-
def test_callable_typing_generic_class_annotate_forwardref_value_fallback(self):
1812-
# Normally, generics are 'typing._GenericAlias' objects. These are implemented
1813-
# in Python with a __call__ method (in _typing.BaseGenericAlias), so should work
1814-
# as with any callable class instance.
1815-
class Annotate[T]:
1816-
def __init__(self, format, /, __Format=Format,
1817-
__NotImplementedError=NotImplementedError):
1818-
if format == __Format.VALUE_WITH_FAKE_GLOBALS:
1819-
self.data = {"x": int}
1820-
else:
1821-
raise __NotImplementedError(format)
1822-
def __getitem__(self, item):
1823-
return self.data[item]
1824-
def __iter__(self):
1825-
return iter(self.data)
1826-
def __len__(self):
1827-
return len(self.data)
1828-
def __getattr__(self, attr):
1829-
val = getattr(collections.abc.Mapping, attr)
1830-
if isinstance(val, types.FunctionType):
1831-
return types.MethodType(val, self)
1832-
return val
1833-
def __eq__(self, other):
1834-
return dict(self.items()) == dict(other.items())
1835-
1836-
annotations = annotationlib.call_annotate_function(
1837-
Annotate[int],
1838-
Format.FORWARDREF,
1839-
)
1840-
1841-
self.assertEqual(annotations, {"x": int})
1842-
self.assertEqual(annotations.__orig_class__, Annotate[int])
1843-
18441849
def test_callable_partial_annotate_forwardref_value_fallback(self):
18451850
# functools.partial is implemented in C. Ensure that the annotate function
18461851
# is extracted and called correctly, particularly with Placeholder args.

0 commit comments

Comments
 (0)