Skip to content

Commit 61b76a5

Browse files
committed
Support recursive unwrapping and calling of methods as annotate functions
1 parent 8df3d24 commit 61b76a5

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

Lib/annotationlib.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -891,9 +891,11 @@ def _get_annotate_attr(annotate, attr, default):
891891
if (value := getattr(annotate, attr, None)) is not None:
892892
return value
893893

894+
# Redirect method attribute access to the underlying function. The C code
895+
# verifies that the __func__ attribute is some kind of callable, so we need
896+
# to look for attributes recursively.
894897
if isinstance(annotate, types.MethodType):
895-
if call_func := getattr(annotate, "__func__", None):
896-
return getattr(call_func, attr, default)
898+
return _get_annotate_attr(annotate.__func__, attr, default)
897899

898900
# Class instances themselves aren't methods, their __call__ functions are.
899901
if isinstance(annotate.__call__, types.MethodType):
@@ -919,32 +921,35 @@ def _get_annotate_attr(annotate, attr, default):
919921

920922
return default
921923

922-
def _direct_call_annotate(func, annotate, format):
924+
def _direct_call_annotate(func, annotate, *args):
923925
# If annotate is a method, we need to pass self as the first param.
924926
if (
925927
hasattr(annotate, "__func__") and
926928
(self := getattr(annotate, "__self__", None))
927929
):
928-
return func(self, format)
930+
# We don't know what type of callable will be in the __func__ attribute,
931+
# so let's try again with knowledge of that type, including self as the first
932+
# argument.
933+
return _direct_call_annotate(func, annotate.__func__, self, *args)
929934

930935
# If annotate is a class instance, its __call__ function is the method.
931936
if (
932937
hasattr(annotate.__call__, "__func__") and
933938
(self := getattr(annotate.__call__, "__self__", None))
934939
):
935-
return func(self, format)
940+
return func(self, *args)
936941

937942
# If annotate is a class, `func` is the __init__ method, so we still need to call
938943
# __new__() to create the instance
939944
if isinstance(annotate, type):
940945
inst = annotate.__new__(annotate)
941-
func(inst, format)
946+
func(inst, *args)
942947
return inst
943948

944949
# Generic instantiation is slightly different.
945950
if isinstance(annotate, types.GenericAlias):
946951
inst = annotate.__new__(annotate.__origin__)
947-
func(inst, format)
952+
func(inst, *args)
948953
# Try to set the original class on the instance, if possible.
949954
try:
950955
inst.__orig_class__ = annotate
@@ -959,14 +964,14 @@ def _direct_call_annotate(func, annotate, format):
959964
if isinstance(annotate, functools.partial):
960965
# Partial methods
961966
if self := getattr(annotate, "__self__", None):
962-
return functools.partial(func, self, *annotate.args, **annotate.keywords)(format)
963-
return functools.partial(func, *annotate.args, **annotate.keywords)(format)
967+
return functools.partial(func, self, *annotate.args, **annotate.keywords)(*args)
968+
return functools.partial(func, *annotate.args, **annotate.keywords)(*args)
964969

965970
# If annotate is a cached function, we've now updated the function data, so
966971
# let's not use the old cache. Furthermore, we're about to call the function
967972
# and never use it again, so let's not bother trying to cache it.
968973
# Or, if it's a normal function or unsupported callable, we should just call it.
969-
return func(format)
974+
return func(*args)
970975

971976

972977
def get_annotate_from_class_namespace(obj):

Lib/test/test_annotationlib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,27 @@ def format(self, format, /, __Format=Format,
16321632

16331633
self.assertEqual(annotations, {"x": str})
16341634

1635+
def test_callable_custom_method_annotate_forwardref_value_fallback(self):
1636+
# If Format.STRING and Format.VALUE_WITH_FAKE_GLOBALS are not
1637+
# supported fall back to Format.VALUE and convert to strings
1638+
class Annotate(dict):
1639+
def __init__(inst, self, format, /, __Format=Format,
1640+
__NotImplementedError=NotImplementedError):
1641+
if format == __Format.VALUE:
1642+
super().__init__({"x": str})
1643+
else:
1644+
raise __NotImplementedError(format)
1645+
1646+
# This wouldn't happen on a normal class, but it's technically legal.
1647+
custom_method = types.MethodType(Annotate, Annotate(None, Format.VALUE))
1648+
1649+
annotations = annotationlib.call_annotate_function(
1650+
custom_method,
1651+
Format.FORWARDREF,
1652+
)
1653+
1654+
self.assertEqual(annotations, {"x": str})
1655+
16351656
def test_callable_classmethod_annotate_forwardref_value_fallback(self):
16361657
# If Format.STRING and Format.VALUE_WITH_FAKE_GLOBALS are not
16371658
# supported fall back to Format.VALUE and convert to strings

0 commit comments

Comments
 (0)