@@ -3584,6 +3584,7 @@ def test_map() -> None:
35843584 pd .Series ,
35853585 str ,
35863586 )
3587+ check (assert_type (s .map (mapping ), "pd.Series[str]" ), pd .Series , str )
35873588
35883589 def callable (x : int ) -> str :
35893590 return str (x )
@@ -3593,16 +3594,29 @@ def callable(x: int) -> str:
35933594 pd .Series ,
35943595 str ,
35953596 )
3597+ check (assert_type (s .map (callable ), "pd.Series[str]" ), pd .Series , str )
35963598
35973599 series = pd .Series (["a" , "b" , "c" ])
35983600 check (
35993601 assert_type (s .map (series , na_action = "ignore" ), "pd.Series[str]" ), pd .Series , str
36003602 )
3603+ check (assert_type (s .map (series ), "pd.Series[str]" ), pd .Series , str )
36013604
36023605 unknown_series = pd .Series ([1 , 0 , None ])
36033606 check (
3604- assert_type (unknown_series .map ({1 : True , 0 : False , None : None }), pd .Series ),
3607+ assert_type (
3608+ unknown_series .map ({1 : True , 0 : False , None : None }), "pd.Series[bool]"
3609+ ),
3610+ pd .Series ,
3611+ bool ,
3612+ )
3613+ check (
3614+ assert_type (
3615+ unknown_series .map ({1 : True , 0 : False , None : None }, na_action = "ignore" ),
3616+ "pd.Series[bool]" ,
3617+ ),
36053618 pd .Series ,
3619+ bool ,
36063620 )
36073621
36083622
@@ -3626,6 +3640,16 @@ def callable(x: int | NAType) -> str | NAType:
36263640 series = pd .Series (["a" , "b" , "c" ])
36273641 check (assert_type (s .map (series , na_action = None ), "pd.Series[str]" ), pd .Series , str )
36283642
3643+ def callable2 (x : int | NAType | None ) -> str | None :
3644+ if isinstance (x , int ):
3645+ return str (x )
3646+ return None
3647+
3648+ check (
3649+ assert_type (s .map (callable2 , na_action = None ), "pd.Series[str]" ), pd .Series , str
3650+ )
3651+ check (assert_type (s .map (callable2 ), "pd.Series[str]" ), pd .Series , str )
3652+
36293653
36303654def test_case_when () -> None :
36313655 c = pd .Series ([6 , 7 , 8 , 9 ], name = "c" )
0 commit comments