Skip to content

Commit a9a7f88

Browse files
committed
Add _direct_call_annotate() and support callable classes
1 parent b600f8c commit a9a7f88

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

Lib/annotationlib.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -734,12 +734,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
734734
argdefs=_get_annotate_attr(annotate, "__defaults__", None),
735735
kwdefaults=_get_annotate_attr(annotate, "__kwdefaults__", None),
736736
)
737-
if isinstance(annotate.__call__, types.MethodType):
738-
annos = func(
739-
annotate.__call__.__self__, Format.VALUE_WITH_FAKE_GLOBALS
740-
)
741-
else:
742-
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
737+
annos = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS)
743738
if _is_evaluate:
744739
return _stringify_single(annos)
745740
return {
@@ -791,10 +786,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
791786
kwdefaults=annotate_kwdefaults,
792787
)
793788
try:
794-
if isinstance(annotate.__call__, types.MethodType):
795-
result = func(annotate.__call__.__self__, Format.VALUE_WITH_FAKE_GLOBALS)
796-
else:
797-
result = func(Format.VALUE_WITH_FAKE_GLOBALS)
789+
result = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS)
798790
except NotImplementedError:
799791
# FORWARDREF and VALUE_WITH_FAKE_GLOBALS not supported, fall back to VALUE
800792
return annotate(Format.VALUE)
@@ -823,10 +815,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
823815
argdefs=annotate_defaults,
824816
kwdefaults=annotate_kwdefaults,
825817
)
826-
if isinstance(annotate.__call__, types.MethodType):
827-
result = func(annotate.__call__.__self__, Format.VALUE_WITH_FAKE_GLOBALS)
828-
else:
829-
result = func(Format.VALUE_WITH_FAKE_GLOBALS)
818+
result = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS)
830819
globals.transmogrify(cell_dict)
831820
if _is_evaluate:
832821
if isinstance(result, ForwardRef):
@@ -902,12 +891,29 @@ def _get_annotate_attr(annotate, attr, default):
902891
if (value := getattr(annotate, attr, None)) is not None:
903892
return value
904893

905-
if call_method := getattr(annotate, "__call__", None):
906-
if call_func := getattr(call_method, "__func__", None):
894+
if isinstance(annotate.__call__, types.MethodType):
895+
if call_func := getattr(annotate.__call__, "__func__", None):
907896
return getattr(call_func, attr, default)
897+
elif isinstance(annotate, type):
898+
return getattr(annotate.__init__, attr, default)
908899

909900
return default
910901

902+
def _direct_call_annotate(func, annotate, format):
903+
# If annotate is a method, we need to pass its self as the first param
904+
if (
905+
hasattr(annotate.__call__, "__func__") and
906+
(self := getattr(annotate.__call__, "__self__", None))
907+
):
908+
return func(self, format)
909+
910+
if isinstance(annotate, type):
911+
inst = annotate.__new__(annotate)
912+
func(inst, format)
913+
return inst
914+
915+
return func(format)
916+
911917

912918
def get_annotate_from_class_namespace(obj):
913919
"""Retrieve the annotate function from a class namespace dictionary.

0 commit comments

Comments
 (0)