Skip to content

Commit b749c49

Browse files
committed
Support arbitrary callables as __init__ methods for class annotate functions
1 parent 4acb56b commit b749c49

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

Lib/annotationlib.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -891,8 +891,11 @@ def _stringify_single(anno):
891891

892892

893893
def _get_annotate_attr(annotate, attr, default):
894-
if (value := getattr(annotate, attr, None)) is not None:
895-
return value
894+
# Try to get the attr on the annotate function. If it doesn't exist, we might
895+
# need to look in other places on the object. If all of those fail, we can
896+
# return the default at the end.
897+
if hasattr(annotate, attr):
898+
return getattr(annotate, attr)
896899

897900
# Redirect method attribute access to the underlying function. The C code
898901
# verifies that the __func__ attribute is some kind of callable, so we need
@@ -909,7 +912,7 @@ def _get_annotate_attr(annotate, attr, default):
909912
(typing := sys.modules.get("typing", None))
910913
and isinstance(annotate, typing._BaseGenericAlias)
911914
):
912-
return getattr(annotate.__origin__.__init__, attr, default)
915+
return _get_annotate_attr(annotate.__origin__.__init__, attr, default)
913916

914917
# If annotate is a class instance, its __call__ is the relevant function.
915918
# However, __call__ Could be a method, a function descriptor, or any other callable.
@@ -921,16 +924,17 @@ def _get_annotate_attr(annotate, attr, default):
921924
):
922925
return _get_annotate_attr(annotate.__call__, attr, default)
923926

924-
# Classes and generics are callable, usually the __init__ method sets attributes,
927+
# Classes and generics are callable. Usually the __init__ method sets attributes,
925928
# so let's access this method for fake globals and the like.
929+
# Technically __init__ can be any callable object, so we recurse.
926930
if isinstance(annotate, type) or isinstance(annotate, types.GenericAlias):
927-
return getattr(annotate.__init__, attr, default)
931+
return _get_annotate_attr(annotate.__init__, attr, default)
928932

929933
# Most 'wrapped' functions, including functools.cache and staticmethod, need us
930934
# to manually, recursively unwrap. For partial.update_wrapper functions, the
931935
# attribute is accessible on the function itself, so we never get this far.
932-
if (unwrapped := getattr(annotate, "__wrapped__", None)) is not None:
933-
return _get_annotate_attr(unwrapped, attr, default)
936+
if hasattr(annotate, "__wrapped__"):
937+
return _get_annotate_attr(annotate.__wrapped__, attr, default)
934938

935939
# Partial functions and methods both store their underlying function as a
936940
# func attribute. They can wrap any callable, so we need to recursively unwrap.
@@ -983,14 +987,16 @@ def _direct_call_annotate(func, annotate, *args):
983987
# __new__() to create the instance
984988
if isinstance(annotate, type):
985989
inst = annotate.__new__(annotate)
986-
func(inst, *args)
990+
# func might refer to some non-function object.
991+
_direct_call_annotate(func, annotate.__init__, inst, *args)
987992
return inst
988993

989994
# Generic instantiation is slightly different. Since we want to give
990995
# __call__ priority, the custom logic for builtin generics is here.
991996
if isinstance(annotate, types.GenericAlias):
992997
inst = annotate.__new__(annotate.__origin__)
993-
func(inst, *args)
998+
# func might refer to some non-function object.
999+
_direct_call_annotate(func, annotate.__init__, inst, *args)
9941000
# Try to set the original class on the instance, if possible.
9951001
# This is the same logic used in typing for custom generics.
9961002
try:

Lib/test/test_annotationlib.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,31 @@ def __init__(self, format, /, __Format=Format,
15351535

15361536
self.assertEqual(annotations, {"x": int})
15371537

1538+
def test_callable_class_custom_init_annotate_forwardref_fakeglobals(self):
1539+
# Calling the class will construct a new instance and call its __init__ function
1540+
# as an annotate function, except this __init__ is not a method,
1541+
# but a partial function.
1542+
def custom_init(self, second, format, /, __Format=Format,
1543+
__NotImplementedError=NotImplementedError):
1544+
if format == __Format.VALUE:
1545+
super(type(self), self).__init__({"x": str})
1546+
elif format == __Format.VALUE_WITH_FAKE_GLOBALS:
1547+
super(type(self), self).__init__({"x": second})
1548+
else:
1549+
raise __NotImplementedError(format)
1550+
1551+
class Annotate(dict):
1552+
pass
1553+
1554+
Annotate.__init__ = functools.partial(custom_init, functools.Placeholder, int)
1555+
1556+
annotations = annotationlib.call_annotate_function(
1557+
Annotate,
1558+
Format.FORWARDREF
1559+
)
1560+
1561+
self.assertEqual(annotations, {"x": int})
1562+
15381563
def test_callable_generic_class_annotate_forwardref_fakeglobals(self):
15391564
# Subscripted generic classes are types.GenericAlias instances
15401565
# for dict subclasses. Check that they are still
@@ -1990,7 +2015,6 @@ def _(format, /, __Format=Format,
19902015
else:
19912016
raise __NotImplementedError(format)
19922017

1993-
print("Single dispatch")
19942018
annotations = annotationlib.call_annotate_function(
19952019
format,
19962020
Format.FORWARDREF,

0 commit comments

Comments
 (0)