diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 5ce55142b..562aca141 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1111,20 +1111,18 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame): @overload def map( self, - arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2], - na_action: Literal["ignore"] = ..., - ) -> Series[S2]: ... - @overload - def map( - self, - arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2], - na_action: None = None, + arg: ( + Callable[[S1], S2 | None | NAType] + | Mapping[S1, S2 | None | NAType] + | Series[S2] + ), + na_action: Literal["ignore"] | None = None, ) -> Series[S2]: ... @overload def map( self, - arg: Callable[[Any], Any] | Mapping[Any, Any] | Series, - na_action: Literal["ignore"] | None = ..., + arg: Callable[[Any], object] | Mapping[Any, object] | Series[Any], + na_action: Literal["ignore"] | None = None, ) -> Series: ... @overload def aggregate( diff --git a/tests/series/test_series.py b/tests/series/test_series.py index 6a41e339b..27ceb35cc 100644 --- a/tests/series/test_series.py +++ b/tests/series/test_series.py @@ -3700,6 +3700,7 @@ def test_map() -> None: pd.Series, str, ) + check(assert_type(s.map(mapping), "pd.Series[str]"), pd.Series, str) def callable(x: int) -> str: return str(x) @@ -3709,16 +3710,35 @@ def callable(x: int) -> str: pd.Series, str, ) + check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str) series = pd.Series(["a", "b", "c"]) check( assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str ) + check(assert_type(s.map(series), "pd.Series[str]"), pd.Series, str) unknown_series = pd.Series([1, 0, None]) check( - assert_type(unknown_series.map({1: True, 0: False, None: None}), pd.Series), + assert_type( + unknown_series.map({1: True, 0: False, None: None}), "pd.Series[bool]" + ), + pd.Series, + bool, + ) + check( + assert_type( + unknown_series.map({1: True, 0: False, None: None}, na_action="ignore"), + "pd.Series[bool]", + ), pd.Series, + bool, + ) + s_mixed = pd.Series([1, "a"]) + check( + assert_type(s_mixed.map({1: 1.0, "a": 2.0}), "pd.Series[float]"), + pd.Series, + float, ) @@ -3736,10 +3756,22 @@ def callable(x: int | NAType) -> str | NAType: check( assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str ) + # na_action defaults to None + check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str) series = pd.Series(["a", "b", "c"]) check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str) + def callable2(x: int | NAType | None) -> str | None: + if isinstance(x, int): + return str(x) + return None + + check( + assert_type(s.map(callable2, na_action=None), "pd.Series[str]"), pd.Series, str + ) + check(assert_type(s.map(callable2), "pd.Series[str]"), pd.Series, str) + def test_case_when() -> None: c = pd.Series([6, 7, 8, 9], name="c")