Skip to content

Commit e623ae3

Browse files
committed
test: add unit test for arrow_cast function to validate casting to Float64 and Int32
1 parent 1914a0b commit e623ae3

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

python/tests/test_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,18 @@ def test_temporal_functions(df):
905905
)
906906

907907

908+
def test_arrow_cast(df):
909+
df = df.select(
910+
f.arrow_cast(column("a"), "Float64").alias("a_as_float"),
911+
f.arrow_cast(column("a"), "Int32").alias("a_as_int"),
912+
)
913+
result = df.collect()
914+
assert len(result) == 1
915+
result = result[0]
916+
assert result.column(0) == pa.array([1.0, 2.0, 3.0], type=pa.float64())
917+
assert result.column(1) == pa.array([1, 2, 3], type=pa.int32())
918+
919+
908920
def test_case(df):
909921
df = df.select(
910922
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),

src/functions.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ macro_rules! expr_fn {
384384
($FUNC: ident, $($arg:ident)*, $DOC: expr) => {
385385
#[doc = $DOC]
386386
#[pyfunction]
387-
fn $FUNC($($arg: PyExpr),*, data_type: &str) -> PyExpr {
388-
functions::expr_fn::$FUNC($($arg.into()),*, data_type.to_string()).into()
387+
fn $FUNC($($arg: PyExpr),*) -> PyExpr {
388+
functions::expr_fn::$FUNC($($arg.into()),*).into()
389389
}
390390
};
391391
}
@@ -563,7 +563,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword
563563
expr_fn_vec!(named_struct);
564564
expr_fn!(from_unixtime, unixtime);
565565
expr_fn!(arrow_typeof, arg_1);
566-
expr_fn!(arrow_cast, expr data_type);
566+
expr_fn!(arrow_cast, datatype);
567567
expr_fn!(random);
568568

569569
// Array Functions

0 commit comments

Comments
 (0)