Skip to content

Commit 00baef9

Browse files
committed
Simplify and add more tests
1 parent 14f610c commit 00baef9

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

pandas-stubs/core/series.pyi

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,24 +1103,15 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
11031103
fill_value: int | _str | dict | None = None,
11041104
sort: _bool = True,
11051105
) -> DataFrame: ...
1106-
@overload
1107-
def map(
1108-
self,
1109-
arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2],
1110-
na_action: Literal["ignore"],
1111-
) -> Series[S2]: ...
1112-
@overload
1113-
def map(
1114-
self,
1115-
arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2],
1116-
na_action: None = None,
1117-
) -> Series[S2]: ...
1118-
@overload
11191106
def map(
11201107
self,
1121-
arg: Callable[[Any], Any] | Mapping[Any, Any] | Series,
1108+
arg: (
1109+
Callable[[S1], S2 | None | NAType]
1110+
| Mapping[S1, S2 | None | NAType]
1111+
| Series[S2]
1112+
),
11221113
na_action: Literal["ignore"] | None = None,
1123-
) -> Series: ...
1114+
) -> Series[S2]: ...
11241115
@overload
11251116
def aggregate(
11261117
self: Series[int],

tests/series/test_series.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3584,6 +3584,7 @@ def test_map() -> None:
35843584
pd.Series,
35853585
str,
35863586
)
3587+
check(assert_type(s.map(mapping), "pd.Series[str]"), pd.Series, str)
35873588

35883589
def callable(x: int) -> str:
35893590
return str(x)
@@ -3593,16 +3594,29 @@ def callable(x: int) -> str:
35933594
pd.Series,
35943595
str,
35953596
)
3597+
check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str)
35963598

35973599
series = pd.Series(["a", "b", "c"])
35983600
check(
35993601
assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str
36003602
)
3603+
check(assert_type(s.map(series), "pd.Series[str]"), pd.Series, str)
36013604

36023605
unknown_series = pd.Series([1, 0, None])
36033606
check(
3604-
assert_type(unknown_series.map({1: True, 0: False, None: None}), pd.Series),
3607+
assert_type(
3608+
unknown_series.map({1: True, 0: False, None: None}), "pd.Series[bool]"
3609+
),
3610+
pd.Series,
3611+
bool,
3612+
)
3613+
check(
3614+
assert_type(
3615+
unknown_series.map({1: True, 0: False, None: None}, na_action="ignore"),
3616+
"pd.Series[bool]",
3617+
),
36053618
pd.Series,
3619+
bool,
36063620
)
36073621

36083622

@@ -3626,6 +3640,16 @@ def callable(x: int | NAType) -> str | NAType:
36263640
series = pd.Series(["a", "b", "c"])
36273641
check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)
36283642

3643+
def callable2(x: int | NAType | None) -> str | None:
3644+
if isinstance(x, int):
3645+
return str(x)
3646+
return None
3647+
3648+
check(
3649+
assert_type(s.map(callable2, na_action=None), "pd.Series[str]"), pd.Series, str
3650+
)
3651+
check(assert_type(s.map(callable2), "pd.Series[str]"), pd.Series, str)
3652+
36293653

36303654
def test_case_when() -> None:
36313655
c = pd.Series([6, 7, 8, 9], name="c")

0 commit comments

Comments
 (0)