@@ -45,6 +45,8 @@ def correct_function(values, index, a):
4545 {"key" : ["a" , "a" , "b" , "b" , "a" ], "data" : [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]},
4646 columns = ["key" , "data" ],
4747 )
48+ expected = data .groupby ("key" ).sum () * 2.7
49+
4850 # py signature binding
4951 with pytest .raises (TypeError , match = "missing a required argument: 'a'" ):
5052 data .groupby ("key" ).agg (incorrect_function , engine = "numba" , b = 1 )
@@ -59,11 +61,13 @@ def correct_function(values, index, a):
5961 # numba signature check after binding
6062 with pytest .raises (NumbaUtilError , match = "numba does not support" ):
6163 data .groupby ("key" ).agg (incorrect_function , engine = "numba" , a = 1 )
62- data .groupby ("key" ).agg (correct_function , engine = "numba" , a = 1 )
64+ actual = data .groupby ("key" ).agg (correct_function , engine = "numba" , a = 1 )
65+ tm .assert_frame_equal (expected + 1 , actual )
6366
6467 with pytest .raises (NumbaUtilError , match = "numba does not support" ):
6568 data .groupby ("key" )["data" ].agg (incorrect_function , engine = "numba" , a = 1 )
66- data .groupby ("key" )["data" ].agg (correct_function , engine = "numba" , a = 1 )
69+ actual = data .groupby ("key" )["data" ].agg (correct_function , engine = "numba" , a = 1 )
70+ tm .assert_series_equal (expected ["data" ] + 1 , actual )
6771
6872
6973@pytest .mark .filterwarnings ("ignore" )
0 commit comments