From c358c44fadaa3040653c65748c303e0692831357 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:40:22 -0800 Subject: [PATCH 01/39] [SPARK-55162][PYTHON] Extract transformers from ArrowStreamUDFSerializer Extract the struct flattening/wrapping logic from ArrowStreamUDFSerializer into reusable transformers in a new transformers.py module. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/pandas/serializers.py | 52 +++--- python/pyspark/sql/pandas/transformers.py | 71 ++++++++ .../sql/tests/pandas/test_transformers.py | 169 ++++++++++++++++++ 4 files changed, 263 insertions(+), 30 deletions(-) create mode 100644 python/pyspark/sql/pandas/transformers.py create mode 100644 python/pyspark/sql/tests/pandas/test_transformers.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 6ba2e619f703..3bbd319fbe35 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -575,6 +575,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_pandas_sqlmetrics", "pyspark.sql.tests.pandas.test_converter", + "pyspark.sql.tests.pandas.test_transformers", "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ffdd6c9901ea..ff3bf4e5f2eb 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -37,6 +37,10 @@ ArrowTableToRowsConversion, ArrowArrayToPandasConversion, ) +from pyspark.sql.pandas.transformers import ( + FlattenStructTransformer, + WrapStructTransformer, +) from pyspark.sql.pandas.types import ( from_arrow_type, is_variant, @@ -143,50 +147,38 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. + + Uses transformers for data processing: + - load_stream: ArrowStreamSerializer -> FlattenStructTransformer + - dump_stream: WrapStructTransformer -> ArrowStreamSerializer (with START_ARROW_STREAM marker) """ + def __init__(self): + super().__init__() + self._flatten_struct = FlattenStructTransformer() + self._wrap_struct = WrapStructTransformer() + def load_stream(self, stream): """ Flatten the struct into Arrow's record batches. """ - import pyarrow as pa - batches = super().load_stream(stream) - for batch in batches: - struct = batch.column(0) - yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))] + # Wrap each batch in a list for downstream UDF processing + return ([batch] for batch in self._flatten_struct(batches)) def dump_stream(self, iterator, stream): """ Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. """ - import pyarrow as pa - - def wrap_and_init_stream(): - should_write_start_length = True - for batch, _ in iterator: - assert isinstance(batch, pa.RecordBatch) - - # Wrap the root struct - if batch.num_columns == 0: - # When batch has no column, it should still create - # an empty batch with the number of rows set. - struct = pa.array([{}] * batch.num_rows) - else: - struct = pa.StructArray.from_arrays( - batch.columns, fields=pa.struct(list(batch.schema)) - ) - batch = pa.RecordBatch.from_arrays([struct], ["_0"]) + import itertools - # Write the first record batch with initialization. - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch + batches = self._wrap_struct(iterator) + first = next(batches, None) + if first is None: + return - return super().dump_stream(wrap_and_init_stream(), stream) + write_int(SpecialLengths.START_ARROW_STREAM, stream) + return super().dump_stream(itertools.chain([first], batches), stream) class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): diff --git a/python/pyspark/sql/pandas/transformers.py b/python/pyspark/sql/pandas/transformers.py new file mode 100644 index 000000000000..79d0df47b21d --- /dev/null +++ b/python/pyspark/sql/pandas/transformers.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Arrow batch transformers for building data processing pipelines. + +These are pure callable classes that transform Iterator[RecordBatch] -> Iterator[...]. +They should have no side effects (no I/O, no writing to streams). +""" + +from typing import TYPE_CHECKING, Iterator, Tuple + +if TYPE_CHECKING: + import pyarrow as pa + + +class FlattenStructTransformer: + """ + Flatten a single struct column into a RecordBatch. + + Input: Iterator of RecordBatch with a single struct column + Output: Iterator of RecordBatch (flattened) + """ + + def __call__( + self, batches: Iterator["pa.RecordBatch"] + ) -> Iterator["pa.RecordBatch"]: + import pyarrow as pa + + for batch in batches: + struct = batch.column(0) + yield pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) + + +class WrapStructTransformer: + """ + Wrap a RecordBatch's columns into a single struct column. + + Input: Iterator of (RecordBatch, arrow_type) + Output: Iterator of RecordBatch with a single struct column + """ + + def __call__( + self, iterator: Iterator[Tuple["pa.RecordBatch", "pa.DataType"]] + ) -> Iterator["pa.RecordBatch"]: + import pyarrow as pa + + for batch, _ in iterator: + if batch.num_columns == 0: + # When batch has no column, it should still create + # an empty batch with the number of rows set. + struct = pa.array([{}] * batch.num_rows) + else: + struct = pa.StructArray.from_arrays( + batch.columns, fields=pa.struct(list(batch.schema)) + ) + yield pa.RecordBatch.from_arrays([struct], ["_0"]) diff --git a/python/pyspark/sql/tests/pandas/test_transformers.py b/python/pyspark/sql/tests/pandas/test_transformers.py new file mode 100644 index 000000000000..779a0f4b6980 --- /dev/null +++ b/python/pyspark/sql/tests/pandas/test_transformers.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from pyspark.testing.sqlutils import ( + have_pyarrow, + pyarrow_requirement_message, +) + +if have_pyarrow: + import pyarrow as pa + + from pyspark.sql.pandas.transformers import FlattenStructTransformer, WrapStructTransformer + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class TransformerTests(unittest.TestCase): + def test_flatten_struct_transformer_basic(self): + """Test flattening a struct column into separate columns.""" + # Create a batch with a single struct column containing two fields + struct_array = pa.StructArray.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + transformer = FlattenStructTransformer() + result = list(transformer(iter([batch]))) + + self.assertEqual(len(result), 1) + flattened = result[0] + self.assertEqual(flattened.num_columns, 2) + self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3]) + self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"]) + self.assertEqual(flattened.schema.names, ["x", "y"]) + + def test_flatten_struct_transformer_multiple_batches(self): + """Test flattening multiple batches.""" + batches = [] + for i in range(3): + struct_array = pa.StructArray.from_arrays( + [pa.array([i * 10 + j for j in range(2)])], + names=["val"], + ) + batches.append(pa.RecordBatch.from_arrays([struct_array], ["_0"])) + + transformer = FlattenStructTransformer() + result = list(transformer(iter(batches))) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0].column(0).to_pylist(), [0, 1]) + self.assertEqual(result[1].column(0).to_pylist(), [10, 11]) + self.assertEqual(result[2].column(0).to_pylist(), [20, 21]) + + def test_flatten_struct_transformer_empty_batch(self): + """Test flattening an empty batch.""" + struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + struct_array = pa.array([], type=struct_type) + batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + transformer = FlattenStructTransformer() + result = list(transformer(iter([batch]))) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].num_rows, 0) + self.assertEqual(result[0].num_columns, 2) + + def test_wrap_struct_transformer_basic(self): + """Test wrapping columns into a struct.""" + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + + transformer = WrapStructTransformer() + result = list(transformer(iter([(batch, arrow_type)]))) + + self.assertEqual(len(result), 1) + wrapped = result[0] + self.assertEqual(wrapped.num_columns, 1) + self.assertEqual(wrapped.schema.names, ["_0"]) + + # Verify the struct content + struct_col = wrapped.column(0) + self.assertEqual(len(struct_col), 3) + # Access struct fields + self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3]) + self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"]) + + def test_wrap_struct_transformer_multiple_batches(self): + """Test wrapping multiple batches.""" + batches_with_types = [] + arrow_type = pa.struct([("val", pa.int64())]) + for i in range(3): + batch = pa.RecordBatch.from_arrays( + [pa.array([i * 10 + j for j in range(2)])], + names=["val"], + ) + batches_with_types.append((batch, arrow_type)) + + transformer = WrapStructTransformer() + result = list(transformer(iter(batches_with_types))) + + self.assertEqual(len(result), 3) + for i, wrapped in enumerate(result): + self.assertEqual(wrapped.num_columns, 1) + struct_col = wrapped.column(0) + self.assertEqual(struct_col.field(0).to_pylist(), [i * 10, i * 10 + 1]) + + def test_wrap_struct_transformer_empty_columns(self): + """Test wrapping a batch with no columns.""" + # Create an empty schema batch with some rows + batch = pa.RecordBatch.from_arrays([], names=[]) + # Manually set num_rows by creating from pydict + batch = pa.RecordBatch.from_pydict({}, schema=pa.schema([])) + # Create batch with rows using a workaround + batch = pa.record_batch({"dummy": [1, 2, 3]}).select([]).slice(0, 3) + # Actually, let's create it properly + schema = pa.schema([]) + batch = pa.RecordBatch.from_arrays([], schema=schema) + + arrow_type = pa.struct([]) + + transformer = WrapStructTransformer() + result = list(transformer(iter([(batch, arrow_type)]))) + + self.assertEqual(len(result), 1) + wrapped = result[0] + self.assertEqual(wrapped.num_columns, 1) + # Empty struct batch has 0 rows + self.assertEqual(wrapped.num_rows, 0) + + def test_wrap_struct_transformer_empty_batch(self): + """Test wrapping an empty batch with schema.""" + schema = pa.schema([("x", pa.int64()), ("y", pa.string())]) + batch = pa.RecordBatch.from_arrays( + [pa.array([], type=pa.int64()), pa.array([], type=pa.string())], + schema=schema, + ) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + + transformer = WrapStructTransformer() + result = list(transformer(iter([(batch, arrow_type)]))) + + self.assertEqual(len(result), 1) + wrapped = result[0] + self.assertEqual(wrapped.num_rows, 0) + self.assertEqual(wrapped.num_columns, 1) + + +if __name__ == "__main__": + from pyspark.testing import main + + main() From d0c2644649674a9974b7ac17db04f176e8d2f10c Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:40:35 -0800 Subject: [PATCH 02/39] fix: format --- python/pyspark/sql/pandas/transformers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/transformers.py b/python/pyspark/sql/pandas/transformers.py index 79d0df47b21d..62d5e258c49e 100644 --- a/python/pyspark/sql/pandas/transformers.py +++ b/python/pyspark/sql/pandas/transformers.py @@ -36,9 +36,7 @@ class FlattenStructTransformer: Output: Iterator of RecordBatch (flattened) """ - def __call__( - self, batches: Iterator["pa.RecordBatch"] - ) -> Iterator["pa.RecordBatch"]: + def __call__(self, batches: Iterator["pa.RecordBatch"]) -> Iterator["pa.RecordBatch"]: import pyarrow as pa for batch in batches: From fc236834f4d87d54912feb17f24caa0cd58a0db9 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:46:40 -0800 Subject: [PATCH 03/39] refactor: simplify --- python/pyspark/sql/pandas/serializers.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ff3bf4e5f2eb..cec694b319ef 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -147,10 +147,6 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. - - Uses transformers for data processing: - - load_stream: ArrowStreamSerializer -> FlattenStructTransformer - - dump_stream: WrapStructTransformer -> ArrowStreamSerializer (with START_ARROW_STREAM marker) """ def __init__(self): @@ -163,8 +159,7 @@ def load_stream(self, stream): Flatten the struct into Arrow's record batches. """ batches = super().load_stream(stream) - # Wrap each batch in a list for downstream UDF processing - return ([batch] for batch in self._flatten_struct(batches)) + return map(list, self._flatten_struct(batches)) def dump_stream(self, iterator, stream): """ From 1464cfb1775b7e3c982dd3b795474e00e30dce41 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:35:21 -0800 Subject: [PATCH 04/39] refactor: use function and apply maps --- python/pyspark/sql/pandas/serializers.py | 13 +-- python/pyspark/sql/pandas/transformers.py | 46 ++++----- .../sql/tests/pandas/test_transformers.py | 93 +++---------------- 3 files changed, 34 insertions(+), 118 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index cec694b319ef..a33a3bc1b89a 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -38,8 +38,8 @@ ArrowArrayToPandasConversion, ) from pyspark.sql.pandas.transformers import ( - FlattenStructTransformer, - WrapStructTransformer, + flatten_struct, + wrap_struct, ) from pyspark.sql.pandas.types import ( from_arrow_type, @@ -149,17 +149,12 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer): for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. """ - def __init__(self): - super().__init__() - self._flatten_struct = FlattenStructTransformer() - self._wrap_struct = WrapStructTransformer() - def load_stream(self, stream): """ Flatten the struct into Arrow's record batches. """ batches = super().load_stream(stream) - return map(list, self._flatten_struct(batches)) + return map(list, map(flatten_struct, batches)) def dump_stream(self, iterator, stream): """ @@ -167,7 +162,7 @@ def dump_stream(self, iterator, stream): """ import itertools - batches = self._wrap_struct(iterator) + batches = map(lambda x: wrap_struct(x[0]), iterator) first = next(batches, None) if first is None: return diff --git a/python/pyspark/sql/pandas/transformers.py b/python/pyspark/sql/pandas/transformers.py index 62d5e258c49e..930a1af09695 100644 --- a/python/pyspark/sql/pandas/transformers.py +++ b/python/pyspark/sql/pandas/transformers.py @@ -18,52 +18,40 @@ """ Arrow batch transformers for building data processing pipelines. -These are pure callable classes that transform Iterator[RecordBatch] -> Iterator[...]. +These are pure functions that transform RecordBatch -> RecordBatch. They should have no side effects (no I/O, no writing to streams). """ -from typing import TYPE_CHECKING, Iterator, Tuple +from typing import TYPE_CHECKING if TYPE_CHECKING: import pyarrow as pa -class FlattenStructTransformer: +def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ Flatten a single struct column into a RecordBatch. - Input: Iterator of RecordBatch with a single struct column - Output: Iterator of RecordBatch (flattened) + Used by: ArrowStreamUDFSerializer.load_stream """ + import pyarrow as pa - def __call__(self, batches: Iterator["pa.RecordBatch"]) -> Iterator["pa.RecordBatch"]: - import pyarrow as pa - - for batch in batches: - struct = batch.column(0) - yield pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) + struct = batch.column(0) + return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) -class WrapStructTransformer: +def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ Wrap a RecordBatch's columns into a single struct column. - Input: Iterator of (RecordBatch, arrow_type) - Output: Iterator of RecordBatch with a single struct column + Used by: ArrowStreamUDFSerializer.dump_stream """ + import pyarrow as pa - def __call__( - self, iterator: Iterator[Tuple["pa.RecordBatch", "pa.DataType"]] - ) -> Iterator["pa.RecordBatch"]: - import pyarrow as pa - - for batch, _ in iterator: - if batch.num_columns == 0: - # When batch has no column, it should still create - # an empty batch with the number of rows set. - struct = pa.array([{}] * batch.num_rows) - else: - struct = pa.StructArray.from_arrays( - batch.columns, fields=pa.struct(list(batch.schema)) - ) - yield pa.RecordBatch.from_arrays([struct], ["_0"]) + if batch.num_columns == 0: + # When batch has no column, it should still create + # an empty batch with the number of rows set. + struct = pa.array([{}] * batch.num_rows) + else: + struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) + return pa.RecordBatch.from_arrays([struct], ["_0"]) diff --git a/python/pyspark/sql/tests/pandas/test_transformers.py b/python/pyspark/sql/tests/pandas/test_transformers.py index 779a0f4b6980..60aa459a5739 100644 --- a/python/pyspark/sql/tests/pandas/test_transformers.py +++ b/python/pyspark/sql/tests/pandas/test_transformers.py @@ -24,141 +24,74 @@ if have_pyarrow: import pyarrow as pa - from pyspark.sql.pandas.transformers import FlattenStructTransformer, WrapStructTransformer + from pyspark.sql.pandas.transformers import flatten_struct, wrap_struct @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class TransformerTests(unittest.TestCase): - def test_flatten_struct_transformer_basic(self): + def test_flatten_struct_basic(self): """Test flattening a struct column into separate columns.""" - # Create a batch with a single struct column containing two fields struct_array = pa.StructArray.from_arrays( [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["x", "y"], ) batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) - transformer = FlattenStructTransformer() - result = list(transformer(iter([batch]))) + flattened = flatten_struct(batch) - self.assertEqual(len(result), 1) - flattened = result[0] self.assertEqual(flattened.num_columns, 2) self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3]) self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"]) self.assertEqual(flattened.schema.names, ["x", "y"]) - def test_flatten_struct_transformer_multiple_batches(self): - """Test flattening multiple batches.""" - batches = [] - for i in range(3): - struct_array = pa.StructArray.from_arrays( - [pa.array([i * 10 + j for j in range(2)])], - names=["val"], - ) - batches.append(pa.RecordBatch.from_arrays([struct_array], ["_0"])) - - transformer = FlattenStructTransformer() - result = list(transformer(iter(batches))) - - self.assertEqual(len(result), 3) - self.assertEqual(result[0].column(0).to_pylist(), [0, 1]) - self.assertEqual(result[1].column(0).to_pylist(), [10, 11]) - self.assertEqual(result[2].column(0).to_pylist(), [20, 21]) - - def test_flatten_struct_transformer_empty_batch(self): + def test_flatten_struct_empty_batch(self): """Test flattening an empty batch.""" struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) struct_array = pa.array([], type=struct_type) batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) - transformer = FlattenStructTransformer() - result = list(transformer(iter([batch]))) + flattened = flatten_struct(batch) - self.assertEqual(len(result), 1) - self.assertEqual(result[0].num_rows, 0) - self.assertEqual(result[0].num_columns, 2) + self.assertEqual(flattened.num_rows, 0) + self.assertEqual(flattened.num_columns, 2) - def test_wrap_struct_transformer_basic(self): + def test_wrap_struct_basic(self): """Test wrapping columns into a struct.""" batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["x", "y"], ) - arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) - transformer = WrapStructTransformer() - result = list(transformer(iter([(batch, arrow_type)]))) + wrapped = wrap_struct(batch) - self.assertEqual(len(result), 1) - wrapped = result[0] self.assertEqual(wrapped.num_columns, 1) self.assertEqual(wrapped.schema.names, ["_0"]) - # Verify the struct content struct_col = wrapped.column(0) self.assertEqual(len(struct_col), 3) - # Access struct fields self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3]) self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"]) - def test_wrap_struct_transformer_multiple_batches(self): - """Test wrapping multiple batches.""" - batches_with_types = [] - arrow_type = pa.struct([("val", pa.int64())]) - for i in range(3): - batch = pa.RecordBatch.from_arrays( - [pa.array([i * 10 + j for j in range(2)])], - names=["val"], - ) - batches_with_types.append((batch, arrow_type)) - - transformer = WrapStructTransformer() - result = list(transformer(iter(batches_with_types))) - - self.assertEqual(len(result), 3) - for i, wrapped in enumerate(result): - self.assertEqual(wrapped.num_columns, 1) - struct_col = wrapped.column(0) - self.assertEqual(struct_col.field(0).to_pylist(), [i * 10, i * 10 + 1]) - - def test_wrap_struct_transformer_empty_columns(self): + def test_wrap_struct_empty_columns(self): """Test wrapping a batch with no columns.""" - # Create an empty schema batch with some rows - batch = pa.RecordBatch.from_arrays([], names=[]) - # Manually set num_rows by creating from pydict - batch = pa.RecordBatch.from_pydict({}, schema=pa.schema([])) - # Create batch with rows using a workaround - batch = pa.record_batch({"dummy": [1, 2, 3]}).select([]).slice(0, 3) - # Actually, let's create it properly schema = pa.schema([]) batch = pa.RecordBatch.from_arrays([], schema=schema) - arrow_type = pa.struct([]) - - transformer = WrapStructTransformer() - result = list(transformer(iter([(batch, arrow_type)]))) + wrapped = wrap_struct(batch) - self.assertEqual(len(result), 1) - wrapped = result[0] self.assertEqual(wrapped.num_columns, 1) - # Empty struct batch has 0 rows self.assertEqual(wrapped.num_rows, 0) - def test_wrap_struct_transformer_empty_batch(self): + def test_wrap_struct_empty_batch(self): """Test wrapping an empty batch with schema.""" schema = pa.schema([("x", pa.int64()), ("y", pa.string())]) batch = pa.RecordBatch.from_arrays( [pa.array([], type=pa.int64()), pa.array([], type=pa.string())], schema=schema, ) - arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) - transformer = WrapStructTransformer() - result = list(transformer(iter([(batch, arrow_type)]))) + wrapped = wrap_struct(batch) - self.assertEqual(len(result), 1) - wrapped = result[0] self.assertEqual(wrapped.num_rows, 0) self.assertEqual(wrapped.num_columns, 1) From 5f9627b486c8729fdf540998023705a0a90f65e8 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 22:59:30 -0800 Subject: [PATCH 05/39] refactor: move to conversion.py --- dev/sparktestsupport/modules.py | 1 - python/pyspark/sql/conversion.py | 38 +++++++ python/pyspark/sql/pandas/serializers.py | 9 +- python/pyspark/sql/pandas/transformers.py | 57 ---------- .../sql/tests/pandas/test_transformers.py | 102 ------------------ python/pyspark/sql/tests/test_conversion.py | 80 ++++++++++++++ 6 files changed, 121 insertions(+), 166 deletions(-) delete mode 100644 python/pyspark/sql/pandas/transformers.py delete mode 100644 python/pyspark/sql/tests/pandas/test_transformers.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3bbd319fbe35..6ba2e619f703 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -575,7 +575,6 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_pandas_sqlmetrics", "pyspark.sql.tests.pandas.test_converter", - "pyspark.sql.tests.pandas.test_transformers", "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 0a6478d49431..4da824a1c172 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -56,6 +56,44 @@ import pandas as pd +class ArrowBatchTransformer: + """ + Pure functions that transform RecordBatch -> RecordBatch. + They should have no side effects (no I/O, no writing to streams). + """ + + @staticmethod + def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": + """ + Flatten a single struct column into a RecordBatch. + + Used by: ArrowStreamUDFSerializer.load_stream + """ + import pyarrow as pa + + struct = batch.column(0) + return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) + + @staticmethod + def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": + """ + Wrap a RecordBatch's columns into a single struct column. + + Used by: ArrowStreamUDFSerializer.dump_stream + """ + import pyarrow as pa + + if batch.num_columns == 0: + # When batch has no column, it should still create + # an empty batch with the number of rows set. + struct = pa.array([{}] * batch.num_rows) + else: + struct = pa.StructArray.from_arrays( + batch.columns, fields=pa.struct(list(batch.schema)) + ) + return pa.RecordBatch.from_arrays([struct], ["_0"]) + + class LocalDataToArrowConversion: """ Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index a33a3bc1b89a..f45ba0f8e1ff 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -37,10 +37,7 @@ ArrowTableToRowsConversion, ArrowArrayToPandasConversion, ) -from pyspark.sql.pandas.transformers import ( - flatten_struct, - wrap_struct, -) +from pyspark.sql.conversion import ArrowBatchTransformer from pyspark.sql.pandas.types import ( from_arrow_type, is_variant, @@ -154,7 +151,7 @@ def load_stream(self, stream): Flatten the struct into Arrow's record batches. """ batches = super().load_stream(stream) - return map(list, map(flatten_struct, batches)) + return map(list, map(ArrowBatchTransformer.flatten_struct, batches)) def dump_stream(self, iterator, stream): """ @@ -162,7 +159,7 @@ def dump_stream(self, iterator, stream): """ import itertools - batches = map(lambda x: wrap_struct(x[0]), iterator) + batches = map(lambda x: ArrowBatchTransformer.wrap_struct(x[0]), iterator) first = next(batches, None) if first is None: return diff --git a/python/pyspark/sql/pandas/transformers.py b/python/pyspark/sql/pandas/transformers.py deleted file mode 100644 index 930a1af09695..000000000000 --- a/python/pyspark/sql/pandas/transformers.py +++ /dev/null @@ -1,57 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Arrow batch transformers for building data processing pipelines. - -These are pure functions that transform RecordBatch -> RecordBatch. -They should have no side effects (no I/O, no writing to streams). -""" - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import pyarrow as pa - - -def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": - """ - Flatten a single struct column into a RecordBatch. - - Used by: ArrowStreamUDFSerializer.load_stream - """ - import pyarrow as pa - - struct = batch.column(0) - return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) - - -def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": - """ - Wrap a RecordBatch's columns into a single struct column. - - Used by: ArrowStreamUDFSerializer.dump_stream - """ - import pyarrow as pa - - if batch.num_columns == 0: - # When batch has no column, it should still create - # an empty batch with the number of rows set. - struct = pa.array([{}] * batch.num_rows) - else: - struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) - return pa.RecordBatch.from_arrays([struct], ["_0"]) diff --git a/python/pyspark/sql/tests/pandas/test_transformers.py b/python/pyspark/sql/tests/pandas/test_transformers.py deleted file mode 100644 index 60aa459a5739..000000000000 --- a/python/pyspark/sql/tests/pandas/test_transformers.py +++ /dev/null @@ -1,102 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -from pyspark.testing.sqlutils import ( - have_pyarrow, - pyarrow_requirement_message, -) - -if have_pyarrow: - import pyarrow as pa - - from pyspark.sql.pandas.transformers import flatten_struct, wrap_struct - - -@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) -class TransformerTests(unittest.TestCase): - def test_flatten_struct_basic(self): - """Test flattening a struct column into separate columns.""" - struct_array = pa.StructArray.from_arrays( - [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], - names=["x", "y"], - ) - batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) - - flattened = flatten_struct(batch) - - self.assertEqual(flattened.num_columns, 2) - self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3]) - self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"]) - self.assertEqual(flattened.schema.names, ["x", "y"]) - - def test_flatten_struct_empty_batch(self): - """Test flattening an empty batch.""" - struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) - struct_array = pa.array([], type=struct_type) - batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) - - flattened = flatten_struct(batch) - - self.assertEqual(flattened.num_rows, 0) - self.assertEqual(flattened.num_columns, 2) - - def test_wrap_struct_basic(self): - """Test wrapping columns into a struct.""" - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], - names=["x", "y"], - ) - - wrapped = wrap_struct(batch) - - self.assertEqual(wrapped.num_columns, 1) - self.assertEqual(wrapped.schema.names, ["_0"]) - - struct_col = wrapped.column(0) - self.assertEqual(len(struct_col), 3) - self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3]) - self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"]) - - def test_wrap_struct_empty_columns(self): - """Test wrapping a batch with no columns.""" - schema = pa.schema([]) - batch = pa.RecordBatch.from_arrays([], schema=schema) - - wrapped = wrap_struct(batch) - - self.assertEqual(wrapped.num_columns, 1) - self.assertEqual(wrapped.num_rows, 0) - - def test_wrap_struct_empty_batch(self): - """Test wrapping an empty batch with schema.""" - schema = pa.schema([("x", pa.int64()), ("y", pa.string())]) - batch = pa.RecordBatch.from_arrays( - [pa.array([], type=pa.int64()), pa.array([], type=pa.string())], - schema=schema, - ) - - wrapped = wrap_struct(batch) - - self.assertEqual(wrapped.num_rows, 0) - self.assertEqual(wrapped.num_columns, 1) - - -if __name__ == "__main__": - from pyspark.testing import main - - main() diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 9773b2154c63..c3fa1fd19304 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -23,6 +23,7 @@ ArrowTableToRowsConversion, LocalDataToArrowConversion, ArrowTimestampConversion, + ArrowBatchTransformer, ) from pyspark.sql.types import ( ArrayType, @@ -64,6 +65,85 @@ def __eq__(self, other): return self.score == other.score +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class ArrowBatchTransformerTests(unittest.TestCase): + def test_flatten_struct_basic(self): + """Test flattening a struct column into separate columns.""" + import pyarrow as pa + + struct_array = pa.StructArray.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + flattened = ArrowBatchTransformer.flatten_struct(batch) + + self.assertEqual(flattened.num_columns, 2) + self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3]) + self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"]) + self.assertEqual(flattened.schema.names, ["x", "y"]) + + def test_flatten_struct_empty_batch(self): + """Test flattening an empty batch.""" + import pyarrow as pa + + struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + struct_array = pa.array([], type=struct_type) + batch = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + flattened = ArrowBatchTransformer.flatten_struct(batch) + + self.assertEqual(flattened.num_rows, 0) + self.assertEqual(flattened.num_columns, 2) + + def test_wrap_struct_basic(self): + """Test wrapping columns into a struct.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + + wrapped = ArrowBatchTransformer.wrap_struct(batch) + + self.assertEqual(wrapped.num_columns, 1) + self.assertEqual(wrapped.schema.names, ["_0"]) + + struct_col = wrapped.column(0) + self.assertEqual(len(struct_col), 3) + self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3]) + self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"]) + + def test_wrap_struct_empty_columns(self): + """Test wrapping a batch with no columns.""" + import pyarrow as pa + + schema = pa.schema([]) + batch = pa.RecordBatch.from_arrays([], schema=schema) + + wrapped = ArrowBatchTransformer.wrap_struct(batch) + + self.assertEqual(wrapped.num_columns, 1) + self.assertEqual(wrapped.num_rows, 0) + + def test_wrap_struct_empty_batch(self): + """Test wrapping an empty batch with schema.""" + import pyarrow as pa + + schema = pa.schema([("x", pa.int64()), ("y", pa.string())]) + batch = pa.RecordBatch.from_arrays( + [pa.array([], type=pa.int64()), pa.array([], type=pa.string())], + schema=schema, + ) + + wrapped = ArrowBatchTransformer.wrap_struct(batch) + + self.assertEqual(wrapped.num_rows, 0) + self.assertEqual(wrapped.num_columns, 1) + + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class ConversionTests(unittest.TestCase): def test_conversion(self): From b396e1077534b19413cdeee209b90fa21be8f3f7 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 23 Jan 2026 22:59:46 -0800 Subject: [PATCH 06/39] fix: format --- python/pyspark/sql/conversion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 4da824a1c172..81ceea857e19 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -88,9 +88,7 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": # an empty batch with the number of rows set. struct = pa.array([{}] * batch.num_rows) else: - struct = pa.StructArray.from_arrays( - batch.columns, fields=pa.struct(list(batch.schema)) - ) + struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) From f6dbc95d3a0fe8cf79349cffe55324d36a1098ed Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 24 Jan 2026 12:13:26 -0800 Subject: [PATCH 07/39] fix: keep wrapper --- python/pyspark/sql/pandas/serializers.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f45ba0f8e1ff..d1d641611853 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -151,7 +151,8 @@ def load_stream(self, stream): Flatten the struct into Arrow's record batches. """ batches = super().load_stream(stream) - return map(list, map(ArrowBatchTransformer.flatten_struct, batches)) + flattened = map(ArrowBatchTransformer.flatten_struct, batches) + return map(lambda b: [b], flattened) def dump_stream(self, iterator, stream): """ @@ -159,13 +160,15 @@ def dump_stream(self, iterator, stream): """ import itertools - batches = map(lambda x: ArrowBatchTransformer.wrap_struct(x[0]), iterator) - first = next(batches, None) + first = next(iterator, None) if first is None: return write_int(SpecialLengths.START_ARROW_STREAM, stream) - return super().dump_stream(itertools.chain([first], batches), stream) + batches = map( + lambda x: ArrowBatchTransformer.wrap_struct(x[0]), itertools.chain([first], iterator) + ) + return super().dump_stream(batches, stream) class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): From 535ad45082c9e2360dfeb9872622137c1fa7ef8d Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 24 Jan 2026 12:31:41 -0800 Subject: [PATCH 08/39] refactor: use transformer for GroupArrowUDFSerializer --- python/pyspark/sql/conversion.py | 4 +++- python/pyspark/sql/pandas/serializers.py | 12 ++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 81ceea857e19..c580899647fa 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -67,7 +67,9 @@ def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ Flatten a single struct column into a RecordBatch. - Used by: ArrowStreamUDFSerializer.load_stream + Used by: + - ArrowStreamUDFSerializer.load_stream + - GroupArrowUDFSerializer.load_stream """ import pyarrow as pa diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d1d641611853..4c87605b1b59 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1016,20 +1016,16 @@ def load_stream(self, stream): """ Flatten the struct into Arrow's record batches. """ - import pyarrow as pa - - def process_group(batches: "Iterator[pa.RecordBatch]"): - for batch in batches: - struct = batch.column(0) - yield pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) - dataframes_in_group = None while dataframes_in_group is None or dataframes_in_group > 0: dataframes_in_group = read_int(stream) if dataframes_in_group == 1: - batch_iter = process_group(ArrowStreamSerializer.load_stream(self, stream)) + batch_iter = map( + ArrowBatchTransformer.flatten_struct, + ArrowStreamSerializer.load_stream(self, stream), + ) yield batch_iter # Make sure the batches are fully iterated before getting the next group for _ in batch_iter: From 26b0a70a00707a971fe045722d9f7c4f6a008d6e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 24 Jan 2026 13:04:07 -0800 Subject: [PATCH 09/39] refactor: use flatten_struct in ArrowStreamArrowUDTFSerializer --- python/pyspark/sql/conversion.py | 7 ++++--- python/pyspark/sql/pandas/serializers.py | 26 ++++++++---------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index c580899647fa..31c43ddd2797 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -63,17 +63,18 @@ class ArrowBatchTransformer: """ @staticmethod - def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": + def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.RecordBatch": """ - Flatten a single struct column into a RecordBatch. + Flatten a struct column at given index into a RecordBatch. Used by: - ArrowStreamUDFSerializer.load_stream - GroupArrowUDFSerializer.load_stream + - ArrowStreamArrowUDTFSerializer.load_stream """ import pyarrow as pa - struct = batch.column(0) + struct = batch.column(column_index) return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) @staticmethod diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 4c87605b1b59..6802a83a1462 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -193,23 +193,15 @@ def load_stream(self, stream): """ Flatten the struct into Arrow's record batches. """ - import pyarrow as pa - - batches = super().load_stream(stream) - for batch in batches: - result_batches = [] - for i in range(batch.num_columns): - if i in self.table_arg_offsets: - struct = batch.column(i) - # Flatten the struct and create a RecordBatch from it - flattened_batch = pa.RecordBatch.from_arrays( - struct.flatten(), schema=pa.schema(struct.type) - ) - result_batches.append(flattened_batch) - else: - # Keep the column as it is for non-table columns - result_batches.append(batch.column(i)) - yield result_batches + for batch in super().load_stream(stream): + # For each column: flatten struct columns at table_arg_offsets into RecordBatch, + # keep other columns as Array + yield [ + ArrowBatchTransformer.flatten_struct(batch, column_index=i) + if i in self.table_arg_offsets + else batch.column(i) + for i in range(batch.num_columns) + ] def _create_array(self, arr, arrow_type): import pyarrow as pa From 8467a5929da4c921618f56176c32a5dee6a0e02d Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 24 Jan 2026 14:51:16 -0800 Subject: [PATCH 10/39] fix: import --- python/pyspark/sql/pandas/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6802a83a1462..4db6bbb9305e 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -36,8 +36,8 @@ LocalDataToArrowConversion, ArrowTableToRowsConversion, ArrowArrayToPandasConversion, + ArrowBatchTransformer, ) -from pyspark.sql.conversion import ArrowBatchTransformer from pyspark.sql.pandas.types import ( from_arrow_type, is_variant, From 31370e218f3850d3a31dd316e323fba989842c05 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:46:11 -0800 Subject: [PATCH 11/39] refactor: extract converter logic out to Conversion --- python/pyspark/sql/conversion.py | 61 ++++++++++++++++ python/pyspark/sql/pandas/serializers.py | 90 +++++------------------- 2 files changed, 80 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 0a6478d49431..3d6f7e8c87e0 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -994,3 +994,64 @@ def convert_legacy( integer_object_nulls=True, ) return converter(ser) + + @classmethod + def create_converter( + cls, + timezone: str, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + input_types: Optional[List] = None, + ) -> Callable[["pa.Array", int], "pd.Series"]: + """ + Create an arrow_to_pandas converter function. + + Parameters + ---------- + timezone : str + Timezone for timestamp conversion. + struct_in_pandas : str + How to represent struct in pandas ("dict", "row", etc.) + ndarray_as_list : bool + Whether to convert ndarray as list. + df_for_struct : bool + If True, convert struct columns to DataFrame instead of Series. + input_types : list, optional + Spark types for each column, used for precise type conversion. + + Returns + ------- + callable + Function (arrow_column, idx) -> pd.Series or pd.DataFrame + """ + import pyarrow.types as types + + from pyspark.sql.pandas.types import from_arrow_type, is_variant + + def convert(arr: "pa.Array", spark_type=None) -> "pd.Series": + return cls.convert_legacy( + arr, + spark_type or from_arrow_type(arr.type), + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + ) + + def converter(arrow_column: "pa.Array", idx: int) -> "pd.Series": + spark_type = input_types[idx] if input_types is not None else None + + # Special case: flatten struct to DataFrame when df_for_struct is enabled + if df_for_struct and types.is_struct(arrow_column.type) and not is_variant(arrow_column.type): + import pandas as pd + + return pd.concat( + [ + convert(col, spark_type[i].dataType if spark_type else None).rename(f.name) + for i, (col, f) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) + ], + axis=1, + ) + return convert(arrow_column, spark_type) + + return converter diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ffdd6c9901ea..7fe20881f2d2 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -357,17 +357,7 @@ def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled - - def arrow_to_pandas( - self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None - ): - return ArrowArrayToPandasConversion.convert_legacy( - arrow_column, - spark_type or from_arrow_type(arrow_column.type), - timezone=self._timezone, - struct_in_pandas=struct_in_pandas, - ndarray_as_list=ndarray_as_list, - ) + self._arrow_to_pandas = ArrowArrayToPandasConversion.create_converter(timezone=timezone) def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ @@ -494,9 +484,7 @@ def load_stream(self, stream): import pandas as pd for batch in batches: - pandas_batches = [ - self.arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) - ] + pandas_batches = [self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns)] if len(pandas_batches) == 0: yield [pd.Series([pyspark._NoValue] * batch.num_rows)] else: @@ -525,52 +513,17 @@ def __init__( ): super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled) self._assign_cols_by_name = assign_cols_by_name - self._df_for_struct = df_for_struct self._struct_in_pandas = struct_in_pandas - self._ndarray_as_list = ndarray_as_list self._arrow_cast = arrow_cast self._input_types = input_types - - def arrow_to_pandas(self, arrow_column, idx): - import pyarrow.types as types - - # If the arrow type is struct, return a pandas dataframe where the fields of the struct - # correspond to columns in the DataFrame. However, if the arrow struct is actually a - # Variant, which is an atomic type, treat it as a non-struct arrow type. - if ( - self._df_for_struct - and types.is_struct(arrow_column.type) - and not is_variant(arrow_column.type) - ): - import pandas as pd - - series = [ - # Need to be explicit here because it's in a comprehension - super(ArrowStreamPandasUDFSerializer, self) - .arrow_to_pandas( - column, - i, - self._struct_in_pandas, - self._ndarray_as_list, - spark_type=( - self._input_types[idx][i].dataType - if self._input_types is not None - else None - ), - ) - .rename(field.name) - for i, (column, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) - ] - s = pd.concat(series, axis=1) - else: - s = super().arrow_to_pandas( - arrow_column, - idx, - self._struct_in_pandas, - self._ndarray_as_list, - spark_type=self._input_types[idx] if self._input_types is not None else None, - ) - return s + # Override parent's _arrow_to_pandas converter with full config + self._arrow_to_pandas = ArrowArrayToPandasConversion.create_converter( + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + input_types=input_types, + ) def _create_struct_array( self, @@ -1126,7 +1079,6 @@ def load_stream(self, stream): Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ - dataframes_in_group = None while dataframes_in_group is None or dataframes_in_group > 0: @@ -1136,7 +1088,7 @@ def load_stream(self, stream): # Lazily read and convert Arrow batches to pandas Series one at a time # from the stream. This avoids loading all batches into memory for the group batch_iter = ( - tuple(self.arrow_to_pandas(c, i) for i, c in enumerate(batch.columns)) + tuple(self._arrow_to_pandas(c, i) for i, c in enumerate(batch.columns)) for batch in ArrowStreamSerializer.load_stream(self, stream) ) yield batch_iter @@ -1186,9 +1138,7 @@ def process_group(batches: "Iterator[pa.RecordBatch]"): # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch for batch in batches: # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) - series = [ - self.arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) - ] + series = [self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns)] yield series dataframes_in_group = None @@ -1277,11 +1227,11 @@ def load_stream(self, stream): batches2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] yield ( [ - self.arrow_to_pandas(c, i) + self._arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches(batches1).itercolumns()) ], [ - self.arrow_to_pandas(c, i) + self._arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches(batches2).itercolumns()) ], ) @@ -1450,7 +1400,7 @@ def gen_data_and_state(batches): ) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(state_arrow)][0] + state_pandas = [self._arrow_to_pandas(c, i) for i, c in enumerate(state_arrow)][0] for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -1482,7 +1432,7 @@ def gen_data_and_state(batches): data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() - data_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(data_arrow)] + data_pandas = [self._arrow_to_pandas(c, i) for i, c in enumerate(data_arrow)] # state info yield ( @@ -1764,7 +1714,7 @@ def row_stream(): for batch in batches: self._update_batch_size_stats(batch) data_pandas = [ - self.arrow_to_pandas(c, i) + self._arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) ] for row in pd.concat(data_pandas, axis=1).itertuples(index=False): @@ -1894,14 +1844,12 @@ def row_stream(): flatten_state_table = flatten_columns(batch, "inputData") data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_state_table.itercolumns()) + self._arrow_to_pandas(c, i) for i, c in enumerate(flatten_state_table.itercolumns()) ] flatten_init_table = flatten_columns(batch, "initState") init_data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_init_table.itercolumns()) + self._arrow_to_pandas(c, i) for i, c in enumerate(flatten_init_table.itercolumns()) ] assert not (bool(init_data_pandas) and bool(data_pandas)) From 71c5e492f30c58f8b53a89227b291035f042baca Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:46:26 -0800 Subject: [PATCH 12/39] fix: format --- python/pyspark/sql/conversion.py | 6 +++++- python/pyspark/sql/pandas/serializers.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 3d6f7e8c87e0..50d2e891a15f 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1042,7 +1042,11 @@ def converter(arrow_column: "pa.Array", idx: int) -> "pd.Series": spark_type = input_types[idx] if input_types is not None else None # Special case: flatten struct to DataFrame when df_for_struct is enabled - if df_for_struct and types.is_struct(arrow_column.type) and not is_variant(arrow_column.type): + if ( + df_for_struct + and types.is_struct(arrow_column.type) + and not is_variant(arrow_column.type) + ): import pandas as pd return pd.concat( diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 7fe20881f2d2..2243a783c8f7 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -484,7 +484,9 @@ def load_stream(self, stream): import pandas as pd for batch in batches: - pandas_batches = [self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns)] + pandas_batches = [ + self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) + ] if len(pandas_batches) == 0: yield [pd.Series([pyspark._NoValue] * batch.num_rows)] else: @@ -1138,7 +1140,9 @@ def process_group(batches: "Iterator[pa.RecordBatch]"): # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch for batch in batches: # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) - series = [self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns)] + series = [ + self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) + ] yield series dataframes_in_group = None @@ -1844,12 +1848,14 @@ def row_stream(): flatten_state_table = flatten_columns(batch, "inputData") data_pandas = [ - self._arrow_to_pandas(c, i) for i, c in enumerate(flatten_state_table.itercolumns()) + self._arrow_to_pandas(c, i) + for i, c in enumerate(flatten_state_table.itercolumns()) ] flatten_init_table = flatten_columns(batch, "initState") init_data_pandas = [ - self._arrow_to_pandas(c, i) for i, c in enumerate(flatten_init_table.itercolumns()) + self._arrow_to_pandas(c, i) + for i, c in enumerate(flatten_init_table.itercolumns()) ] assert not (bool(init_data_pandas) and bool(data_pandas)) From 84bdb214baddd1e5ca121bd566b156c03c679c7e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 25 Jan 2026 12:02:09 -0800 Subject: [PATCH 13/39] refactor: simplify and add comments --- python/pyspark/sql/conversion.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 50d2e891a15f..2de64871298c 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -25,6 +25,8 @@ _dedup_names, _deduplicate_field_names, _create_converter_to_pandas, + from_arrow_type, + is_variant, to_arrow_schema, ) from pyspark.sql.pandas.utils import require_minimum_pyarrow_version @@ -1027,8 +1029,6 @@ def create_converter( """ import pyarrow.types as types - from pyspark.sql.pandas.types import from_arrow_type, is_variant - def convert(arr: "pa.Array", spark_type=None) -> "pd.Series": return cls.convert_legacy( arr, @@ -1041,21 +1041,23 @@ def convert(arr: "pa.Array", spark_type=None) -> "pd.Series": def converter(arrow_column: "pa.Array", idx: int) -> "pd.Series": spark_type = input_types[idx] if input_types is not None else None - # Special case: flatten struct to DataFrame when df_for_struct is enabled + # If the arrow struct is actually a Variant, which is an atomic type, + # treat it as a non-struct arrow type. if ( - df_for_struct - and types.is_struct(arrow_column.type) - and not is_variant(arrow_column.type) + not df_for_struct + or not types.is_struct(arrow_column.type) + or is_variant(arrow_column.type) ): - import pandas as pd + return convert(arrow_column, spark_type) - return pd.concat( - [ - convert(col, spark_type[i].dataType if spark_type else None).rename(f.name) - for i, (col, f) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) - ], - axis=1, - ) - return convert(arrow_column, spark_type) + # Struct case: return a pandas DataFrame where the fields of the struct + # correspond to columns in the DataFrame. + import pandas as pd + + series = [ + convert(col, spark_type[i].dataType if spark_type else None).rename(field.name) + for i, (col, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) + ] + return pd.concat(series, axis=1) return converter From 7ce137a3b85a037f9b2a3deeb2a019e49e3cc193 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 25 Jan 2026 13:26:48 -0800 Subject: [PATCH 14/39] fix: type annotation --- python/pyspark/sql/conversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 2de64871298c..fcadc5c32627 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1005,7 +1005,7 @@ def create_converter( ndarray_as_list: bool = False, df_for_struct: bool = False, input_types: Optional[List] = None, - ) -> Callable[["pa.Array", int], "pd.Series"]: + ) -> Callable[["pa.Array", int], Union["pd.Series", "pd.DataFrame"]]: """ Create an arrow_to_pandas converter function. @@ -1029,7 +1029,7 @@ def create_converter( """ import pyarrow.types as types - def convert(arr: "pa.Array", spark_type=None) -> "pd.Series": + def convert(arr: "pa.Array", spark_type: Optional[DataType] = None) -> "pd.Series": return cls.convert_legacy( arr, spark_type or from_arrow_type(arr.type), @@ -1038,7 +1038,7 @@ def convert(arr: "pa.Array", spark_type=None) -> "pd.Series": ndarray_as_list=ndarray_as_list, ) - def converter(arrow_column: "pa.Array", idx: int) -> "pd.Series": + def converter(arrow_column: "pa.Array", idx: int) -> Union["pd.Series", "pd.DataFrame"]: spark_type = input_types[idx] if input_types is not None else None # If the arrow struct is actually a Variant, which is an atomic type, From 6a0e89788be02391294c3584f5d323c49052028e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 25 Jan 2026 17:43:40 -0800 Subject: [PATCH 15/39] refactor: extract to_pandas transformer --- python/pyspark/sql/conversion.py | 53 ++++++++++++++++++++++++ python/pyspark/sql/pandas/serializers.py | 13 +----- 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 4268f03ad10e..86dd21e328d2 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -96,6 +96,59 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) + @staticmethod + def to_pandas( + timezone: str, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + input_types: Optional[List] = None, + ) -> Callable[["pa.RecordBatch"], List["pd.Series"]]: + """ + Create a transformer that converts a RecordBatch to a list of pandas Series. + + This is a batch-level transformer that internally uses the column-level + converter from ArrowArrayToPandasConversion.create_converter. + + Parameters + ---------- + timezone : str + Timezone for timestamp conversion. + struct_in_pandas : str + How to represent struct in pandas ("dict", "row", etc.) + ndarray_as_list : bool + Whether to convert ndarray as list. + df_for_struct : bool + If True, convert struct columns to DataFrame instead of Series. + input_types : list, optional + Spark types for each column, used for precise type conversion. + + Returns + ------- + callable + Function (RecordBatch) -> List[pd.Series] + + Used by: ArrowStreamPandasSerializer.load_stream and subclasses + """ + import pandas as pd + + import pyspark + + converter = ArrowArrayToPandasConversion.create_converter( + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + input_types=input_types, + ) + + def transform(batch: "pa.RecordBatch") -> List["pd.Series"]: + if batch.num_columns == 0: + return [pd.Series([pyspark._NoValue] * batch.num_rows)] + return [converter(batch.column(i), i) for i in range(batch.num_columns)] + + return transform + class LocalDataToArrowConversion: """ diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 8384a184028b..04c332847da8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -454,17 +454,8 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - batches = super().load_stream(stream) - import pandas as pd - - for batch in batches: - pandas_batches = [ - self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) - ] - if len(pandas_batches) == 0: - yield [pd.Series([pyspark._NoValue] * batch.num_rows)] - else: - yield pandas_batches + to_pandas = ArrowBatchTransformer.to_pandas(timezone=self._timezone) + yield from map(to_pandas, super().load_stream(stream)) def __repr__(self): return "ArrowStreamPandasSerializer" From 06aeec703b04a977c46bbea9e762dc3d7f959a75 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 18:48:13 -0800 Subject: [PATCH 16/39] refactor: simplify --- python/pyspark/sql/conversion.py | 126 +++++++---------------- python/pyspark/sql/pandas/serializers.py | 121 +++++++++++++++------- 2 files changed, 123 insertions(+), 124 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index be475aafbca1..ecade7fcacee 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1058,48 +1058,65 @@ def localize_tz(arr: "pa.Array") -> "pa.Array": class ArrowArrayToPandasConversion: - @classmethod - def convert_legacy( - cls, - arr: Union["pa.Array", "pa.ChunkedArray"], - spark_type: DataType, - *, - timezone: Optional[str] = None, - struct_in_pandas: Optional[str] = None, + @staticmethod + def convert( + arrow_column: Union["pa.Array", "pa.ChunkedArray"], + timezone: str, + struct_in_pandas: str = "dict", ndarray_as_list: bool = False, + spark_type: Optional[DataType] = None, df_for_struct: bool = False, ) -> Union["pd.Series", "pd.DataFrame"]: """ + Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. + Parameters ---------- - arr : :class:`pyarrow.Array`. - spark_type: target spark type, should always be specified. - timezone : The timezone to convert from. If there is a timestamp type, it's required. - struct_in_pandas : How to handle struct type. If there is a struct type, it's required. - ndarray_as_list : Whether `np.ndarray` is converted to a list or not. - df_for_struct: when true, and spark type is a StructType, return a DataFrame. + arrow_column : pa.Array or pa.ChunkedArray + The arrow column to convert. + timezone : str + Timezone for timestamp conversion. + struct_in_pandas : str, optional + How to represent struct in pandas ("dict", "row", etc.). Default is "dict". + ndarray_as_list : bool, optional + Whether to convert ndarray as list. Default is False. + spark_type : DataType, optional + Spark type for the column. If None, inferred from arrow type. + df_for_struct : bool, optional + If True, convert struct columns to DataFrame instead of Series. Default is False. + + Returns + ------- + pd.Series or pd.DataFrame + Converted pandas Series. If df_for_struct is True and the type is StructType, + returns a DataFrame with columns corresponding to struct fields. """ import pyarrow as pa import pandas as pd - assert isinstance(arr, (pa.Array, pa.ChunkedArray)) + if spark_type is None: + spark_type = from_arrow_type(arrow_column.type) + + assert isinstance(arrow_column, (pa.Array, pa.ChunkedArray)) if df_for_struct and isinstance(spark_type, StructType): import pyarrow.types as types - assert types.is_struct(arr.type) - assert len(spark_type.names) == len(arr.type.names), f"{spark_type} {arr.type} " + assert types.is_struct(arrow_column.type) + assert len(spark_type.names) == len( + arrow_column.type.names + ), f"{spark_type} {arrow_column.type} " series = [ - cls.convert_legacy( + ArrowArrayToPandasConversion.convert( field_arr, - spark_type=field.dataType, - timezone=timezone, + timezone, struct_in_pandas=struct_in_pandas, ndarray_as_list=ndarray_as_list, + spark_type=field.dataType, df_for_struct=False, # always False for child fields ) - for field_arr, field in zip(arr.flatten(), spark_type) + for field_arr, field in zip(arrow_column.flatten(), spark_type) ] pdf = pd.concat(series, axis=1) pdf.columns = spark_type.names # type: ignore[assignment] @@ -1114,7 +1131,7 @@ def convert_legacy( "coerce_temporal_nanoseconds": True, "integer_object_nulls": True, } - ser = arr.to_pandas(**pandas_options) + ser = arrow_column.to_pandas(**pandas_options) converter = _create_converter_to_pandas( data_type=spark_type, @@ -1126,68 +1143,3 @@ def convert_legacy( integer_object_nulls=True, ) return converter(ser) - - @classmethod - def create_converter( - cls, - timezone: str, - struct_in_pandas: str = "dict", - ndarray_as_list: bool = False, - df_for_struct: bool = False, - input_types: Optional[List] = None, - ) -> Callable[["pa.Array", int], Union["pd.Series", "pd.DataFrame"]]: - """ - Create an arrow_to_pandas converter function. - - Parameters - ---------- - timezone : str - Timezone for timestamp conversion. - struct_in_pandas : str - How to represent struct in pandas ("dict", "row", etc.) - ndarray_as_list : bool - Whether to convert ndarray as list. - df_for_struct : bool - If True, convert struct columns to DataFrame instead of Series. - input_types : list, optional - Spark types for each column, used for precise type conversion. - - Returns - ------- - callable - Function (arrow_column, idx) -> pd.Series or pd.DataFrame - """ - import pyarrow.types as types - - def convert(arr: "pa.Array", spark_type: Optional[DataType] = None) -> "pd.Series": - return cls.convert_legacy( - arr, - spark_type or from_arrow_type(arr.type), - timezone=timezone, - struct_in_pandas=struct_in_pandas, - ndarray_as_list=ndarray_as_list, - ) - - def converter(arrow_column: "pa.Array", idx: int) -> Union["pd.Series", "pd.DataFrame"]: - spark_type = input_types[idx] if input_types is not None else None - - # If the arrow struct is actually a Variant, which is an atomic type, - # treat it as a non-struct arrow type. - if ( - not df_for_struct - or not types.is_struct(arrow_column.type) - or is_variant(arrow_column.type) - ): - return convert(arrow_column, spark_type) - - # Struct case: return a pandas DataFrame where the fields of the struct - # correspond to columns in the DataFrame. - import pandas as pd - - series = [ - convert(col, spark_type[i].dataType if spark_type else None).rename(field.name) - for i, (col, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) - ] - return pd.concat(series, axis=1) - - return converter diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d58aae4b4df9..bd8f679f14a2 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -416,19 +416,6 @@ def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled - self._arrow_to_pandas = ArrowArrayToPandasConversion.create_converter(timezone=timezone) - - def arrow_to_pandas( - self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None - ): - return ArrowArrayToPandasConversion.convert_legacy( - arrow_column, - spark_type or from_arrow_type(arrow_column.type), - timezone=self._timezone, - struct_in_pandas=struct_in_pandas, - ndarray_as_list=ndarray_as_list, - df_for_struct=False, - ) def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ @@ -578,26 +565,13 @@ def __init__( super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled) self._assign_cols_by_name = assign_cols_by_name self._struct_in_pandas = struct_in_pandas + self._ndarray_as_list = ndarray_as_list + self._df_for_struct = df_for_struct self._arrow_cast = arrow_cast if input_type is not None: assert isinstance(input_type, StructType) self._input_type = input_type - def arrow_to_pandas(self, arrow_column, idx): - if self._input_type is not None: - spark_type = self._input_type[idx].dataType - else: - spark_type = from_arrow_type(arrow_column.type) - - return ArrowArrayToPandasConversion.convert_legacy( - arr=arrow_column, - spark_type=spark_type, - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - def _create_struct_array( self, df: "pd.DataFrame", @@ -1094,7 +1068,17 @@ def load_stream(self, stream): # Lazily read and convert Arrow batches to pandas Series one at a time # from the stream. This avoids loading all batches into memory for the group series_iter = ( - tuple(self.arrow_to_pandas(c, i) for i, c in enumerate(batch.columns)) + tuple( + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) + for i, c in enumerate(batch.columns) + ) for batch in batches ) yield series_iter @@ -1139,7 +1123,15 @@ def process_group(batches: "Iterator[pa.RecordBatch]"): for batch in batches: # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) series = [ - self._arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) + ArrowArrayToPandasConversion.convert( + batch.column(i), + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) + for i in range(batch.num_columns) ] yield series @@ -1198,11 +1190,25 @@ def load_stream(self, stream): for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): yield ( [ - self.arrow_to_pandas(c, i) + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) for i, c in enumerate(pa.Table.from_batches(left_batches).itercolumns()) ], [ - self.arrow_to_pandas(c, i) + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) for i, c in enumerate(pa.Table.from_batches(right_batches).itercolumns()) ], ) @@ -1365,7 +1371,17 @@ def gen_data_and_state(batches): ) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_pandas = [self._arrow_to_pandas(c, i) for i, c in enumerate(state_arrow)][0] + state_pandas = [ + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) + for i, c in enumerate(state_arrow) + ][0] for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -1397,7 +1413,17 @@ def gen_data_and_state(batches): data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() - data_pandas = [self._arrow_to_pandas(c, i) for i, c in enumerate(data_arrow)] + data_pandas = [ + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) + for i, c in enumerate(data_arrow) + ] # state info yield ( @@ -1661,7 +1687,14 @@ def row_stream(): for batch in batches: self._update_batch_size_stats(batch) data_pandas = [ - self._arrow_to_pandas(c, i) + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) ] for row in pd.concat(data_pandas, axis=1).itertuples(index=False): @@ -1791,13 +1824,27 @@ def row_stream(): flatten_state_table = flatten_columns(batch, "inputData") data_pandas = [ - self._arrow_to_pandas(c, i) + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) for i, c in enumerate(flatten_state_table.itercolumns()) ] flatten_init_table = flatten_columns(batch, "initState") init_data_pandas = [ - self._arrow_to_pandas(c, i) + ArrowArrayToPandasConversion.convert( + c, + self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=self._input_type[i] if self._input_type else None, + df_for_struct=self._df_for_struct, + ) for i, c in enumerate(flatten_init_table.itercolumns()) ] From 56f6a377b669859e8194baab7f53cc2171dd3099 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 18:57:52 -0800 Subject: [PATCH 17/39] fix: revert changes --- python/pyspark/sql/conversion.py | 53 ------------------------ python/pyspark/sql/pandas/serializers.py | 18 +++++++- 2 files changed, 16 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index ecade7fcacee..0b5cff777ff5 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -97,59 +97,6 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) - @staticmethod - def to_pandas( - timezone: str, - struct_in_pandas: str = "dict", - ndarray_as_list: bool = False, - df_for_struct: bool = False, - input_types: Optional[List] = None, - ) -> Callable[["pa.RecordBatch"], List["pd.Series"]]: - """ - Create a transformer that converts a RecordBatch to a list of pandas Series. - - This is a batch-level transformer that internally uses the column-level - converter from ArrowArrayToPandasConversion.create_converter. - - Parameters - ---------- - timezone : str - Timezone for timestamp conversion. - struct_in_pandas : str - How to represent struct in pandas ("dict", "row", etc.) - ndarray_as_list : bool - Whether to convert ndarray as list. - df_for_struct : bool - If True, convert struct columns to DataFrame instead of Series. - input_types : list, optional - Spark types for each column, used for precise type conversion. - - Returns - ------- - callable - Function (RecordBatch) -> List[pd.Series] - - Used by: ArrowStreamPandasSerializer.load_stream and subclasses - """ - import pandas as pd - - import pyspark - - converter = ArrowArrayToPandasConversion.create_converter( - timezone=timezone, - struct_in_pandas=struct_in_pandas, - ndarray_as_list=ndarray_as_list, - df_for_struct=df_for_struct, - input_types=input_types, - ) - - def transform(batch: "pa.RecordBatch") -> List["pd.Series"]: - if batch.num_columns == 0: - return [pd.Series([pyspark._NoValue] * batch.num_rows)] - return [converter(batch.column(i), i) for i in range(batch.num_columns)] - - return transform - class LocalDataToArrowConversion: """ diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index bd8f679f14a2..0f3d499b32d3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -538,8 +538,22 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - to_pandas = ArrowBatchTransformer.to_pandas(timezone=self._timezone) - yield from map(to_pandas, super().load_stream(stream)) + import pandas as pd + import pyspark + + batches = super().load_stream(stream) + for batch in batches: + if batch.num_columns == 0: + yield [pd.Series([pyspark._NoValue] * batch.num_rows)] + else: + pandas_batches = [ + ArrowArrayToPandasConversion.convert( + batch.column(i), + self._timezone, + ) + for i in range(batch.num_columns) + ] + yield pandas_batches def __repr__(self): return "ArrowStreamPandasSerializer" From 4bdf46c466f35950b462603f3a7c0d069e5d9e3b Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 19:06:53 -0800 Subject: [PATCH 18/39] revert: bring back convert_legacy --- python/pyspark/sql/conversion.py | 95 +++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 0b5cff777ff5..274b01d981c6 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1005,13 +1005,15 @@ def localize_tz(arr: "pa.Array") -> "pa.Array": class ArrowArrayToPandasConversion: - @staticmethod - def convert( - arrow_column: Union["pa.Array", "pa.ChunkedArray"], - timezone: str, - struct_in_pandas: str = "dict", + @classmethod + def convert_legacy( + cls, + arr: Union["pa.Array", "pa.ChunkedArray"], + spark_type: DataType, + *, + timezone: Optional[str] = None, + struct_in_pandas: Optional[str] = None, ndarray_as_list: bool = False, - spark_type: Optional[DataType] = None, df_for_struct: bool = False, ) -> Union["pd.Series", "pd.DataFrame"]: """ @@ -1019,18 +1021,18 @@ def convert( Parameters ---------- - arrow_column : pa.Array or pa.ChunkedArray + arr : pa.Array or pa.ChunkedArray The arrow column to convert. - timezone : str - Timezone for timestamp conversion. + spark_type : DataType + Target spark type, should always be specified. + timezone : str, optional + The timezone to convert from. If there is a timestamp type, it's required. struct_in_pandas : str, optional - How to represent struct in pandas ("dict", "row", etc.). Default is "dict". + How to handle struct type. If there is a struct type, it's required. ndarray_as_list : bool, optional - Whether to convert ndarray as list. Default is False. - spark_type : DataType, optional - Spark type for the column. If None, inferred from arrow type. + Whether `np.ndarray` is converted to a list or not. df_for_struct : bool, optional - If True, convert struct columns to DataFrame instead of Series. Default is False. + When true, and spark type is a StructType, return a DataFrame. Returns ------- @@ -1041,29 +1043,24 @@ def convert( import pyarrow as pa import pandas as pd - if spark_type is None: - spark_type = from_arrow_type(arrow_column.type) - - assert isinstance(arrow_column, (pa.Array, pa.ChunkedArray)) + assert isinstance(arr, (pa.Array, pa.ChunkedArray)) if df_for_struct and isinstance(spark_type, StructType): import pyarrow.types as types - assert types.is_struct(arrow_column.type) - assert len(spark_type.names) == len( - arrow_column.type.names - ), f"{spark_type} {arrow_column.type} " + assert types.is_struct(arr.type) + assert len(spark_type.names) == len(arr.type.names), f"{spark_type} {arr.type} " series = [ - ArrowArrayToPandasConversion.convert( + cls.convert_legacy( field_arr, - timezone, + spark_type=field.dataType, + timezone=timezone, struct_in_pandas=struct_in_pandas, ndarray_as_list=ndarray_as_list, - spark_type=field.dataType, df_for_struct=False, # always False for child fields ) - for field_arr, field in zip(arrow_column.flatten(), spark_type) + for field_arr, field in zip(arr.flatten(), spark_type) ] pdf = pd.concat(series, axis=1) pdf.columns = spark_type.names # type: ignore[assignment] @@ -1078,7 +1075,7 @@ def convert( "coerce_temporal_nanoseconds": True, "integer_object_nulls": True, } - ser = arrow_column.to_pandas(**pandas_options) + ser = arr.to_pandas(**pandas_options) converter = _create_converter_to_pandas( data_type=spark_type, @@ -1090,3 +1087,47 @@ def convert( integer_object_nulls=True, ) return converter(ser) + + @staticmethod + def convert( + arrow_column: Union["pa.Array", "pa.ChunkedArray"], + timezone: str, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + spark_type: Optional[DataType] = None, + df_for_struct: bool = False, + ) -> Union["pd.Series", "pd.DataFrame"]: + """ + Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. + + This is a convenience method that calls convert_legacy with a more intuitive signature. + + Parameters + ---------- + arrow_column : pa.Array or pa.ChunkedArray + The arrow column to convert. + timezone : str + Timezone for timestamp conversion. + struct_in_pandas : str, optional + How to represent struct in pandas ("dict", "row", etc.). Default is "dict". + ndarray_as_list : bool, optional + Whether to convert ndarray as list. Default is False. + spark_type : DataType, optional + Spark type for the column. If None, inferred from arrow type. + df_for_struct : bool, optional + If True, convert struct columns to DataFrame instead of Series. Default is False. + + Returns + ------- + pd.Series or pd.DataFrame + Converted pandas Series. If df_for_struct is True and the type is StructType, + returns a DataFrame with columns corresponding to struct fields. + """ + return ArrowArrayToPandasConversion.convert_legacy( + arrow_column, + spark_type or from_arrow_type(arrow_column.type), + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + ) From eecdd13e067f54fdbe932cadf7eecef12c34aa3a Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:15:17 -0800 Subject: [PATCH 19/39] fix: comments --- python/pyspark/sql/pandas/serializers.py | 67 ++++++++++++++++++------ 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 0f3d499b32d3..9a99f34e9622 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -404,18 +404,35 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): A timezone to respect when handling timestamp values safecheck : bool If True, conversion from Arrow to Pandas checks for overflow/truncation - assign_cols_by_name : bool - If True, then Pandas DataFrames will get columns by name int_to_decimal_coercion_enabled : bool If True, applies additional coercions in Python before converting to Arrow This has performance penalties. + struct_in_pandas : str, optional + How to represent struct in pandas ("dict", "row", etc.). Default is "dict". + ndarray_as_list : bool, optional + Whether to convert ndarray as list. Default is False. + df_for_struct : bool, optional + If True, convert struct columns to DataFrame instead of Series. Default is False. + input_type : StructType, optional + Spark types for each column. Default is None. """ - def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): + def __init__( + self, + timezone, + safecheck, + int_to_decimal_coercion_enabled: bool = False, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + ): super().__init__() self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled + self._struct_in_pandas = struct_in_pandas + self._ndarray_as_list = ndarray_as_list + self._df_for_struct = df_for_struct def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ @@ -550,6 +567,14 @@ def load_stream(self, stream): ArrowArrayToPandasConversion.convert( batch.column(i), self._timezone, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + spark_type=( + self._input_type[i] + if hasattr(self, "_input_type") and self._input_type + else None + ), + df_for_struct=self._df_for_struct, ) for i in range(batch.num_columns) ] @@ -576,11 +601,15 @@ def __init__( input_type: Optional[StructType] = None, int_to_decimal_coercion_enabled: bool = False, ): - super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled) + super().__init__( + timezone, + safecheck, + int_to_decimal_coercion_enabled, + struct_in_pandas, + ndarray_as_list, + df_for_struct, + ) self._assign_cols_by_name = assign_cols_by_name - self._struct_in_pandas = struct_in_pandas - self._ndarray_as_list = ndarray_as_list - self._df_for_struct = df_for_struct self._arrow_cast = arrow_cast if input_type is not None: assert isinstance(input_type, StructType) @@ -1257,11 +1286,15 @@ def __init__( int_to_decimal_coercion_enabled, ): super().__init__( - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + timezone=timezone, + safecheck=safecheck, + assign_cols_by_name=assign_cols_by_name, + df_for_struct=False, + struct_in_pandas="dict", + ndarray_as_list=False, arrow_cast=True, + input_type=None, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() @@ -1643,11 +1676,15 @@ def __init__( int_to_decimal_coercion_enabled, ): super().__init__( - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + timezone=timezone, + safecheck=safecheck, + assign_cols_by_name=assign_cols_by_name, + df_for_struct=False, + struct_in_pandas="dict", + ndarray_as_list=False, arrow_cast=True, + input_type=None, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) self.arrow_max_records_per_batch = ( arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 From 9efc62358467e76aa0a6e9659ac22dc3af243d6e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 23:21:00 -0800 Subject: [PATCH 20/39] fix: type --- python/pyspark/sql/pandas/serializers.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9a99f34e9622..03abf003f55a 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -570,7 +570,7 @@ def load_stream(self, stream): struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, spark_type=( - self._input_type[i] + self._input_type[i].dataType if hasattr(self, "_input_type") and self._input_type else None ), @@ -1117,7 +1117,7 @@ def load_stream(self, stream): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(batch.columns) @@ -1171,7 +1171,7 @@ def process_group(batches: "Iterator[pa.RecordBatch]"): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i in range(batch.num_columns) @@ -1238,7 +1238,7 @@ def load_stream(self, stream): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches(left_batches).itercolumns()) @@ -1249,7 +1249,7 @@ def load_stream(self, stream): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches(right_batches).itercolumns()) @@ -1424,7 +1424,7 @@ def gen_data_and_state(batches): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(state_arrow) @@ -1466,7 +1466,7 @@ def gen_data_and_state(batches): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(data_arrow) @@ -1743,7 +1743,7 @@ def row_stream(): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1880,7 +1880,7 @@ def row_stream(): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(flatten_state_table.itercolumns()) @@ -1893,7 +1893,7 @@ def row_stream(): self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i] if self._input_type else None, + spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(flatten_init_table.itercolumns()) From e5c6ad1c6a15a7fc3a865d820bada882c2804dab Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 27 Jan 2026 23:43:55 -0800 Subject: [PATCH 21/39] fix: unused import --- python/pyspark/sql/conversion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 274b01d981c6..69cdc7767de1 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -26,7 +26,6 @@ _deduplicate_field_names, _create_converter_to_pandas, from_arrow_type, - is_variant, to_arrow_schema, ) from pyspark.sql.pandas.utils import require_minimum_pyarrow_version From 1c7c9ed5082d3f49709a9f1aebe26566ce944160 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 28 Jan 2026 17:04:03 -0800 Subject: [PATCH 22/39] refactor: use spark_type from callsite --- python/pyspark/sql/conversion.py | 84 ++++++++++++++++-------- python/pyspark/sql/pandas/serializers.py | 70 ++++++++++++-------- 2 files changed, 102 insertions(+), 52 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 69cdc7767de1..14bab1f36335 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1004,9 +1004,18 @@ def localize_tz(arr: "pa.Array") -> "pa.Array": class ArrowArrayToPandasConversion: - @classmethod + """ + Conversion utilities for converting PyArrow Arrays and ChunkedArrays to pandas. + + This class provides methods to convert PyArrow columnar data structures to pandas + Series or DataFrames, with support for Spark-specific type handling and conversions. + + The class is primarily used by PySpark's Arrow-based serializers for UDF execution, + where Arrow data needs to be converted to pandas for Python UDF processing. + """ + + @staticmethod def convert_legacy( - cls, arr: Union["pa.Array", "pa.ChunkedArray"], spark_type: DataType, *, @@ -1018,26 +1027,37 @@ def convert_legacy( """ Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. + This is the lower-level conversion method that requires explicit Spark type + specification. For a more convenient API, see :meth:`convert`. + Parameters ---------- arr : pa.Array or pa.ChunkedArray The arrow column to convert. spark_type : DataType - Target spark type, should always be specified. + Target Spark type. Must be specified and should match the Arrow array type. timezone : str, optional - The timezone to convert from. If there is a timestamp type, it's required. + The timezone to use for timestamp conversion. Required if the data contains + timestamp types. struct_in_pandas : str, optional - How to handle struct type. If there is a struct type, it's required. + How to handle struct types in pandas. Valid values are "dict", "row", or "legacy". + Required if the data contains struct types. ndarray_as_list : bool, optional - Whether `np.ndarray` is converted to a list or not. + Whether to convert numpy ndarrays to Python lists. Default is False. df_for_struct : bool, optional - When true, and spark type is a StructType, return a DataFrame. + If True and spark_type is a StructType, return a DataFrame with columns + corresponding to struct fields instead of a Series. Default is False. Returns ------- pd.Series or pd.DataFrame - Converted pandas Series. If df_for_struct is True and the type is StructType, + Converted pandas Series. If df_for_struct is True and spark_type is StructType, returns a DataFrame with columns corresponding to struct fields. + + Notes + ----- + This method handles date type columns specially to avoid overflow issues with + datetime64[ns] intermediate representations. """ import pyarrow as pa import pandas as pd @@ -1048,10 +1068,14 @@ def convert_legacy( import pyarrow.types as types assert types.is_struct(arr.type) - assert len(spark_type.names) == len(arr.type.names), f"{spark_type} {arr.type} " + assert len(spark_type.names) == len(arr.type.names), ( + f"Schema mismatch: spark_type has {len(spark_type.names)} fields, " + f"but arrow type has {len(arr.type.names)} fields. " + f"spark_type={spark_type}, arrow_type={arr.type}" + ) series = [ - cls.convert_legacy( + ArrowArrayToPandasConversion.convert_legacy( field_arr, spark_type=field.dataType, timezone=timezone, @@ -1065,10 +1089,11 @@ def convert_legacy( pdf.columns = spark_type.names # type: ignore[assignment] return pdf - # If the given column is a date type column, creates a series of datetime.date directly - # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by - # datetime64[ns] type handling. - # Cast dates to objects instead of datetime64[ns] dtype to avoid overflow. + # Convert Arrow array to pandas Series with specific options: + # - date_as_object: Convert date types to Python datetime.date objects directly + # instead of datetime64[ns] to avoid overflow issues + # - coerce_temporal_nanoseconds: Handle nanosecond precision timestamps correctly + # - integer_object_nulls: Use object dtype for integer arrays with nulls pandas_options = { "date_as_object": True, "coerce_temporal_nanoseconds": True, @@ -1090,41 +1115,48 @@ def convert_legacy( @staticmethod def convert( arrow_column: Union["pa.Array", "pa.ChunkedArray"], - timezone: str, + target_type: DataType, + timezone: Optional[str] = None, struct_in_pandas: str = "dict", ndarray_as_list: bool = False, - spark_type: Optional[DataType] = None, df_for_struct: bool = False, ) -> Union["pd.Series", "pd.DataFrame"]: """ Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. - This is a convenience method that calls convert_legacy with a more intuitive signature. + This is a convenience method that provides a more intuitive API than + :meth:`convert_legacy`. Parameters ---------- arrow_column : pa.Array or pa.ChunkedArray - The arrow column to convert. - timezone : str - Timezone for timestamp conversion. + The Arrow column to convert. + target_type : DataType + The target Spark type for the column to be coverted to. + timezone : str, optional + Timezone for timestamp conversion. Required if the data contains timestamp types. struct_in_pandas : str, optional - How to represent struct in pandas ("dict", "row", etc.). Default is "dict". + How to represent struct types in pandas. Valid values are "dict", "row", or "legacy". + Default is "dict". ndarray_as_list : bool, optional - Whether to convert ndarray as list. Default is False. - spark_type : DataType, optional - Spark type for the column. If None, inferred from arrow type. + Whether to convert numpy ndarrays to Python lists. Default is False. df_for_struct : bool, optional - If True, convert struct columns to DataFrame instead of Series. Default is False. + If True, convert struct columns to a DataFrame with columns corresponding + to struct fields instead of a Series. Default is False. Returns ------- pd.Series or pd.DataFrame Converted pandas Series. If df_for_struct is True and the type is StructType, returns a DataFrame with columns corresponding to struct fields. + + See Also + -------- + convert_legacy : Lower-level conversion method that requires explicit Spark type. """ return ArrowArrayToPandasConversion.convert_legacy( arrow_column, - spark_type or from_arrow_type(arrow_column.type), + target_type, timezone=timezone, struct_in_pandas=struct_in_pandas, ndarray_as_list=ndarray_as_list, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 03abf003f55a..e417282b7e07 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -566,14 +566,12 @@ def load_stream(self, stream): pandas_batches = [ ArrowArrayToPandasConversion.convert( batch.column(i), - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(batch.column(i).type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=( - self._input_type[i].dataType - if hasattr(self, "_input_type") and self._input_type - else None - ), df_for_struct=self._df_for_struct, ) for i in range(batch.num_columns) @@ -1114,10 +1112,12 @@ def load_stream(self, stream): tuple( ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(batch.columns) @@ -1164,14 +1164,15 @@ def load_stream(self, stream): def process_group(batches: "Iterator[pa.RecordBatch]"): # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch for batch in batches: - # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) series = [ ArrowArrayToPandasConversion.convert( batch.column(i), - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(batch.column(i).type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i in range(batch.num_columns) @@ -1235,10 +1236,12 @@ def load_stream(self, stream): [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches(left_batches).itercolumns()) @@ -1246,10 +1249,12 @@ def load_stream(self, stream): [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches(right_batches).itercolumns()) @@ -1418,16 +1423,21 @@ def gen_data_and_state(batches): ) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_spark_type = ( + self._input_type[-1].dataType + if self._input_type is not None + else from_arrow_type(state_batch.schema[0]) + ) state_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + state_spark_type, + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) - for i, c in enumerate(state_arrow) + for c in state_arrow ][0] for state_idx in range(0, len(state_pandas)): @@ -1463,10 +1473,12 @@ def gen_data_and_state(batches): data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(data_arrow) @@ -1740,10 +1752,12 @@ def row_stream(): data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1877,10 +1891,12 @@ def row_stream(): data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(flatten_state_table.itercolumns()) @@ -1890,10 +1906,12 @@ def row_stream(): init_data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._timezone, + self._input_type[i].dataType + if self._input_type is not None + else from_arrow_type(c.type), + timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, - spark_type=self._input_type[i].dataType if self._input_type else None, df_for_struct=self._df_for_struct, ) for i, c in enumerate(flatten_init_table.itercolumns()) From cbb3a90c024956361d1c2aec2dd85b5d2da9b0c1 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:58:00 -0800 Subject: [PATCH 23/39] fix: use classmethod --- python/pyspark/sql/conversion.py | 20 ++++++++------------ python/pyspark/sql/pandas/serializers.py | 2 +- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 14bab1f36335..32c386ee84d6 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -1014,8 +1014,9 @@ class ArrowArrayToPandasConversion: where Arrow data needs to be converted to pandas for Python UDF processing. """ - @staticmethod + @classmethod def convert_legacy( + cls, arr: Union["pa.Array", "pa.ChunkedArray"], spark_type: DataType, *, @@ -1075,7 +1076,7 @@ def convert_legacy( ) series = [ - ArrowArrayToPandasConversion.convert_legacy( + cls.convert_legacy( field_arr, spark_type=field.dataType, timezone=timezone, @@ -1112,10 +1113,12 @@ def convert_legacy( ) return converter(ser) - @staticmethod + @classmethod def convert( + cls, arrow_column: Union["pa.Array", "pa.ChunkedArray"], target_type: DataType, + *, timezone: Optional[str] = None, struct_in_pandas: str = "dict", ndarray_as_list: bool = False, @@ -1124,15 +1127,12 @@ def convert( """ Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. - This is a convenience method that provides a more intuitive API than - :meth:`convert_legacy`. - Parameters ---------- arrow_column : pa.Array or pa.ChunkedArray The Arrow column to convert. target_type : DataType - The target Spark type for the column to be coverted to. + The target Spark type for the column to be converted to. timezone : str, optional Timezone for timestamp conversion. Required if the data contains timestamp types. struct_in_pandas : str, optional @@ -1149,12 +1149,8 @@ def convert( pd.Series or pd.DataFrame Converted pandas Series. If df_for_struct is True and the type is StructType, returns a DataFrame with columns corresponding to struct fields. - - See Also - -------- - convert_legacy : Lower-level conversion method that requires explicit Spark type. """ - return ArrowArrayToPandasConversion.convert_legacy( + return cls.convert_legacy( arrow_column, target_type, timezone=timezone, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index e417282b7e07..48985e92ec8d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1426,7 +1426,7 @@ def gen_data_and_state(batches): state_spark_type = ( self._input_type[-1].dataType if self._input_type is not None - else from_arrow_type(state_batch.schema[0]) + else from_arrow_type(state_batch.schema[0].type) ) state_pandas = [ ArrowArrayToPandasConversion.convert( From bec3f44d889b0a97054681af8011567ad0eef156 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 28 Jan 2026 21:15:36 -0800 Subject: [PATCH 24/39] fix: simplify --- python/pyspark/sql/pandas/serializers.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 48985e92ec8d..dfaac9fa9c76 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1423,15 +1423,10 @@ def gen_data_and_state(batches): ) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_spark_type = ( - self._input_type[-1].dataType - if self._input_type is not None - else from_arrow_type(state_batch.schema[0].type) - ) state_pandas = [ ArrowArrayToPandasConversion.convert( c, - state_spark_type, + from_arrow_type(c.type), timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, @@ -1473,15 +1468,13 @@ def gen_data_and_state(batches): data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), + from_arrow_type(c.type), timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, df_for_struct=self._df_for_struct, ) - for i, c in enumerate(data_arrow) + for c in data_arrow ] # state info From cf5187610ef6b3eab6a57f38f592f795ed8e47e7 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 28 Jan 2026 23:13:37 -0800 Subject: [PATCH 25/39] fix: import and doc --- python/pyspark/sql/conversion.py | 2 -- python/pyspark/sql/pandas/serializers.py | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 32c386ee84d6..105c06c7a4a1 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -25,7 +25,6 @@ _dedup_names, _deduplicate_field_names, _create_converter_to_pandas, - from_arrow_type, to_arrow_schema, ) from pyspark.sql.pandas.utils import require_minimum_pyarrow_version @@ -69,7 +68,6 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record Flatten a struct column at given index into a RecordBatch. Used by: - - ArrowStreamGroupUDFSerializer.load_stream - ArrowStreamUDFSerializer.load_stream - SQL_GROUPED_MAP_ARROW_UDF mapper - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index dfaac9fa9c76..6e5facc79aa6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1745,15 +1745,13 @@ def row_stream(): data_pandas = [ ArrowArrayToPandasConversion.convert( c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), + from_arrow_type(c.type), timezone=self._timezone, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, df_for_struct=self._df_for_struct, ) - for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) + for c in pa.Table.from_batches([batch]).itercolumns() ] for row in pd.concat(data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.key_offsets) From 78f2920ae4672b11456b7c4c7de151c1b804aa6d Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 29 Jan 2026 21:39:49 +0000 Subject: [PATCH 26/39] refactor: extract `to_pandas` transformer --- python/pyspark/sql/conversion.py | 53 ++++++ python/pyspark/sql/pandas/serializers.py | 219 ++++++++--------------- 2 files changed, 128 insertions(+), 144 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 105c06c7a4a1..4307eb700c0b 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -94,6 +94,59 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) + @classmethod + def to_pandas( + cls, + batch: Union["pa.RecordBatch", "pa.Table"], + timezone: str, + schema: Optional["StructType"] = None, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + ) -> List[Union["pd.Series", "pd.DataFrame"]]: + """ + Convert a RecordBatch or Table to a list of pandas Series. + + Parameters + ---------- + batch : pa.RecordBatch or pa.Table + The Arrow RecordBatch or Table to convert. + timezone : str + Timezone for timestamp conversion. + schema : StructType, optional + Spark schema for type conversion. If None, types are inferred from Arrow. + struct_in_pandas : str + How to represent struct in pandas ("dict", "row", etc.) + ndarray_as_list : bool + Whether to convert ndarray as list. + df_for_struct : bool + If True, convert struct columns to DataFrame instead of Series. + + Returns + ------- + List[Union[pd.Series, pd.DataFrame]] + List of pandas Series (or DataFrame if df_for_struct=True), one for each column. + """ + import pandas as pd + + import pyspark + from pyspark.sql.pandas.types import from_arrow_type + + if batch.num_columns == 0: + return [pd.Series([pyspark._NoValue] * batch.num_rows)] + + return [ + ArrowArrayToPandasConversion.convert( + batch.column(i), + schema[i].dataType if schema is not None else from_arrow_type(batch.column(i).type), + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + ) + for i in range(batch.num_columns) + ] + class LocalDataToArrowConversion: """ diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6e5facc79aa6..55b93bd4da58 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -555,28 +555,17 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - import pandas as pd - import pyspark - - batches = super().load_stream(stream) - for batch in batches: - if batch.num_columns == 0: - yield [pd.Series([pyspark._NoValue] * batch.num_rows)] - else: - pandas_batches = [ - ArrowArrayToPandasConversion.convert( - batch.column(i), - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(batch.column(i).type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i in range(batch.num_columns) - ] - yield pandas_batches + yield from map( + lambda batch: ArrowBatchTransformer.to_pandas( + batch, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ), + super().load_stream(stream), + ) def __repr__(self): return "ArrowStreamPandasSerializer" @@ -1108,21 +1097,18 @@ def load_stream(self, stream): for (batches,) in self._load_group_dataframes(stream, num_dfs=1): # Lazily read and convert Arrow batches to pandas Series one at a time # from the stream. This avoids loading all batches into memory for the group - series_iter = ( - tuple( - ArrowArrayToPandasConversion.convert( - c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), + series_iter = map( + lambda batch: tuple( + ArrowBatchTransformer.to_pandas( + batch, timezone=self._timezone, + schema=self._input_type, struct_in_pandas=self._struct_in_pandas, ndarray_as_list=self._ndarray_as_list, df_for_struct=self._df_for_struct, ) - for i, c in enumerate(batch.columns) - ) - for batch in batches + ), + batches, ) yield series_iter # Make sure the batches are fully iterated before getting the next group @@ -1160,29 +1146,20 @@ def load_stream(self, stream): Each outer iterator element represents a group, containing an iterator of Series lists (one list per batch). """ - - def process_group(batches: "Iterator[pa.RecordBatch]"): - # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch - for batch in batches: - series = [ - ArrowArrayToPandasConversion.convert( - batch.column(i), - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(batch.column(i).type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i in range(batch.num_columns) - ] - yield series - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): # Lazily read and convert Arrow batches one at a time from the stream # This avoids loading all batches into memory for the group - series_iter = process_group(batches) + series_iter = map( + lambda batch: ArrowBatchTransformer.to_pandas( + batch, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ), + batches, + ) yield series_iter # Make sure the batches are fully iterated before getting the next group for _ in series_iter: @@ -1232,33 +1209,16 @@ def load_stream(self, stream): import pyarrow as pa for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield ( - [ - ArrowArrayToPandasConversion.convert( - c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i, c in enumerate(pa.Table.from_batches(left_batches).itercolumns()) - ], - [ - ArrowArrayToPandasConversion.convert( - c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i, c in enumerate(pa.Table.from_batches(right_batches).itercolumns()) - ], + yield tuple( + ArrowBatchTransformer.to_pandas( + pa.Table.from_batches(batches), + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) + for batches in (left_batches, right_batches) ) @@ -1422,18 +1382,14 @@ def gen_data_and_state(batches): schema=state_schema, ) - state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_pandas = [ - ArrowArrayToPandasConversion.convert( - c, - from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for c in state_arrow - ][0] + state_pandas = ArrowBatchTransformer.to_pandas( + state_batch, + timezone=self._timezone, + schema=None, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + )[0] for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -1463,19 +1419,14 @@ def gen_data_and_state(batches): state_for_current_group = state data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) - data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() - - data_pandas = [ - ArrowArrayToPandasConversion.convert( - c, - from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for c in data_arrow - ] + data_pandas = ArrowBatchTransformer.to_pandas( + data_batch_for_group, + timezone=self._timezone, + schema=None, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) # state info yield ( @@ -1742,17 +1693,14 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: self._update_batch_size_stats(batch) - data_pandas = [ - ArrowArrayToPandasConversion.convert( - c, - from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for c in pa.Table.from_batches([batch]).itercolumns() - ] + data_pandas = ArrowBatchTransformer.to_pandas( + batch, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) for row in pd.concat(data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.key_offsets) yield (batch_key, row) @@ -1874,39 +1822,22 @@ def flatten_columns(cur_batch, col_name): but each batch will have either init_data or input_data, not mix. """ + def to_pandas(table): + return ArrowBatchTransformer.to_pandas( + table, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) + def row_stream(): for batch in batches: self._update_batch_size_stats(batch) - flatten_state_table = flatten_columns(batch, "inputData") - data_pandas = [ - ArrowArrayToPandasConversion.convert( - c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i, c in enumerate(flatten_state_table.itercolumns()) - ] - - flatten_init_table = flatten_columns(batch, "initState") - init_data_pandas = [ - ArrowArrayToPandasConversion.convert( - c, - self._input_type[i].dataType - if self._input_type is not None - else from_arrow_type(c.type), - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for i, c in enumerate(flatten_init_table.itercolumns()) - ] + data_pandas = to_pandas(flatten_columns(batch, "inputData")) + init_data_pandas = to_pandas(flatten_columns(batch, "initState")) assert not (bool(init_data_pandas) and bool(data_pandas)) From 2ca4b27bf8f367bcb040888912abc951533aa5cd Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:42:46 +0000 Subject: [PATCH 27/39] refactor: simplify serializers and move data transformations to wrapper/mapper layers This refactoring separates concerns between serialization and data transformation: - Serializers now only handle data serialization/deserialization - Wrappers handle UDF wrapping and data format conversion - Mappers handle UDF result aggregation using transformer utilities Key changes: 1. Simplified serializer hierarchy by removing redundant classes: - Removed ArrowStreamPandasUDFSerializer, ArrowStreamGroupUDFSerializer, ArrowStreamArrowUDFSerializer, ArrowStreamUDTFSerializer - Renamed ArrowStreamMapIterSerializer to ArrowStreamGroupSerializer - Unified grouped/non-grouped UDF handling in ArrowStreamGroupSerializer 2. Introduced transformer utility classes: - ArrowBatchTransformer: Arrow batch operations (wrap_struct, flatten_struct, partial_batch, partial_table, concat_batches, merge_batches, reorder_columns) - PandasBatchTransformer: Pandas/Arrow conversions (to_arrow, concat_series_batches) 3. Moved data transformations from serializers to wrappers/mappers: - Moved to_arrow conversion from mappers to wrappers for Pandas agg UDFs - Wrappers now return RecordBatch directly instead of (result, arrow_type) tuples - Mappers simplified to use transformer methods for common operations Benefits: - Clearer separation of concerns - Reduced code duplication through transformer utilities - Easier to maintain and extend - Consistent data format handling across UDF types --- python/pyspark/sql/connect/session.py | 26 +- python/pyspark/sql/conversion.py | 831 ++++++++++++++++- python/pyspark/sql/pandas/conversion.py | 20 +- python/pyspark/sql/pandas/serializers.py | 1033 ++++------------------ python/pyspark/worker.py | 649 ++++++++------ 5 files changed, 1437 insertions(+), 1122 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 2288677763d3..a67b08f5c3f9 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -69,7 +69,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager -from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer +from pyspark.sql.pandas.serializers import ArrowStreamUDFSerializer from pyspark.sql.pandas.types import ( to_arrow_schema, to_arrow_type, @@ -630,19 +630,21 @@ def createDataFrame( safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] - ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true", False) - - _table = pa.Table.from_batches( - [ - ser._create_batch( - [ - (c, at, st) - for (_, c), at, st in zip(data.items(), arrow_types, spark_types) - ] - ) - ] + # Convert pandas data to Arrow RecordBatch + from pyspark.sql.conversion import PandasBatchTransformer + + batch_data = [ + (c, at, st) for (_, c), at, st in zip(data.items(), arrow_types, spark_types) + ] + record_batch = PandasBatchTransformer.to_arrow( + batch_data, + timezone=cast(str, timezone), + safecheck=safecheck == "true", + int_to_decimal_coercion_enabled=False, ) + _table = pa.Table.from_batches([record_batch]) + if isinstance(schema, StructType): assert arrow_schema is not None _table = _table.rename_columns( diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 4307eb700c0b..a997be17b6c9 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -20,11 +20,13 @@ import decimal from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError from pyspark.sql.pandas.types import ( _dedup_names, _deduplicate_field_names, _create_converter_to_pandas, + _create_converter_from_pandas, + from_arrow_type, to_arrow_schema, ) from pyspark.sql.pandas.utils import require_minimum_pyarrow_version @@ -60,6 +62,11 @@ class ArrowBatchTransformer: """ Pure functions that transform RecordBatch -> RecordBatch. They should have no side effects (no I/O, no writing to streams). + + This class provides utility methods for Arrow batch transformations used throughout + PySpark's Arrow UDF implementation. All methods are static and handle common patterns + like struct wrapping/unwrapping and schema conversions. + """ @staticmethod @@ -67,10 +74,14 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record """ Flatten a struct column at given index into a RecordBatch. - Used by: - - ArrowStreamUDFSerializer.load_stream - - SQL_GROUPED_MAP_ARROW_UDF mapper - - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper + Used by + ------- + - ArrowStreamGroupSerializer + - ArrowStreamArrowUDTFSerializer + - SQL_MAP_ARROW_ITER_UDF mapper + - SQL_GROUPED_MAP_ARROW_UDF mapper + - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper + - SQL_COGROUPED_MAP_ARROW_UDF mapper """ import pyarrow as pa @@ -82,7 +93,14 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ Wrap a RecordBatch's columns into a single struct column. - Used by: ArrowStreamUDFSerializer.dump_stream + Used by + ------- + - wrap_grouped_map_arrow_udf + - wrap_grouped_map_arrow_iter_udf + - wrap_cogrouped_map_arrow_udf + - wrap_arrow_batch_iter_udf + - ArrowStreamArrowUDTFSerializer.dump_stream + - TransformWithStateInPySparkRowSerializer.dump_stream """ import pyarrow as pa @@ -94,6 +112,339 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) + @classmethod + def partial_batch(cls, batch: "pa.RecordBatch", column_indices: List[int]) -> "pa.RecordBatch": + """ + Create a new RecordBatch with only selected columns. + + This method selects a subset of columns from a RecordBatch by their indices, + preserving column names and data types. + + Parameters + ---------- + batch : pa.RecordBatch + Input RecordBatch + column_indices : List[int] + Indices of columns to select (0-based) + + Returns + ------- + pa.RecordBatch + New RecordBatch containing only the selected columns + + Used by + ------- + - SQL_GROUPED_MAP_ARROW_UDF mapper + - partial_table + + Examples + -------- + >>> import pyarrow as pa + >>> batch = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['a', 'b']) + >>> partial = ArrowBatchTransformer.partial_batch(batch, [1]) + >>> partial.schema.names + ['b'] + """ + import pyarrow as pa + + return pa.RecordBatch.from_arrays( + arrays=[batch.columns[i] for i in column_indices], + names=[batch.schema.names[i] for i in column_indices], + ) + + @classmethod + def partial_table(cls, batches: List["pa.RecordBatch"], column_indices: List[int]) -> "pa.Table": + """ + Combine multiple batches into a Table with only selected columns. + + This method selects a subset of columns from each RecordBatch and combines + them into a single Arrow Table. + + Parameters + ---------- + batches : List[pa.RecordBatch] + List of RecordBatches to combine + column_indices : List[int] + Indices of columns to select (0-based) + + Returns + ------- + pa.Table + Combined Table containing only the selected columns + + Used by + ------- + - SQL_COGROUPED_MAP_ARROW_UDF mapper + + Examples + -------- + >>> import pyarrow as pa + >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['a', 'b']) + >>> batch2 = pa.RecordBatch.from_arrays([pa.array([5, 6]), pa.array([7, 8])], ['a', 'b']) + >>> table = ArrowBatchTransformer.partial_table([batch1, batch2], [1]) + >>> table.schema.names + ['b'] + >>> len(table) + 4 + """ + import pyarrow as pa + + return pa.Table.from_batches( + [cls.partial_batch(batch, column_indices) for batch in batches] + ) + + @classmethod + def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": + """ + Concatenate multiple RecordBatches into a single RecordBatch. + + This method handles both modern and legacy PyArrow versions. + + Parameters + ---------- + batches : List[pa.RecordBatch] + List of RecordBatches with the same schema + + Returns + ------- + pa.RecordBatch + Single RecordBatch containing all rows from input batches + + Used by + ------- + - SQL_GROUPED_AGG_ARROW_UDF mapper + - SQL_WINDOW_AGG_ARROW_UDF mapper + + Examples + -------- + >>> import pyarrow as pa + >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ['a']) + >>> batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ['a']) + >>> result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + >>> result.to_pydict() + {'a': [1, 2, 3, 4]} + """ + import pyarrow as pa + + if hasattr(pa, "concat_batches"): + return pa.concat_batches(batches) + else: + # pyarrow.concat_batches not supported in old versions + return pa.RecordBatch.from_struct_array( + pa.concat_arrays([b.to_struct_array() for b in batches]) + ) + + @classmethod + def merge_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": + """ + Merge multiple RecordBatches horizontally by combining their columns. + + This is different from concat_batches which concatenates rows vertically. + This method combines columns from multiple batches into a single batch, + useful when multiple UDFs each produce a RecordBatch and we need to + combine their outputs. + + Parameters + ---------- + batches : List[pa.RecordBatch] + List of RecordBatches to merge (must have same number of rows) + + Returns + ------- + pa.RecordBatch + Single RecordBatch with all columns from input batches + + Used by + ------- + - SQL_GROUPED_AGG_ARROW_UDF mapper + - SQL_WINDOW_AGG_ARROW_UDF mapper + + Examples + -------- + >>> import pyarrow as pa + >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ['a']) + >>> batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ['b']) + >>> result = ArrowBatchTransformer.merge_batches([batch1, batch2]) + >>> result.to_pydict() + {'_0': [1, 2], '_1': [3, 4]} + """ + import pyarrow as pa + + if len(batches) == 1: + return batches[0] + + # Combine all columns from all batches + all_columns = [] + for batch in batches: + all_columns.extend(batch.columns) + return pa.RecordBatch.from_arrays( + all_columns, ["_%d" % i for i in range(len(all_columns))] + ) + + @classmethod + def reorder_columns( + cls, batch: "pa.RecordBatch", target_schema: Union["pa.StructType", "StructType"] + ) -> "pa.RecordBatch": + """ + Reorder columns in a RecordBatch to match target schema field order. + + This method is useful when columns need to be arranged in a specific order + for schema compatibility, particularly when assign_cols_by_name is enabled. + + Parameters + ---------- + batch : pa.RecordBatch + Input RecordBatch with columns to reorder + target_schema : pa.StructType or pyspark.sql.types.StructType + Target schema defining the desired column order. + Can be either PyArrow StructType or Spark StructType. + + Returns + ------- + pa.RecordBatch + New RecordBatch with columns reordered to match target schema + + Used by + ------- + - wrap_grouped_map_arrow_udf + - wrap_grouped_map_arrow_iter_udf + - wrap_cogrouped_map_arrow_udf + + Examples + -------- + >>> import pyarrow as pa + >>> from pyspark.sql.types import StructType, StructField, IntegerType + >>> batch = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['b', 'a']) + >>> # Using PyArrow schema + >>> target_pa = pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.int64())]) + >>> result = ArrowBatchTransformer.reorder_columns(batch, target_pa) + >>> result.schema.names + ['a', 'b'] + >>> # Using Spark schema + >>> target_spark = StructType([StructField('a', IntegerType()), StructField('b', IntegerType())]) + >>> result = ArrowBatchTransformer.reorder_columns(batch, target_spark) + >>> result.schema.names + ['a', 'b'] + """ + import pyarrow as pa + + # Convert Spark StructType to PyArrow StructType if needed + if hasattr(target_schema, 'fields') and hasattr(target_schema.fields[0], 'dataType'): + # This is Spark StructType - convert to PyArrow + from pyspark.sql.pandas.types import to_arrow_schema + arrow_schema = to_arrow_schema(target_schema) + field_names = [field.name for field in arrow_schema] + else: + # This is PyArrow StructType + field_names = [field.name for field in target_schema] + + return pa.RecordBatch.from_arrays( + [batch.column(name) for name in field_names], + names=field_names, + ) + + @staticmethod + def cast_array( + arr: "pa.Array", + target_type: "pa.DataType", + arrow_cast: bool = False, + safecheck: bool = True, + error_message: Optional[str] = None, + ) -> "pa.Array": + """ + Cast an Arrow Array to a target type with type checking. + + Parameters + ---------- + arr : pa.Array + The Arrow Array to cast. + target_type : pa.DataType + Target Arrow data type. + arrow_cast : bool + If True, always attempt to cast. If False, raise error on type mismatch. + safecheck : bool + If True, use safe casting (fails on overflow/truncation). + error_message : str, optional + Custom error message for type mismatch. + + Returns + ------- + pa.Array + The casted array if types differ, or original array if types match. + """ + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(target_type, pa.DataType) + + if arr.type == target_type: + return arr + elif arrow_cast: + return arr.cast(target_type=target_type, safe=safecheck) + else: + if error_message: + raise PySparkTypeError(error_message) + else: + raise PySparkTypeError( + f"Arrow type mismatch. Expected: {target_type}, but got: {arr.type}." + ) + + @staticmethod + def create_batch_from_arrays( + packed: Union[tuple, list], + arrow_cast: bool = False, + safecheck: bool = True, + ) -> "pa.RecordBatch": + """ + Create a RecordBatch from (array, type) pairs with type casting. + + Parameters + ---------- + packed : tuple or list + Either a (array, type) tuple for single array, or a list of (array, type) tuples. + arrow_cast : bool + If True, always attempt to cast. If False, raise error on type mismatch. + safecheck : bool + If True, use safe casting (fails on overflow/truncation). + + Returns + ------- + pa.RecordBatch + RecordBatch with casted arrays. + """ + import pyarrow as pa + + if len(packed) == 2 and isinstance(packed[1], pa.DataType): + # single array UDF in a projection + arrs = [ + ArrowBatchTransformer.cast_array( + packed[0], + packed[1], + arrow_cast=arrow_cast, + safecheck=safecheck, + error_message=( + "Arrow UDFs require the return type to match the expected Arrow type. " + f"Expected: {packed[1]}, but got: {packed[0].type}." + ), + ) + ] + else: + # multiple array UDFs in a projection + arrs = [ + ArrowBatchTransformer.cast_array( + t[0], + t[1], + arrow_cast=arrow_cast, + safecheck=safecheck, + error_message=( + "Arrow UDFs require the return type to match the expected Arrow type. " + f"Expected: {t[1]}, but got: {t[0].type}." + ), + ) + for t in packed + ] + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + @classmethod def to_pandas( cls, @@ -148,6 +499,474 @@ def to_pandas( ] +class PandasBatchTransformer: + """ + Pure functions that transform between pandas DataFrames/Series and Arrow RecordBatches. + They should have no side effects (no I/O, no writing to streams). + + This class provides utility methods for converting between pandas and Arrow formats, + used primarily by Pandas UDF wrappers and serializers. + + """ + + @classmethod + def create_array( + cls, + series: "pd.Series", + arrow_type: Optional["pa.DataType"], + timezone: str, + safecheck: bool = True, + spark_type: Optional[DataType] = None, + arrow_cast: bool = False, + int_to_decimal_coercion_enabled: bool = False, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + ) -> "pa.Array": + """ + Create an Arrow Array from the given pandas.Series and optional type. + + Parameters + ---------- + series : pandas.Series + A single series + arrow_type : pyarrow.DataType, optional + If None, pyarrow's inferred type will be used + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + spark_type : DataType, optional + If None, spark type converted from arrow_type will be used + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values during conversion + error_class : str, optional + Custom error class for arrow type cast errors (e.g., "UDTF_ARROW_TYPE_CAST_ERROR") + + Returns + ------- + pyarrow.Array + """ + import pyarrow as pa + import pandas as pd + + if isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype(series.dtypes.categories.dtype) + + if arrow_type is not None: + dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) + conv = _create_converter_from_pandas( + dt, + timezone=timezone, + error_on_duplicated_field_names=False, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + ) + series = conv(series) + + if hasattr(series.array, "__arrow_array__"): + mask = None + else: + mask = series.isnull() + + # For UDTF (error_class is set), use Arrow-specific error handling + if error_class is not None: + try: + try: + return pa.Array.from_pandas( + series, mask=mask, type=arrow_type, safe=safecheck + ) + except pa.lib.ArrowException: + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except pa.lib.ArrowException: + raise PySparkRuntimeError( + errorClass=error_class, + messageParameters={ + "col_name": series.name, + "col_type": str(series.dtype), + "arrow_type": str(arrow_type), + }, + ) from None + + # For regular UDF, use standard error handling + try: + try: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) + except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError): + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except TypeError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e + except (ValueError, pa.lib.ArrowException) as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + if safecheck: + error_msg = error_msg + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e + + @staticmethod + def normalize_input( + series: Union["pd.Series", "pd.DataFrame", List], + ) -> "Iterator[Tuple[Any, Optional[pa.DataType], Optional[DataType]]]": + """ + Normalize input to a consistent format for batch conversion. + + Converts various input formats to an iterator of + (data, arrow_type, spark_type) tuples. + + Parameters + ---------- + series : pandas.Series, pandas.DataFrame, or list + A single series/dataframe, list of series/dataframes, or list of + (series, arrow_type) or (series, arrow_type, spark_type) + + Returns + ------- + Iterator[Tuple[Any, Optional[pa.DataType], Optional[DataType]]] + Iterator of (data, arrow_type, spark_type) tuples + """ + import pyarrow as pa + + # Make input conform to + # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] + if ( + not isinstance(series, (list, tuple)) + or (len(series) == 2 and isinstance(series[1], pa.DataType)) + or ( + len(series) == 3 + and isinstance(series[1], pa.DataType) + and isinstance(series[2], DataType) + ) + ): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + return ((s[0], s[1], None) if len(s) == 2 else s for s in series) + + @classmethod + def create_struct_array( + cls, + df: "pd.DataFrame", + arrow_struct_type: "pa.StructType", + timezone: str, + safecheck: bool = True, + spark_type: Optional["StructType"] = None, + arrow_cast: bool = False, + int_to_decimal_coercion_enabled: bool = False, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + assign_cols_by_name: bool = True, + ) -> "pa.StructArray": + """ + Create an Arrow StructArray from the given pandas.DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + A pandas DataFrame + arrow_struct_type : pyarrow.StructType + Target Arrow struct type + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + spark_type : StructType, optional + Spark schema for type conversion + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values + error_class : str, optional + Custom error class for type cast errors + assign_cols_by_name : bool + If True, assign columns by name; otherwise by position + + Returns + ------- + pyarrow.StructArray + """ + import pyarrow as pa + + if len(df.columns) == 0: + return pa.array([{}] * len(df), arrow_struct_type) + + # Assign result columns by schema name if user labeled with strings + if assign_cols_by_name and any(isinstance(name, str) for name in df.columns): + struct_arrs = [ + cls.create_array( + df[field.name], + field.type, + timezone=timezone, + safecheck=safecheck, + spark_type=( + spark_type[field.name].dataType if spark_type is not None else None + ), + arrow_cast=arrow_cast, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + for field in arrow_struct_type + ] + # Assign result columns by position + else: + struct_arrs = [ + cls.create_array( + df[df.columns[i]].rename(field.name), + field.type, + timezone=timezone, + safecheck=safecheck, + spark_type=spark_type[i].dataType if spark_type is not None else None, + arrow_cast=arrow_cast, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + for i, field in enumerate(arrow_struct_type) + ] + + return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type)) + + @classmethod + def concat_series_batches( + cls, series_batches: List[List["pd.Series"]], arg_offsets: Optional[List[int]] = None + ) -> List["pd.Series"]: + """ + Concatenate multiple batches of pandas Series column-wise. + + Takes a list of batches where each batch is a list of Series (one per column), + and concatenates all Series column-by-column to produce a single list of + concatenated Series. + + Parameters + ---------- + series_batches : List[List[pd.Series]] + List of batches, each batch is a list of Series (one Series per column) + arg_offsets : Optional[List[int]] + If provided and series_batches is empty, determines the number of empty Series to create + + Returns + ------- + List[pd.Series] + List of concatenated Series, one per column + + Used by + ------- + - SQL_GROUPED_AGG_PANDAS_UDF mapper + - SQL_WINDOW_AGG_PANDAS_UDF mapper + + Examples + -------- + >>> import pandas as pd + >>> batch1 = [pd.Series([1, 2]), pd.Series([3, 4])] + >>> batch2 = [pd.Series([5, 6]), pd.Series([7, 8])] + >>> result = PandasBatchTransformer.concat_series_batches([batch1, batch2]) + >>> len(result) + 2 + >>> result[0].tolist() + [1, 2, 5, 6] + """ + import pandas as pd + + if not series_batches: + # Empty batches - create empty Series + if arg_offsets: + num_columns = max(arg_offsets) + 1 if arg_offsets else 0 + else: + num_columns = 0 + return [pd.Series(dtype=object) for _ in range(num_columns)] + + # Concatenate Series by column + num_columns = len(series_batches[0]) + return [ + pd.concat([batch[i] for batch in series_batches], ignore_index=True) + for i in range(num_columns) + ] + + @classmethod + def series_batches_to_dataframe(cls, series_batches) -> "pd.DataFrame": + """ + Convert an iterator of Series lists to a single DataFrame. + + Each batch is a list of Series (one per column). This method concatenates + Series within each batch horizontally (axis=1), then concatenates all + resulting DataFrames vertically (axis=0). + + Parameters + ---------- + series_batches : Iterator[List[pd.Series]] + Iterator where each element is a list of Series representing one batch + + Returns + ------- + pd.DataFrame + Combined DataFrame with all data, or empty DataFrame if no batches + + Used by + ------- + - wrap_grouped_map_pandas_udf + - wrap_grouped_map_pandas_iter_udf + + Examples + -------- + >>> import pandas as pd + >>> batch1 = [pd.Series([1, 2], name='a'), pd.Series([3, 4], name='b')] + >>> batch2 = [pd.Series([5, 6], name='a'), pd.Series([7, 8], name='b')] + >>> df = PandasBatchTransformer.series_batches_to_dataframe([batch1, batch2]) + >>> df.shape + (4, 2) + >>> df.columns.tolist() + ['a', 'b'] + """ + import pandas as pd + + # Materialize iterator and convert each batch to DataFrame + dataframes = [pd.concat(series_list, axis=1) for series_list in series_batches] + + # Concatenate all DataFrames vertically + return pd.concat(dataframes, axis=0) if dataframes else pd.DataFrame() + + @classmethod + def to_arrow( + cls, + series: Union["pd.Series", "pd.DataFrame", List], + timezone: str, + safecheck: bool = True, + int_to_decimal_coercion_enabled: bool = False, + as_struct: bool = False, + struct_in_pandas: Optional[str] = None, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + assign_cols_by_name: bool = True, + arrow_cast: bool = False, + ) -> "pa.RecordBatch": + """ + Convert pandas.Series/DataFrame or list to Arrow RecordBatch. + + Parameters + ---------- + series : pandas.Series, pandas.DataFrame or list + A single series/dataframe, list of series/dataframes, or list of + (data, arrow_type) or (data, arrow_type, spark_type) + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + as_struct : bool + If True, treat all inputs as DataFrames and create struct arrays (for UDTF) + struct_in_pandas : str, optional + How struct types are represented. If "dict", struct types require DataFrame input. + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values (for UDTF) + error_class : str, optional + Custom error class for type cast errors (for UDTF) + assign_cols_by_name : bool + If True, assign columns by name; otherwise by position (for struct) + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + + Returns + ------- + pyarrow.RecordBatch + Arrow RecordBatch + """ + import pandas as pd + import pyarrow as pa + from pyspark.sql.pandas.types import is_variant + + arrs = [] + for s, arrow_type, spark_type in cls.normalize_input(series): + if as_struct: + # UDTF mode: require DataFrame, create struct array + if not isinstance(s, pd.DataFrame): + raise PySparkValueError( + "Output of an arrow-optimized Python UDTFs expects " + f"a pandas.DataFrame but got: {type(s)}" + ) + arrs.append( + cls.create_struct_array( + s, + arrow_type, + timezone=timezone, + safecheck=safecheck, + spark_type=spark_type, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + assign_cols_by_name=assign_cols_by_name, + arrow_cast=arrow_cast, + ) + ) + elif ( + struct_in_pandas == "dict" + and arrow_type is not None + and pa.types.is_struct(arrow_type) + and not is_variant(arrow_type) + ): + # Struct type with dict mode: require DataFrame + if not isinstance(s, pd.DataFrame): + raise PySparkValueError( + "Invalid return type. Please make sure that the UDF returns a " + "pandas.DataFrame when the specified return type is StructType." + ) + arrs.append( + cls.create_struct_array( + s, + arrow_type, + timezone=timezone, + safecheck=safecheck, + spark_type=spark_type, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + assign_cols_by_name=assign_cols_by_name, + arrow_cast=arrow_cast, + ) + ) + else: + # Normal mode: create array from Series + arrs.append( + cls.create_array( + s, + arrow_type, + timezone=timezone, + safecheck=safecheck, + spark_type=spark_type, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + arrow_cast=arrow_cast, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + ) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + + class LocalDataToArrowConversion: """ Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow. diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 333f9803df3a..4bbae85d02c4 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -31,6 +31,7 @@ from pyspark.errors.exceptions.captured import unwrap_spark_exception from pyspark.util import _load_from_socket +from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.pandas.types import _dedup_names from pyspark.sql.types import ( @@ -807,7 +808,7 @@ def _create_from_pandas_with_arrow( assert isinstance(self, SparkSession) - from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + from pyspark.sql.pandas.serializers import ArrowStreamUDFSerializer from pyspark.sql.types import TimestampType from pyspark.sql.pandas.types import ( from_arrow_type, @@ -878,8 +879,8 @@ def _create_from_pandas_with_arrow( step = step if step > 0 else len(pdf) pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) - # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream - arrow_data = [ + # Create list of Arrow (columns, arrow_type, spark_type) for conversion to RecordBatch + pandas_data = [ [ ( c, @@ -893,9 +894,20 @@ def _create_from_pandas_with_arrow( for pdf_slice in pdf_slices ] + # Convert pandas data to Arrow RecordBatches before serialization + arrow_data = map( + lambda batch_data: PandasBatchTransformer.to_arrow( + batch_data, + timezone=timezone, + safecheck=safecheck, + int_to_decimal_coercion_enabled=False, + ), + pandas_data, + ) + jsparkSession = self._jsparkSession - ser = ArrowStreamPandasSerializer(timezone, safecheck, False) + ser = ArrowStreamUDFSerializer(timezone, safecheck, False) @no_type_check def reader_func(temp_filename): diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 55b93bd4da58..42d5700d31c8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -37,6 +37,7 @@ ArrowTableToRowsConversion, ArrowArrayToPandasConversion, ArrowBatchTransformer, + PandasBatchTransformer, ) from pyspark.sql.pandas.types import ( from_arrow_type, @@ -213,40 +214,120 @@ def __repr__(self): return "ArrowStreamSerializer" -class ArrowStreamUDFSerializer(ArrowStreamSerializer): +class ArrowStreamGroupSerializer(ArrowStreamSerializer): """ - Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch - for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. + Unified serializer for Arrow stream operations with optional grouping support. + + This serializer handles: + - Non-grouped operations: SQL_MAP_ARROW_ITER_UDF (num_dfs=0) + - Grouped operations: SQL_GROUPED_MAP_ARROW_UDF, SQL_GROUPED_MAP_PANDAS_UDF (num_dfs=1) + - Cogrouped operations: SQL_COGROUPED_MAP_ARROW_UDF, SQL_COGROUPED_MAP_PANDAS_UDF (num_dfs=2) + - Grouped aggregations: SQL_GROUPED_AGG_ARROW_UDF, SQL_GROUPED_AGG_PANDAS_UDF (num_dfs=1) + + The serializer handles Arrow stream I/O and START signal, while transformation logic + (flatten/wrap struct, pandas conversion) is handled by worker wrappers. + + Used by + ------- + - SQL_MAP_ARROW_ITER_UDF: DataFrame.mapInArrow() + - SQL_GROUPED_MAP_ARROW_UDF: GroupedData.applyInArrow() + - SQL_GROUPED_MAP_ARROW_ITER_UDF: GroupedData.applyInArrow() with iter + - SQL_GROUPED_MAP_PANDAS_UDF: GroupedData.apply() + - SQL_GROUPED_MAP_PANDAS_ITER_UDF: GroupedData.apply() with iter + - SQL_COGROUPED_MAP_ARROW_UDF: DataFrame.groupby().cogroup().applyInArrow() + - SQL_COGROUPED_MAP_PANDAS_UDF: DataFrame.groupby().cogroup().apply() + - SQL_GROUPED_AGG_ARROW_UDF: GroupedData.agg() with arrow UDF + - SQL_GROUPED_AGG_ARROW_ITER_UDF: GroupedData.agg() with arrow iter UDF + - SQL_GROUPED_AGG_PANDAS_UDF: GroupedData.agg() with pandas UDF + - SQL_GROUPED_AGG_PANDAS_ITER_UDF: GroupedData.agg() with pandas iter UDF + - SQL_WINDOW_AGG_ARROW_UDF: Window aggregation with arrow UDF + - SQL_SCALAR_ARROW_UDF: Scalar arrow UDF + - SQL_SCALAR_ARROW_ITER_UDF: Scalar arrow iter UDF + + Parameters + ---------- + timezone : str, optional + Timezone for timestamp conversion (stored for compatibility) + safecheck : bool, optional + Safecheck flag (stored for compatibility) + int_to_decimal_coercion_enabled : bool, optional + Decimal coercion flag (stored for compatibility) + num_dfs : int, optional + Number of DataFrames per group: + - 0: Non-grouped mode (default) - yields all batches as single stream + - 1: Grouped mode - yields one iterator of batches per group + - 2: Cogrouped mode - yields tuple of two iterators per group + assign_cols_by_name : bool, optional + If True, assign DataFrame columns by name; otherwise by position (default: True) + arrow_cast : bool, optional + Arrow cast flag (stored for compatibility, default: False) """ + def __init__( + self, + timezone=None, + safecheck=None, + int_to_decimal_coercion_enabled: bool = False, + num_dfs: int = 0, + assign_cols_by_name: bool = True, + arrow_cast: bool = False, + ): + super().__init__() + # Store parameters for compatibility + self._timezone = timezone + self._safecheck = safecheck + self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled + self._num_dfs = num_dfs + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = arrow_cast + def load_stream(self, stream): """ - Flatten the struct into Arrow's record batches. - """ - batches = super().load_stream(stream) - flattened = map(ArrowBatchTransformer.flatten_struct, batches) - return map(lambda b: [b], flattened) + Deserialize Arrow record batches from stream. + + Returns + ------- + Iterator + - num_dfs=0: Iterator[pa.RecordBatch] - all batches in a single stream + - num_dfs=1: Iterator[Iterator[pa.RecordBatch]] - one iterator per group + - num_dfs=2: Iterator[Tuple[Iterator[pa.RecordBatch], Iterator[pa.RecordBatch]]] - + tuple of two iterators per cogrouped group + """ + if self._num_dfs > 0: + # Grouped mode: return raw Arrow batches per group + for batch_iters in self._load_group_dataframes(stream, num_dfs=self._num_dfs): + if self._num_dfs == 1: + yield batch_iters[0] + # Make sure the batches are fully iterated before getting the next group + for _ in batch_iters[0]: + pass + else: + yield batch_iters + else: + # Non-grouped mode: return raw Arrow batches + yield from super().load_stream(stream) def dump_stream(self, iterator, stream): """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + Serialize Arrow record batches to stream with START signal. + + The START_ARROW_STREAM marker is sent before the first batch to signal + the JVM that Arrow data is about to be transmitted. This allows proper + error handling during batch creation. + + Parameters + ---------- + iterator : Iterator[pa.RecordBatch] + Iterator of Arrow RecordBatches to serialize. For grouped/cogrouped UDFs, + this iterator is already flattened by the worker's func layer. + stream : file-like object + Output stream to write the serialized batches """ - batches = self._write_stream_start( - (ArrowBatchTransformer.wrap_struct(x[0]) for x in iterator), stream - ) + batches = self._write_stream_start(iterator, stream) return super().dump_stream(batches, stream) -class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): - """ - Same as :class:`ArrowStreamUDFSerializer` but it does not flatten when loading batches. - """ - - def load_stream(self, stream): - return ArrowStreamSerializer.load_stream(self, stream) - - -class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): +class ArrowStreamArrowUDTFSerializer(ArrowStreamGroupSerializer): """ Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. """ @@ -269,29 +350,6 @@ def load_stream(self, stream): for i in range(batch.num_columns) ] - def _create_array(self, arr, arrow_type): - import pyarrow as pa - - assert isinstance(arr, pa.Array) - assert isinstance(arrow_type, pa.DataType) - if arr.type == arrow_type: - return arr - else: - try: - # when safe is True, the cast will fail if there's a overflow or other - # unsafe conversion. - # RecordBatch.cast(...) isn't used as minimum PyArrow version - # required for RecordBatch.cast(...) is v16.0 - return arr.cast(target_type=arrow_type, safe=True) - except (pa.ArrowInvalid, pa.ArrowTypeError): - raise PySparkRuntimeError( - errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", - messageParameters={ - "expected": str(arrow_type), - "actual": str(arr.type), - }, - ) - def dump_stream(self, iterator, stream): """ Override to handle type coercion for ArrowUDTF outputs. @@ -322,7 +380,22 @@ def apply_type_coercion(): coerced_arrays = [] for i, field in enumerate(arrow_return_type): original_array = batch.column(i) - coerced_array = self._create_array(original_array, field.type) + try: + coerced_array = ArrowBatchTransformer.cast_array( + original_array, + field.type, + arrow_cast=True, + safecheck=True, + ) + except PySparkTypeError: + # Re-raise with UDTF-specific error + raise PySparkRuntimeError( + errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", + messageParameters={ + "expected": str(field.type), + "actual": str(original_array.type), + }, + ) coerced_arrays.append(coerced_array) coerced_batch = pa.RecordBatch.from_arrays( coerced_arrays, names=expected_field_names @@ -332,89 +405,23 @@ def apply_type_coercion(): return super().dump_stream(apply_type_coercion(), stream) -class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): - """ - Serializer for grouped Arrow UDFs. - - Deserializes: - ``Iterator[Iterator[pa.RecordBatch]]`` - one inner iterator per group. - Each batch contains a single struct column. - - Serializes: - ``Iterator[Tuple[Iterator[pa.RecordBatch], pa.DataType]]`` - Each tuple contains iterator of flattened batches and their Arrow type. - - Used by: - - SQL_GROUPED_MAP_ARROW_UDF - - SQL_GROUPED_MAP_ARROW_ITER_UDF - - Parameters - ---------- - assign_cols_by_name : bool - If True, reorder serialized columns by schema name. - """ - - def __init__(self, assign_cols_by_name): - super().__init__() - self._assign_cols_by_name = assign_cols_by_name - - def load_stream(self, stream): - """ - Load grouped Arrow record batches from stream. - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - yield batches - # Make sure the batches are fully iterated before getting the next group - for _ in batches: - pass - - def dump_stream(self, iterator, stream): - import pyarrow as pa - - # flatten inner list [([pa.RecordBatch], arrow_type)] into [(pa.RecordBatch, arrow_type)] - # so strip off inner iterator induced by ArrowStreamUDFSerializer.load_stream - batch_iter = ( - (batch, arrow_type) - for batches, arrow_type in iterator # tuple constructed in wrap_grouped_map_arrow_udf - for batch in batches - ) - - if self._assign_cols_by_name: - batch_iter = ( - ( - pa.RecordBatch.from_arrays( - [batch.column(field.name) for field in arrow_type], - names=[field.name for field in arrow_type], - ), - arrow_type, - ) - for batch, arrow_type in batch_iter - ) - - super().dump_stream(batch_iter, stream) - - -class ArrowStreamPandasSerializer(ArrowStreamSerializer): +class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ - Serializes pandas.Series as Arrow data with Arrow streaming format. + Serializer for UDFs that handles Arrow RecordBatch serialization. + + This is a thin wrapper around ArrowStreamSerializer kept for backward compatibility + and future extensibility. Currently it doesn't override any methods - all conversion + logic has been moved to wrappers/callers, and this serializer only handles pure + Arrow stream serialization. Parameters ---------- timezone : str - A timezone to respect when handling timestamp values + A timezone to respect when handling timestamp values (for compatibility) safecheck : bool - If True, conversion from Arrow to Pandas checks for overflow/truncation + If True, conversion checks for overflow/truncation (for compatibility) int_to_decimal_coercion_enabled : bool - If True, applies additional coercions in Python before converting to Arrow - This has performance penalties. - struct_in_pandas : str, optional - How to represent struct in pandas ("dict", "row", etc.). Default is "dict". - ndarray_as_list : bool, optional - Whether to convert ndarray as list. Default is False. - df_for_struct : bool, optional - If True, convert struct columns to DataFrame instead of Series. Default is False. - input_type : StructType, optional - Spark types for each column. Default is None. + If True, applies additional coercions (for compatibility) """ def __init__( @@ -422,374 +429,21 @@ def __init__( timezone, safecheck, int_to_decimal_coercion_enabled: bool = False, - struct_in_pandas: str = "dict", - ndarray_as_list: bool = False, - df_for_struct: bool = False, ): super().__init__() + # Store parameters for backward compatibility self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled - self._struct_in_pandas = struct_in_pandas - self._ndarray_as_list = ndarray_as_list - self._df_for_struct = df_for_struct - - def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): - """ - Create an Arrow Array from the given pandas.Series and optional type. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast: bool, optional - Whether to apply Arrow casting when the user-specified return type mismatches the - actual return values. - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = _create_converter_from_pandas( - dt, - timezone=self._timezone, - error_on_duplicated_field_names=False, - int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, - ) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=self._safecheck - ) - except pa.lib.ArrowInvalid: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=self._safecheck - ) - else: - raise - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - if self._safecheck: - error_msg = error_msg + ( - " It can be caused by overflows or other " - "unsafe conversions warned by Arrow. Arrow safe type check " - "can be disabled by using SQL config " - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series or list of Series, - with optional type. - - Parameters - ---------- - series : pandas.Series or list - A single series, list of series, or list of (series, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [ - self._create_array(s, arrow_type, spark_type) for s, arrow_type, spark_type in series - ] - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def dump_stream(self, iterator, stream): - """ - Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or - a list of series accompanied by an optional pyarrow type to coerce the data to. - """ - batches = (self._create_batch(series) for series in iterator) - super().dump_stream(batches, stream) - - def load_stream(self, stream): - """ - Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. - """ - yield from map( - lambda batch: ArrowBatchTransformer.to_pandas( - batch, - timezone=self._timezone, - schema=self._input_type, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ), - super().load_stream(stream), - ) def __repr__(self): - return "ArrowStreamPandasSerializer" + return "ArrowStreamUDFSerializer" - -class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): - """ - Serializer used by Python worker to evaluate Pandas UDFs - """ - - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - df_for_struct: bool = False, - struct_in_pandas: str = "dict", - ndarray_as_list: bool = False, - arrow_cast: bool = False, - input_type: Optional[StructType] = None, - int_to_decimal_coercion_enabled: bool = False, - ): - super().__init__( - timezone, - safecheck, - int_to_decimal_coercion_enabled, - struct_in_pandas, - ndarray_as_list, - df_for_struct, - ) - self._assign_cols_by_name = assign_cols_by_name - self._arrow_cast = arrow_cast - if input_type is not None: - assert isinstance(input_type, StructType) - self._input_type = input_type - - def _create_struct_array( - self, - df: "pd.DataFrame", - arrow_struct_type: "pa.StructType", - spark_type: Optional[StructType] = None, - ): - """ - Create an Arrow StructArray from the given pandas.DataFrame and arrow struct type. - - Parameters - ---------- - df : pandas.DataFrame - A pandas DataFrame - arrow_struct_type : pyarrow.StructType - pyarrow struct type - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - - if len(df.columns) == 0: - return pa.array([{}] * len(df), arrow_struct_type) - # Assign result columns by schema name if user labeled with strings - if self._assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - struct_arrs = [ - self._create_array( - df[field.name], - field.type, - spark_type=( - spark_type[field.name].dataType if spark_type is not None else None - ), - arrow_cast=self._arrow_cast, - ) - for field in arrow_struct_type - ] - # Assign result columns by position - else: - struct_arrs = [ - # the selected series has name '1', so we rename it to field.name - # as the name is used by _create_array to provide a meaningful error message - self._create_array( - df[df.columns[i]].rename(field.name), - field.type, - spark_type=spark_type[i].dataType if spark_type is not None else None, - arrow_cast=self._arrow_cast, - ) - for i, field in enumerate(arrow_struct_type) - ] - - return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type)) - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series pandas.DataFrame - or list of Series or DataFrame, with optional type. - - Parameters - ---------- - series : pandas.Series or pandas.DataFrame or list - A single series or dataframe, list of series or dataframe, - or list of (series or dataframe, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [] - for s, arrow_type, spark_type in series: - # Variants are represented in arrow as structs with additional metadata (checked by - # is_variant). If the data type is Variant, return a VariantVal atomic type instead of - # a dict of two binary values. - if ( - self._struct_in_pandas == "dict" - and arrow_type is not None - and pa.types.is_struct(arrow_type) - and not is_variant(arrow_type) - ): - # A pandas UDF should return pd.DataFrame when the return type is a struct type. - # If it returns a pd.Series, it should throw an error. - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Invalid return type. Please make sure that the UDF returns a " - "pandas.DataFrame when the specified return type is StructType." - ) - arrs.append(self._create_struct_array(s, arrow_type, spark_type=spark_type)) - else: - arrs.append( - self._create_array( - s, arrow_type, spark_type=spark_type, arrow_cast=self._arrow_cast - ) - ) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - batches = self._write_stream_start(map(self._create_batch, iterator), stream) - return ArrowStreamSerializer.dump_stream(self, batches, stream) - - def __repr__(self): - return "ArrowStreamPandasUDFSerializer" - - -class ArrowStreamArrowUDFSerializer(ArrowStreamSerializer): - """ - Serializer used by Python worker to evaluate Arrow UDFs - """ - - def __init__( - self, - safecheck, - arrow_cast, - ): - super().__init__() - self._safecheck = safecheck - self._arrow_cast = arrow_cast - - def _create_array(self, arr, arrow_type, arrow_cast): - import pyarrow as pa - - assert isinstance(arr, pa.Array) - assert isinstance(arrow_type, pa.DataType) - - if arr.type == arrow_type: - return arr - elif arrow_cast: - return arr.cast(target_type=arrow_type, safe=self._safecheck) - else: - raise PySparkTypeError( - "Arrow UDFs require the return type to match the expected Arrow type. " - f"Expected: {arrow_type}, but got: {arr.type}." - ) - - def dump_stream(self, iterator, stream): - """ - Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - import pyarrow as pa - - def create_batches(): - for packed in iterator: - if len(packed) == 2 and isinstance(packed[1], pa.DataType): - # single array UDF in a projection - arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)] - else: - # multiple array UDFs in a projection - arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed] - yield pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - batches = self._write_stream_start(create_batches(), stream) - return ArrowStreamSerializer.dump_stream(self, batches, stream) - - def __repr__(self): - return "ArrowStreamArrowUDFSerializer" - - -class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer): +class ArrowBatchUDFSerializer(ArrowStreamGroupSerializer): """ Serializer used by Python worker to evaluate Arrow Python UDFs when the legacy pandas conversion is disabled - (instead of legacy ArrowStreamPandasUDFSerializer). + (instead of legacy ArrowStreamGroupSerializer with pandas conversion in serializer). Parameters ---------- @@ -814,6 +468,7 @@ def __init__( super().__init__( safecheck=safecheck, arrow_cast=True, + num_dfs=0, ) assert isinstance(input_type, StructType) self._input_type = input_type @@ -842,14 +497,13 @@ def load_stream(self, stream): ] for batch in super().load_stream(stream): - columns = [ - [conv(v) for v in column.to_pylist()] if conv is not None else column.to_pylist() - for column, conv in zip(batch.itercolumns(), converters) - ] - if len(columns) == 0: + if batch.num_columns == 0: yield [[pyspark._NoValue] * batch.num_rows] else: - yield columns + yield [ + [conv(v) for v in col.to_pylist()] if conv else col.to_pylist() + for col, conv in zip(batch.itercolumns(), converters) + ] def dump_stream(self, iterator, stream): """ @@ -866,7 +520,7 @@ def dump_stream(self, iterator, stream): Returns ------- object - Result of writing the Arrow stream via ArrowStreamArrowUDFSerializer dump_stream + Result of writing the Arrow stream via ArrowStreamGroupSerializer dump_stream """ import pyarrow as pa @@ -886,343 +540,22 @@ def py_to_batch(): for packed in iterator: if len(packed) == 3 and isinstance(packed[1], pa.DataType): # single array UDF in a projection - yield create_array(packed[0], packed[1], packed[2]), packed[1] - else: - # multiple array UDFs in a projection - yield [(create_array(*t), t[1]) for t in packed] - - return super().dump_stream(py_to_batch(), stream) - - -class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): - """ - Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. - """ - - def __init__(self, timezone, safecheck, input_type, int_to_decimal_coercion_enabled): - super().__init__( - timezone=timezone, - safecheck=safecheck, - # The output pandas DataFrame's columns are unnamed. - assign_cols_by_name=False, - # Set to 'False' to avoid converting struct type inputs into a pandas DataFrame. - df_for_struct=False, - # Defines how struct type inputs are converted. If set to "row", struct type inputs - # are converted into Rows. Without this setting, a struct type input would be treated - # as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1), - # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1} - # if struct_in_pandas="row", it becomes Row(name="Alice", age=1) - struct_in_pandas="row", - # When dealing with array type inputs, Arrow converts them into numpy.ndarrays. - # To ensure consistency across regular and arrow-optimized UDTFs, we further - # convert these numpy.ndarrays into Python lists. - ndarray_as_list=True, - # Enables explicit casting for mismatched return types of Arrow Python UDTFs. - arrow_cast=True, - input_type=input_type, - # Enable additional coercions for UDTF serialization - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series pandas.DataFrame - or list of Series or DataFrame, with optional type. - - Parameters - ---------- - series : pandas.Series or pandas.DataFrame or list - A single series or dataframe, list of series or dataframe, - or list of (series or dataframe, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [] - for s, arrow_type, spark_type in series: - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Output of an arrow-optimized Python UDTFs expects " - f"a pandas.DataFrame but got: {type(s)}" - ) - - arrs.append(self._create_struct_array(s, arrow_type, spark_type)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): - """ - Override the `_create_array` method in the superclass to create an Arrow Array - from a given pandas.Series and an arrow type. The difference here is that we always - use arrow cast when creating the arrow array. Also, the error messages are specific - to arrow-optimized Python UDTFs. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast: bool, optional - Whether to apply Arrow casting when the user-specified return type mismatches the - actual return values. - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = _create_converter_from_pandas( - dt, - timezone=self._timezone, - error_on_duplicated_field_names=False, - ignore_unexpected_complex_type_values=True, - int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, - ) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=self._safecheck - ) - except pa.lib.ArrowException: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=self._safecheck + yield ArrowBatchTransformer.create_batch_from_arrays( + (create_array(packed[0], packed[1], packed[2]), packed[1]), + arrow_cast=self._arrow_cast, + safecheck=self._safecheck, ) else: - raise - except pa.lib.ArrowException: - # Display the most user-friendly error messages instead of showing - # arrow's error message. This also works better with Spark Connect - # where the exception messages are by default truncated. - raise PySparkRuntimeError( - errorClass="UDTF_ARROW_TYPE_CAST_ERROR", - messageParameters={ - "col_name": series.name, - "col_type": str(series.dtype), - "arrow_type": arrow_type, - }, - ) from None - - def __repr__(self): - return "ArrowStreamPandasUDTFSerializer" - - -# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, -# and SQL_GROUPED_AGG_ARROW_ITER_UDF -class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): - def load_stream(self, stream): - """ - Yield an iterator that produces one tuple of column arrays per batch. - Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one - without consuming all batches upfront. - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches one at a time from the stream - # This avoids loading all batches into memory for the group - columns_iter = (batch.columns for batch in batches) - yield columns_iter - # Make sure the batches are fully iterated before getting the next group - for _ in columns_iter: - pass - - def __repr__(self): - return "ArrowStreamAggArrowUDFSerializer" - - -# Serializer for SQL_GROUPED_AGG_PANDAS_UDF, SQL_WINDOW_AGG_PANDAS_UDF, -# and SQL_GROUPED_AGG_PANDAS_ITER_UDF -class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def load_stream(self, stream): - """ - Yield an iterator that produces one tuple of pandas.Series per batch. - Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to - process batches one by one without consuming all batches upfront. - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches to pandas Series one at a time - # from the stream. This avoids loading all batches into memory for the group - series_iter = map( - lambda batch: tuple( - ArrowBatchTransformer.to_pandas( - batch, - timezone=self._timezone, - schema=self._input_type, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, + # multiple array UDFs in a projection + arrays_and_types = [(create_array(*t), t[1]) for t in packed] + yield ArrowBatchTransformer.create_batch_from_arrays( + arrays_and_types, arrow_cast=self._arrow_cast, safecheck=self._safecheck ) - ), - batches, - ) - yield series_iter - # Make sure the batches are fully iterated before getting the next group - for _ in series_iter: - pass - - def __repr__(self): - return "ArrowStreamAggPandasUDFSerializer" - - -# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF -class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def load_stream(self, stream): - """ - Deserialize Grouped ArrowRecordBatches and yield as Iterator[Iterator[pd.Series]]. - Each outer iterator element represents a group, containing an iterator of Series lists - (one list per batch). - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches one at a time from the stream - # This avoids loading all batches into memory for the group - series_iter = map( - lambda batch: ArrowBatchTransformer.to_pandas( - batch, - timezone=self._timezone, - schema=self._input_type, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ), - batches, - ) - yield series_iter - # Make sure the batches are fully iterated before getting the next group - for _ in series_iter: - pass - - def dump_stream(self, iterator, stream): - """ - Flatten the Iterator[Iterator[[(df, arrow_type)]]] returned by func. - The mapper returns Iterator[[(df, arrow_type)]], so we flatten one level - to match the parent's expected format Iterator[[(df, arrow_type)]]. - """ - # Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]] - flattened_iter = (batch for generator in iterator for batch in generator) - super().dump_stream(flattened_iter, stream) - - def __repr__(self): - return "GroupPandasUDFSerializer" - - -class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): - """ - Serializes pyarrow.RecordBatch data with Arrow streaming format. - - Loads Arrow record batches as `[([pa.RecordBatch], [pa.RecordBatch])]` (one tuple per group) - and serializes `[([pa.RecordBatch], arrow_type)]`. - - Parameters - ---------- - assign_cols_by_name : bool - If True, then DataFrames will get columns by name - """ - - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. - """ - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches - - -class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two - lists of pandas.Series. - """ - import pyarrow as pa - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield tuple( - ArrowBatchTransformer.to_pandas( - pa.Table.from_batches(batches), - timezone=self._timezone, - schema=self._input_type, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - for batches in (left_batches, right_batches) - ) + return super().dump_stream(py_to_batch(), stream) -class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): +class ApplyInPandasWithStateSerializer(ArrowStreamUDFSerializer): """ Serializer used by Python worker to evaluate UDF for applyInPandasWithState. @@ -1253,14 +586,14 @@ def __init__( super().__init__( timezone=timezone, safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) + self._df_for_struct = False + self._struct_in_pandas = "dict" + self._ndarray_as_list = False + self._input_type = None + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = True self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() self.state_object_schema = state_object_schema @@ -1434,7 +767,7 @@ def gen_data_and_state(batches): state, ) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_state_generator = gen_data_and_state(_batches) @@ -1526,12 +859,18 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat merged_pdf = pd.concat(pdfs, ignore_index=True) merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) - return self._create_batch( + return PandasBatchTransformer.to_arrow( [ (count_pdf, self.result_count_pdf_arrow_type), (merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type), - ] + ], + timezone=self._timezone, + safecheck=self._safecheck, + int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, + struct_in_pandas=self._struct_in_pandas, + assign_cols_by_name=self._assign_cols_by_name, + arrow_cast=self._arrow_cast, ) def serialize_batches(): @@ -1605,7 +944,7 @@ def serialize_batches(): return ArrowStreamSerializer.dump_stream(self, batches, stream) -class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): +class TransformWithStateInPandasSerializer(ArrowStreamUDFSerializer): """ Serializer used by Python worker to evaluate UDF for :meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`. @@ -1634,14 +973,14 @@ def __init__( super().__init__( timezone=timezone, safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) + self._df_for_struct = False + self._struct_in_pandas = "dict" + self._ndarray_as_list = False + self._input_type = None + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = True self.arrow_max_records_per_batch = ( arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 ) @@ -1718,7 +1057,7 @@ def row_stream(): if rows: yield (batch_key, pd.DataFrame(rows)) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): @@ -1733,15 +1072,17 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ - def flatten_iterator(): # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] for packed in iterator: - iter_pdf_with_type = packed[0] - iter_pdf = iter_pdf_with_type[0] - pdf_type = iter_pdf_with_type[1] + iter_pdf, _ = packed[0] for pdf in iter_pdf: - yield (pdf, pdf_type) + yield PandasBatchTransformer.to_arrow( + pdf, + timezone=self._timezone, + safecheck=self._safecheck, + int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, + ) super().dump_stream(flatten_iterator(), stream) @@ -1883,7 +1224,7 @@ def row_stream(): else EMPTY_DATAFRAME.copy(), ) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): @@ -1894,7 +1235,7 @@ def row_stream(): yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) -class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): +class TransformWithStateInPySparkRowSerializer(ArrowStreamGroupSerializer): """ Serializer used by Python worker to evaluate UDF for :meth:`pyspark.sql.GroupedData.transformWithState`. @@ -1946,7 +1287,7 @@ def generate_data_batches(batches): row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) yield row_key, row - _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): @@ -1980,9 +1321,11 @@ def flatten_iterator(): pdf_schema = pa.schema(list(pdf_type)) record_batch = pa.RecordBatch.from_pylist(rows_as_dict, schema=pdf_schema) - yield (record_batch, pdf_type) + # Wrap the batch into a struct before yielding + wrapped_batch = ArrowBatchTransformer.wrap_struct(record_batch) + yield wrapped_batch - return ArrowStreamUDFSerializer.dump_stream(self, flatten_iterator(), stream) + return ArrowStreamGroupSerializer.dump_stream(self, flatten_iterator(), stream) class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer): @@ -2071,7 +1414,7 @@ def row_iterator(): for key, init_state_row in init_result: yield (key, None, init_state_row) - _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 03bc1366e875..219f7ac9a79a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -51,26 +51,17 @@ LocalDataToArrowConversion, ArrowTableToRowsConversion, ArrowBatchTransformer, + PandasBatchTransformer, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( - ArrowStreamPandasUDFSerializer, - ArrowStreamPandasUDTFSerializer, - ArrowStreamGroupUDFSerializer, - GroupPandasUDFSerializer, - CogroupArrowUDFSerializer, - CogroupPandasUDFSerializer, - ArrowStreamUDFSerializer, + ArrowStreamGroupSerializer, ApplyInPandasWithStateSerializer, TransformWithStateInPandasSerializer, TransformWithStateInPandasInitStateSerializer, TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, - ArrowStreamArrowUDFSerializer, - ArrowStreamAggPandasUDFSerializer, - ArrowStreamAggArrowUDFSerializer, ArrowBatchUDFSerializer, - ArrowStreamUDTFSerializer, ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type, TimestampType @@ -261,9 +252,11 @@ def verify_result_length(result, length): return ( args_kwargs_offsets, - lambda *a: ( - verify_result_length(verify_result_type(func(*a)), len(a[0])), - arrow_return_type, + lambda *a: PandasBatchTransformer.to_arrow( + [(verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ), ) @@ -301,9 +294,10 @@ def verify_result_length(result, length): return ( args_kwargs_offsets, - lambda *a: ( - verify_result_length(verify_result_type(func(*a)), len(a[0])), - arrow_return_type, + lambda *a: ArrowBatchTransformer.create_batch_from_arrays( + (verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type), + arrow_cast=True, + safecheck=True, ), ) @@ -484,9 +478,15 @@ def verify_element(elem): return elem - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + def to_batch(res): + return PandasBatchTransformer.to_arrow( + [(res, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + + return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema): @@ -506,7 +506,7 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu if not result.empty or len(result.columns) != 0: # if any column name of the result is a string # the column names of the result have to match the return type - # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer + # see create_array in pyspark.sql.pandas.serializers.ArrowStreamUDFSerializer field_names = set([field.name for field in return_type.fields]) # only the first len(field_names) result columns are considered # when truncating the return schema @@ -579,9 +579,12 @@ def verify_element(elem): return elem - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + def to_batch(res): + return ArrowBatchTransformer.create_batch_from_arrays( + (res, arrow_return_type), arrow_cast=True, safecheck=True + ) + + return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): @@ -614,12 +617,20 @@ def verify_element(elem): return elem - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + # For mapInArrow: batches are already flattened in func before calling wrapper + # User function receives flattened batches and returns flattened batches + # We just need to wrap results back into struct for serialization + def wrapper(batch_iter): + result = verify_result(f(batch_iter)) + for res in map(verify_element, result): + yield ArrowBatchTransformer.wrap_struct(res) + + return wrapper def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): + import pyarrow as pa + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -639,15 +650,22 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - return result.to_batches() + # Reorder columns by name if needed, then wrap each batch into struct + for batch in result.to_batches(): + if runner_conf.assign_cols_by_name: + batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + yield ArrowBatchTransformer.wrap_struct(batch) - return lambda kl, vl, kr, vr: ( - wrapped(kl, vl, kr, vr), - to_arrow_type(return_type, timezone="UTC"), - ) + return lambda kl, vl, kr, vr: wrapped(kl, vl, kr, vr) def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + + arrow_return_type = to_arrow_type( + return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types + ) + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd @@ -663,13 +681,20 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se verify_pandas_result( result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) - - return result - - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), arrow_return_type)] + + # Convert pandas DataFrame to Arrow RecordBatch and yield (consistent with Arrow cogrouped) + batch = PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + struct_in_pandas='dict', + assign_cols_by_name=runner_conf.assign_cols_by_name, + arrow_cast=True, + ) + yield batch + + return wrapped def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): @@ -785,15 +810,18 @@ def wrapped(key_batch, value_batches): verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - yield from result.to_batches() + # Reorder columns by name if needed, then wrap each batch into struct + for batch in result.to_batches(): + if runner_conf.assign_cols_by_name: + batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + yield ArrowBatchTransformer.wrap_struct(batch) - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda k, v: (wrapped(k, v), arrow_return_type) + return lambda k, v: wrapped(k, v) def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf): + import pyarrow as pa + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -810,30 +838,20 @@ def wrapped(key_batch, value_batches): key = tuple(c[0] for c in key_batch.columns) result = f(key, value_batches) - def verify_element(batch): + for batch in result: verify_arrow_batch(batch, runner_conf.assign_cols_by_name, expected_cols_and_types) - return batch - - yield from map(verify_element, result) + if runner_conf.assign_cols_by_name: + batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + yield ArrowBatchTransformer.wrap_struct(batch) - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda k, v: (wrapped(k, v), arrow_return_type) + return lambda k, v: wrapped(k, v) def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_batches): import pandas as pd - # Convert value_batches (Iterator[list[pd.Series]]) to a single DataFrame - # Each value_series is a list of Series (one per column) for one batch - # Concatenate Series within each batch (axis=1), then concatenate batches (axis=0) - value_dataframes = [] - for value_series in value_batches: - value_dataframes.append(pd.concat(value_series, axis=1)) - - value_df = pd.concat(value_dataframes, axis=0) if value_dataframes else pd.DataFrame() + value_df = PandasBatchTransformer.series_batches_to_dataframe(value_batches) if len(argspec.args) == 1: result = f(value_df) @@ -853,9 +871,20 @@ def wrapped(key_series, value_batches): ) def flatten_wrapper(k, v): - # Return Iterator[[(df, arrow_type)]] directly + # Convert pandas DataFrame to Arrow RecordBatch and wrap as struct for JVM + from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer + for df in wrapped(k, v): - yield [(df, arrow_return_type)] + # Split DataFrame into list of (Series, arrow_type) tuples + series_list = [(df[col], arrow_return_type[i].type) for i, col in enumerate(df.columns)] + batch = PandasBatchTransformer.to_arrow( + series_list, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) return flatten_wrapper @@ -864,18 +893,16 @@ def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_batches): import pandas as pd - # value_batches is an Iterator[list[pd.Series]] (one list per batch) + # value_batches is an Iterator[List[pd.Series]] (one list per batch) # Convert each list of Series into a DataFrame - def dataframe_iter(): - for value_series in value_batches: - yield pd.concat(value_series, axis=1) + dataframe_iter = map(lambda series_list: pd.concat(series_list, axis=1), value_batches) if len(argspec.args) == 1: - result = f(dataframe_iter()) + result = f(dataframe_iter) elif len(argspec.args) == 2: # Extract key from pandas Series, preserving numpy types key = tuple(s.iloc[0] for s in key_series) - result = f(key, dataframe_iter()) + result = f(key, dataframe_iter) def verify_element(df): verify_pandas_result( @@ -890,9 +917,20 @@ def verify_element(df): ) def flatten_wrapper(k, v): - # Return Iterator[[(df, arrow_type)]] directly + # Convert pandas DataFrames to Arrow RecordBatches and wrap as struct for JVM + from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer + for df in wrapped(k, v): - yield [(df, arrow_return_type)] + # Split DataFrame into list of (Series, arrow_type) tuples + series_list = [(df[col], arrow_return_type[i].type) for i, col in enumerate(df.columns)] + batch = PandasBatchTransformer.to_arrow( + series_list, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) return flatten_wrapper @@ -1068,6 +1106,8 @@ def verify_element(result): def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1078,12 +1118,18 @@ def wrapped(*series): import pandas as pd result = func(*series) - return pd.Series([result]) + result_series = pd.Series([result]) + + # Convert to Arrow RecordBatch in wrapper + return PandasBatchTransformer.to_arrow( + [(result_series, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) - return ( - args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) + return (args_kwargs_offsets, wrapped) def wrap_grouped_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): @@ -1097,11 +1143,15 @@ def wrapped(*series): import pyarrow as pa result = func(*series) - return pa.array([result]) + array = pa.array([result]) + # Return RecordBatch directly instead of (array, type) tuple + return ArrowBatchTransformer.create_batch_from_arrays( + (array, arrow_return_type), arrow_cast=True, safecheck=True + ) return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + wrapped, ) @@ -1117,15 +1167,21 @@ def wrapped(batch_iter): # batch_iter: Iterator[pa.Array] (single) or Iterator[Tuple[pa.Array, ...]] (multiple) result = func(batch_iter) - return pa.array([result]) + array = pa.array([result]) + # Return RecordBatch directly instead of (array, type) tuple + return ArrowBatchTransformer.create_batch_from_arrays( + (array, arrow_return_type), arrow_cast=True, safecheck=True + ) return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + wrapped, ) def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1139,12 +1195,18 @@ def wrapped(series_iter): # Iterator[Tuple[pd.Series, ...]] (multiple columns) # This has already been adapted by the mapper function in read_udfs result = func(series_iter) - return pd.Series([result]) + result_series = pd.Series([result]) + + # Convert to Arrow RecordBatch in wrapper + return PandasBatchTransformer.to_arrow( + [(result_series, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) - return ( - args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) + return (args_kwargs_offsets, wrapped) def wrap_window_agg_pandas_udf( @@ -1208,7 +1270,13 @@ def wrapped(*series): return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: PandasBatchTransformer.to_arrow( + [(wrapped(*a), arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ), ) @@ -1229,7 +1297,9 @@ def wrapped(*series): return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: ArrowBatchTransformer.create_batch_from_arrays( + (wrapped(*a), arrow_return_type), arrow_cast=True, safecheck=True + ), ) @@ -1273,7 +1343,13 @@ def wrapped(begin_index, end_index, *series): return ( args_offsets[:2] + args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: PandasBatchTransformer.to_arrow( + [(wrapped(*a), arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ), ) @@ -1302,7 +1378,9 @@ def wrapped(begin_index, end_index, *series): return ( args_offsets[:2] + args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: ArrowBatchTransformer.create_batch_from_arrays( + (wrapped(*a), arrow_return_type), arrow_cast=True, safecheck=True + ), ) @@ -1534,14 +1612,14 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - ser = ArrowStreamPandasUDTFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - input_type=input_type, + # UDTF uses ArrowStreamGroupSerializer with as_struct=True + ser = ArrowStreamGroupSerializer( + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: - ser = ArrowStreamUDTFSerializer() + ser = ArrowStreamGroupSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: # Read the table argument offsets num_table_arg_offsets = read_int(infile) @@ -2435,12 +2513,14 @@ def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): for batch in convert_to_arrow(func()): - yield batch, arrow_return_type + # Wrap batch into struct for ArrowStreamUDTFSerializer.dump_stream + yield ArrowBatchTransformer.wrap_struct(batch) else: for row in zip(*args): for batch in convert_to_arrow(func(*row)): - yield batch, arrow_return_type + # Wrap batch into struct for ArrowStreamUDTFSerializer.dump_stream + yield ArrowBatchTransformer.wrap_struct(batch) return evaluate @@ -2731,11 +2811,11 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): state_object_schema = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - elif ( - eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF + elif eval_type in ( + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): state_server_port = read_int(infile) if state_server_port == -1: @@ -2743,47 +2823,55 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - if ( - eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF - or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, ): - ser = ArrowStreamGroupUDFSerializer(runner_conf.assign_cols_by_name) + ser = ArrowStreamGroupSerializer( + num_dfs=1, assign_cols_by_name=runner_conf.assign_cols_by_name + ) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): - ser = ArrowStreamAggArrowUDFSerializer(safecheck=True, arrow_cast=True) + ser = ArrowStreamGroupSerializer( + safecheck=True, arrow_cast=True, num_dfs=1 + ) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): - ser = ArrowStreamAggPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - runner_conf.int_to_decimal_coercion_enabled, + # Use ArrowStreamGroupSerializer for agg/window UDFs that still use old pattern + # load_stream returns raw batches, to_pandas conversion done in worker + ser = ArrowStreamGroupSerializer( + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + num_dfs=1, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) - elif ( - eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF + elif eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, ): - ser = GroupPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - runner_conf.int_to_decimal_coercion_enabled, + # Wrapper calls to_arrow directly, so use MapIter serializer in grouped mode + # dump_stream handles RecordBatch directly from wrapper + ser = ArrowStreamGroupSerializer( + num_dfs=1, assign_cols_by_name=runner_conf.assign_cols_by_name ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: - ser = CogroupArrowUDFSerializer(runner_conf.assign_cols_by_name) + ser = ArrowStreamGroupSerializer( + num_dfs=2, assign_cols_by_name=runner_conf.assign_cols_by_name + ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - ser = CogroupPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + # Mapper calls to_arrow directly, so use Arrow serializer + # load_stream returns raw batches, to_pandas conversion done in worker + # dump_stream handles RecordBatch directly from mapper + ser = ArrowStreamGroupSerializer( + safecheck=runner_conf.safecheck, arrow_cast=True, + num_dfs=2, ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: ser = ApplyInPandasWithStateSerializer( @@ -2820,13 +2908,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): runner_conf.arrow_max_records_per_batch ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - ser = ArrowStreamUDFSerializer() + ser = ArrowStreamGroupSerializer() elif eval_type in ( PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, ): # Arrow cast and safe check are always enabled - ser = ArrowStreamArrowUDFSerializer(safecheck=True, arrow_cast=True) + ser = ArrowStreamGroupSerializer(safecheck=True, arrow_cast=True) elif ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion @@ -2841,38 +2929,37 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. - df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + pandas_udf_df_for_struct = eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, ) # Arrow-optimized Python UDF takes a struct type argument as a Row - struct_in_pandas = ( - "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" - ) - ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + pandas_udf_struct_in_pandas = "row" if is_arrow_batched_udf else "dict" + pandas_udf_ndarray_as_list = is_arrow_batched_udf # Arrow-optimized Python UDF takes input types - input_type = ( + pandas_udf_input_type = ( _parse_datatype_json_string(utf8_deserializer.loads(infile)) - if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + if is_arrow_batched_udf else None ) - ser = ArrowStreamPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - df_for_struct, - struct_in_pandas, - ndarray_as_list, - True, - input_type, + ser = ArrowStreamGroupSerializer( + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) + # Initialize transformer parameters (will be set if needed) + pandas_udf_input_type = None + pandas_udf_struct_in_pandas = "dict" + pandas_udf_ndarray_as_list = False + pandas_udf_df_for_struct = False + # Read all UDFs num_udfs = read_int(infile) udfs = [ @@ -2887,6 +2974,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF + # Check if we need to convert Arrow batches to pandas + needs_arrow_to_pandas = eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + ) + if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: # TODO: Better error message for num_udfs != 1 if is_scalar_iter: @@ -2899,6 +2993,24 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): arg_offsets, udf = udfs[0] def func(_, iterator): + # For MAP_ARROW_ITER, flatten struct before processing + if is_map_arrow_iter: + iterator = map(ArrowBatchTransformer.flatten_struct, iterator) + + # Convert Arrow batches to pandas if needed + if needs_arrow_to_pandas: + iterator = map( + lambda batch: ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + schema=pandas_udf_input_type, + struct_in_pandas=pandas_udf_struct_in_pandas, + ndarray_as_list=pandas_udf_ndarray_as_list, + df_for_struct=pandas_udf_df_for_struct, + ), + iterator + ) + num_input_rows = 0 def map_batch(batch): @@ -2911,12 +3023,22 @@ def map_batch(batch): else: return tuple(udf_args) - iterator = map(map_batch, iterator) - result_iter = udf(iterator) + # For MAP_ARROW_ITER, pass the whole batch to UDF, not extracted columns + if is_map_arrow_iter: + # Count input rows for verification + def count_rows(batch): + nonlocal num_input_rows + num_input_rows += batch.num_rows + return batch + iterator = map(count_rows, iterator) + result_iter = udf(iterator) + else: + iterator = map(map_batch, iterator) + result_iter = udf(iterator) num_output_rows = 0 - for result_batch, result_type in result_iter: - num_output_rows += len(result_batch) + for result_batch in result_iter: + num_output_rows += len(result_batch.column(0)) # This check is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, @@ -2926,7 +3048,7 @@ def map_batch(batch): raise PySparkRuntimeError( errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} ) - yield (result_batch, result_type) + yield result_batch if is_scalar_iter: try: @@ -2995,20 +3117,26 @@ def extract_key_value_indexes(grouped_arg_offsets): arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - def mapper(series_iter): - # Need to materialize the first series list to get the keys - first_series_list = next(series_iter) + def mapper(batch_iter): + # Convert Arrow batches to pandas Series lists in worker layer + series_iter = map( + lambda batch: ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone + ), + batch_iter, + ) - # Extract key Series from the first batch + # Materialize first batch to extract keys (keys are same for all batches in group) + first_series_list = next(series_iter) key_series = [first_series_list[o] for o in parsed_offsets[0][0]] - # Create generator for value Series lists (one list per batch) + # Create generator for value Series: yields List[pd.Series] per batch value_series_gen = ( [series_list[o] for o in parsed_offsets[0][1]] for series_list in itertools.chain((first_series_list,), series_iter) ) - # Flatten one level: yield from wrapper to return Iterator[[(df, arrow_type)]] + # Wrapper yields wrapped RecordBatches (one or more per group) yield from f(key_series, value_series_gen) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: @@ -3147,12 +3275,6 @@ def mapper(a): arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - def batch_from_offset(batch, offsets): - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[o] for o in offsets], - names=[batch.schema.names[o] for o in offsets], - ) - def mapper(batches): # Flatten struct column into separate columns flattened = map(ArrowBatchTransformer.flatten_struct, batches) @@ -3160,9 +3282,9 @@ def mapper(batches): # Need to materialize the first batch to get the keys first_batch = next(flattened) - keys = batch_from_offset(first_batch, parsed_offsets[0][0]) + keys = ArrowBatchTransformer.partial_batch(first_batch, parsed_offsets[0][0]) value_batches = ( - batch_from_offset(batch, parsed_offsets[0][1]) + ArrowBatchTransformer.partial_batch(batch, parsed_offsets[0][1]) for batch in itertools.chain((first_batch,), flattened) ) @@ -3206,6 +3328,8 @@ def mapper(a): return f(keys, vals, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + import pyarrow as pa + # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 @@ -3213,11 +3337,24 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) + # Cogrouped UDF receives tuple of (list of RecordBatches, list of RecordBatches) def mapper(a): - df1_keys = [a[0][o] for o in parsed_offsets[0][0]] - df1_vals = [a[0][o] for o in parsed_offsets[0][1]] - df2_keys = [a[1][o] for o in parsed_offsets[1][0]] - df2_vals = [a[1][o] for o in parsed_offsets[1][1]] + # a is tuple[list[pa.RecordBatch], list[pa.RecordBatch]] + left_batches, right_batches = a[0], a[1] + + # Convert batches to tables (batches already have flattened struct columns) + left_table = pa.Table.from_batches(left_batches) if left_batches else pa.table({}) + right_table = pa.Table.from_batches(right_batches) if right_batches else pa.table({}) + + # Convert tables to pandas Series lists + left_series = ArrowBatchTransformer.to_pandas(left_table, timezone=runner_conf.timezone) + right_series = ArrowBatchTransformer.to_pandas(right_table, timezone=runner_conf.timezone) + + df1_keys = [left_series[o] for o in parsed_offsets[0][0]] + df1_vals = [left_series[o] for o in parsed_offsets[0][1]] + df2_keys = [right_series[o] for o in parsed_offsets[1][0]] + df2_vals = [right_series[o] for o in parsed_offsets[1][1]] + return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: @@ -3230,20 +3367,11 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) - def batch_from_offset(batch, offsets): - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[o] for o in offsets], - names=[batch.schema.names[o] for o in offsets], - ) - - def table_from_batches(batches, offsets): - return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) - def mapper(a): - df1_keys = table_from_batches(a[0], parsed_offsets[0][0]) - df1_vals = table_from_batches(a[0], parsed_offsets[0][1]) - df2_keys = table_from_batches(a[1], parsed_offsets[1][0]) - df2_vals = table_from_batches(a[1], parsed_offsets[1][1]) + df1_keys = ArrowBatchTransformer.partial_table(a[0], parsed_offsets[0][0]) + df1_vals = ArrowBatchTransformer.partial_table(a[0], parsed_offsets[0][1]) + df2_keys = ArrowBatchTransformer.partial_table(a[1], parsed_offsets[1][0]) + df2_vals = ArrowBatchTransformer.partial_table(a[1], parsed_offsets[1][1]) return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: @@ -3253,13 +3381,15 @@ def mapper(a): arg_offsets, f = udfs[0] - # Convert to iterator of batches: Iterator[pa.Array] for single column, + # Convert to iterator of arrays: Iterator[pa.Array] for single column, # or Iterator[Tuple[pa.Array, ...]] for multiple columns + # a is Iterator[pa.RecordBatch], extract columns from each batch def mapper(a): if len(arg_offsets) == 1: - batch_iter = (batch_columns[arg_offsets[0]] for batch_columns in a) + batch_iter = (batch.column(arg_offsets[0]) for batch in a) else: - batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for batch_columns in a) + batch_iter = (tuple(batch.column(o) for o in arg_offsets) for batch in a) + # f returns pa.RecordBatch directly return f(batch_iter) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF: @@ -3269,21 +3399,23 @@ def mapper(a): arg_offsets, f = udfs[0] - # Convert to iterator of pandas Series: + # POC: Convert raw Arrow batches to pandas Series in worker layer # - Iterator[pd.Series] for single column # - Iterator[Tuple[pd.Series, ...]] for multiple columns def mapper(batch_iter): - # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch - # Convert to Iterator[pd.Series] or Iterator[Tuple[pd.Series, ...]] based on arg_offsets - if len(arg_offsets) == 1: - # Single column: Iterator[Tuple[pd.Series, ...]] -> Iterator[pd.Series] - series_iter = (batch_series[arg_offsets[0]] for batch_series in batch_iter) - else: - # Multiple columns: Iterator[Tuple[pd.Series, ...]] -> - # Iterator[Tuple[pd.Series, ...]] - series_iter = ( - tuple(batch_series[o] for o in arg_offsets) for batch_series in batch_iter - ) + # batch_iter is Iterator[pa.RecordBatch] (raw batches from serializer) + # Convert to pandas Series lists, then select columns based on arg_offsets + series_lists = map( + lambda batch: ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), + batch_iter, + ) + series_iter = ( + series_list[arg_offsets[0]] + if len(arg_offsets) == 1 + else tuple(series_list[o] for o in arg_offsets) + for series_list in series_lists + ) + # Wrapper returns RecordBatch directly (conversion done in wrapper) return f(series_iter) elif eval_type in ( @@ -3293,36 +3425,18 @@ def mapper(batch_iter): import pyarrow as pa # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF, - # convert iterator of batch columns to a concatenated RecordBatch + # convert iterator of RecordBatch to a concatenated RecordBatch def mapper(a): - # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch - batches = [] - for batch_columns in a: - # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch - batch = pa.RecordBatch.from_arrays( - batch_columns, names=["_%d" % i for i in range(len(batch_columns))] - ) - batches.append(batch) + # a is Iterator[pa.RecordBatch] - collect and concatenate all batches + concatenated_batch = ArrowBatchTransformer.concat_batches(list(a)) - # Concatenate all batches into one - if hasattr(pa, "concat_batches"): - concatenated_batch = pa.concat_batches(batches) - else: - # pyarrow.concat_batches not supported in old versions - concatenated_batch = pa.RecordBatch.from_struct_array( - pa.concat_arrays([b.to_struct_array() for b in batches]) - ) - - # Extract series using offsets (concatenated_batch.columns[o] gives pa.Array) - result = tuple( + # Each UDF returns pa.RecordBatch + result_batches = [ f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs - ) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] - else: - return result + ] + + # Merge RecordBatches from all UDFs horizontally + return ArrowBatchTransformer.merge_batches(result_batches) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -3331,45 +3445,70 @@ def mapper(a): import pandas as pd # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF, - # convert iterator of batch tuples to concatenated pandas Series + # batch_iter is now Iterator[pa.RecordBatch] (raw batches from serializer) + # Convert to pandas and concatenate into single Series per column def mapper(batch_iter): - # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch - # Collect all batches and concatenate into single Series per column - batches = list(batch_iter) - if not batches: - # Empty batches - determine num_columns from all UDFs' arg_offsets - all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] - num_columns = max(all_offsets) + 1 if all_offsets else 0 - concatenated = [pd.Series(dtype=object) for _ in range(num_columns)] - else: - # Use actual number of columns from the first batch - num_columns = len(batches[0]) - concatenated = [ - pd.concat([batch[i] for batch in batches], ignore_index=True) - for i in range(num_columns) - ] + # Convert raw Arrow batches to pandas Series in worker layer + series_batches = [ + ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone) + for batch in batch_iter + ] + # Concatenate all batches column-wise + all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] + concatenated = PandasBatchTransformer.concat_series_batches(series_batches, all_offsets) - result = tuple(f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] - else: - return result + # Each UDF returns pa.RecordBatch (conversion done in wrapper) + result_batches = [f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs] + + # Merge all RecordBatches horizontally + return ArrowBatchTransformer.merge_batches(result_batches) else: + import pyarrow as pa + + # Check if we need to convert Arrow batches to pandas (for scalar pandas UDF) + needs_arrow_to_pandas = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF def mapper(a): - result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] - else: - return result + # Each UDF returns a RecordBatch (single column) + result_batches = tuple( + f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs + ) + # Merge all RecordBatches into a single RecordBatch with multiple columns + return ArrowBatchTransformer.merge_batches(result_batches) + + # For grouped/cogrouped map UDFs: + # All wrappers yield batches, so mapper returns generator → need flatten + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + ): - def func(_, it): - return map(mapper, it) + def func(_, it): + # All grouped/cogrouped wrappers yield, so flatten with chain.from_iterable + return itertools.chain.from_iterable(map(mapper, it)) + + else: + + def func(_, it): + # Convert Arrow batches to pandas if needed + if needs_arrow_to_pandas: + it = map( + lambda batch: ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + schema=pandas_udf_input_type, + struct_in_pandas=pandas_udf_struct_in_pandas, + ndarray_as_list=pandas_udf_ndarray_as_list, + df_for_struct=pandas_udf_df_for_struct, + ), + it + ) + return map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser From f359724db96659998b86d1e6e1ddb81957c900f4 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 30 Jan 2026 21:44:31 +0000 Subject: [PATCH 28/39] refactor: simplify ArrowStreamArrowUDTFSerializer using zip_batches - Use ArrowBatchTransformer.zip_batches for type coercion instead of manual loop - Simplify error handling logic - Handle empty struct case properly - Unwrap wrapped batches from worker before type coercion, then wrap back for JVM - All 46/47 UDTF tests pass (1 known failure unrelated to this change) --- python/pyspark/sql/conversion.py | 856 ++++++++++------------- python/pyspark/sql/pandas/serializers.py | 108 +-- python/pyspark/worker.py | 63 +- 3 files changed, 478 insertions(+), 549 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index a997be17b6c9..3713d95d6d3a 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -18,7 +18,7 @@ import array import datetime import decimal -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple, Union, overload from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError from pyspark.sql.pandas.types import ( @@ -60,12 +60,12 @@ class ArrowBatchTransformer: """ - Pure functions that transform RecordBatch -> RecordBatch. + Pure functions that transform Arrow data structures (Arrays, RecordBatches). They should have no side effects (no I/O, no writing to streams). This class provides utility methods for Arrow batch transformations used throughout PySpark's Arrow UDF implementation. All methods are static and handle common patterns - like struct wrapping/unwrapping and schema conversions. + like struct wrapping/unwrapping, schema conversions, and creating RecordBatches from Arrays. """ @@ -112,86 +112,6 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) - @classmethod - def partial_batch(cls, batch: "pa.RecordBatch", column_indices: List[int]) -> "pa.RecordBatch": - """ - Create a new RecordBatch with only selected columns. - - This method selects a subset of columns from a RecordBatch by their indices, - preserving column names and data types. - - Parameters - ---------- - batch : pa.RecordBatch - Input RecordBatch - column_indices : List[int] - Indices of columns to select (0-based) - - Returns - ------- - pa.RecordBatch - New RecordBatch containing only the selected columns - - Used by - ------- - - SQL_GROUPED_MAP_ARROW_UDF mapper - - partial_table - - Examples - -------- - >>> import pyarrow as pa - >>> batch = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['a', 'b']) - >>> partial = ArrowBatchTransformer.partial_batch(batch, [1]) - >>> partial.schema.names - ['b'] - """ - import pyarrow as pa - - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[i] for i in column_indices], - names=[batch.schema.names[i] for i in column_indices], - ) - - @classmethod - def partial_table(cls, batches: List["pa.RecordBatch"], column_indices: List[int]) -> "pa.Table": - """ - Combine multiple batches into a Table with only selected columns. - - This method selects a subset of columns from each RecordBatch and combines - them into a single Arrow Table. - - Parameters - ---------- - batches : List[pa.RecordBatch] - List of RecordBatches to combine - column_indices : List[int] - Indices of columns to select (0-based) - - Returns - ------- - pa.Table - Combined Table containing only the selected columns - - Used by - ------- - - SQL_COGROUPED_MAP_ARROW_UDF mapper - - Examples - -------- - >>> import pyarrow as pa - >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['a', 'b']) - >>> batch2 = pa.RecordBatch.from_arrays([pa.array([5, 6]), pa.array([7, 8])], ['a', 'b']) - >>> table = ArrowBatchTransformer.partial_table([batch1, batch2], [1]) - >>> table.schema.names - ['b'] - >>> len(table) - 4 - """ - import pyarrow as pa - - return pa.Table.from_batches( - [cls.partial_batch(batch, column_indices) for batch in batches] - ) @classmethod def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": @@ -226,6 +146,26 @@ def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": """ import pyarrow as pa + if not batches: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_CONCAT", + messageParameters={"reason": "Cannot concatenate empty list of batches"}, + ) + + # Assert all batches have the same schema + first_schema = batches[0].schema + for i, batch in enumerate(batches[1:], start=1): + if batch.schema != first_schema: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_CONCAT", + messageParameters={ + "reason": ( + f"All batches must have the same schema. " + f"Batch 0 has schema {first_schema}, but batch {i} has schema {batch.schema}." + ) + }, + ) + if hasattr(pa, "concat_batches"): return pa.concat_batches(batches) else: @@ -235,48 +175,109 @@ def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": ) @classmethod - def merge_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": + def zip_batches( + cls, + items: Union[ + List["pa.RecordBatch"], + List["pa.Array"], + List[Tuple["pa.Array", "pa.DataType"]], + ], + safecheck: bool = True, + ) -> "pa.RecordBatch": """ - Merge multiple RecordBatches horizontally by combining their columns. + Zip multiple RecordBatches or Arrays horizontally by combining their columns. This is different from concat_batches which concatenates rows vertically. - This method combines columns from multiple batches into a single batch, - useful when multiple UDFs each produce a RecordBatch and we need to - combine their outputs. + This method combines columns from multiple batches/arrays into a single batch, + useful when multiple UDFs each produce a RecordBatch or when combining arrays. Parameters ---------- - batches : List[pa.RecordBatch] - List of RecordBatches to merge (must have same number of rows) + items : List[pa.RecordBatch], List[pa.Array], or List[Tuple[pa.Array, pa.DataType]] + - List of RecordBatches to zip (must have same number of rows) + - List of Arrays to combine directly + - List of (array, type) tuples for type casting (always attempts cast if types don't match) + safecheck : bool, default True + If True, use safe casting (fails on overflow/truncation) (only used when items are tuples). Returns ------- pa.RecordBatch - Single RecordBatch with all columns from input batches + Single RecordBatch with all columns from input batches/arrays Used by ------- - SQL_GROUPED_AGG_ARROW_UDF mapper - SQL_WINDOW_AGG_ARROW_UDF mapper + - wrap_scalar_arrow_udf + - wrap_grouped_agg_arrow_udf + - ArrowBatchUDFSerializer.dump_stream Examples -------- >>> import pyarrow as pa >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ['a']) >>> batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ['b']) - >>> result = ArrowBatchTransformer.merge_batches([batch1, batch2]) + >>> result = ArrowBatchTransformer.zip_batches([batch1, batch2]) + >>> result.to_pydict() + {'_0': [1, 2], '_1': [3, 4]} + >>> # Can also zip arrays directly + >>> result = ArrowBatchTransformer.zip_batches([pa.array([1, 2]), pa.array([3, 4])]) >>> result.to_pydict() {'_0': [1, 2], '_1': [3, 4]} + >>> # Can also zip with type casting + >>> result = ArrowBatchTransformer.zip_batches( + ... [(pa.array([1, 2]), pa.int64()), (pa.array([3, 4]), pa.int64())] + ... ) """ import pyarrow as pa + + if not items: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_ZIP", + messageParameters={"reason": "Cannot zip empty list"}, + ) + + # Check if items are RecordBatches, Arrays, or (array, type) tuples + first_item = items[0] - if len(batches) == 1: - return batches[0] - - # Combine all columns from all batches - all_columns = [] - for batch in batches: - all_columns.extend(batch.columns) + if isinstance(first_item, pa.RecordBatch): + # Handle RecordBatches + batches = items + if len(batches) == 1: + return batches[0] + + # Assert all batches have the same number of rows + num_rows = batches[0].num_rows + for i, batch in enumerate(batches[1:], start=1): + assert batch.num_rows == num_rows, ( + f"All batches must have the same number of rows. " + f"Batch 0 has {num_rows} rows, but batch {i} has {batch.num_rows} rows." + ) + + # Combine all columns from all batches + all_columns = [] + for batch in batches: + all_columns.extend(batch.columns) + elif isinstance(first_item, tuple) and len(first_item) == 2: + # Handle (array, type) tuples with type casting (always attempt cast if types don't match) + all_columns = [ + cls._cast_array( + array, + arrow_type, + safecheck=safecheck, + error_message=( + "Arrow UDFs require the return type to match the expected Arrow type. " + f"Expected: {arrow_type}, but got: {array.type}." + ), + ) + for array, arrow_type in items + ] + else: + # Handle Arrays directly + all_columns = items + + # Create RecordBatch from columns return pa.RecordBatch.from_arrays( all_columns, ["_%d" % i for i in range(len(all_columns))] ) @@ -343,108 +344,6 @@ def reorder_columns( names=field_names, ) - @staticmethod - def cast_array( - arr: "pa.Array", - target_type: "pa.DataType", - arrow_cast: bool = False, - safecheck: bool = True, - error_message: Optional[str] = None, - ) -> "pa.Array": - """ - Cast an Arrow Array to a target type with type checking. - - Parameters - ---------- - arr : pa.Array - The Arrow Array to cast. - target_type : pa.DataType - Target Arrow data type. - arrow_cast : bool - If True, always attempt to cast. If False, raise error on type mismatch. - safecheck : bool - If True, use safe casting (fails on overflow/truncation). - error_message : str, optional - Custom error message for type mismatch. - - Returns - ------- - pa.Array - The casted array if types differ, or original array if types match. - """ - import pyarrow as pa - - assert isinstance(arr, pa.Array) - assert isinstance(target_type, pa.DataType) - - if arr.type == target_type: - return arr - elif arrow_cast: - return arr.cast(target_type=target_type, safe=safecheck) - else: - if error_message: - raise PySparkTypeError(error_message) - else: - raise PySparkTypeError( - f"Arrow type mismatch. Expected: {target_type}, but got: {arr.type}." - ) - - @staticmethod - def create_batch_from_arrays( - packed: Union[tuple, list], - arrow_cast: bool = False, - safecheck: bool = True, - ) -> "pa.RecordBatch": - """ - Create a RecordBatch from (array, type) pairs with type casting. - - Parameters - ---------- - packed : tuple or list - Either a (array, type) tuple for single array, or a list of (array, type) tuples. - arrow_cast : bool - If True, always attempt to cast. If False, raise error on type mismatch. - safecheck : bool - If True, use safe casting (fails on overflow/truncation). - - Returns - ------- - pa.RecordBatch - RecordBatch with casted arrays. - """ - import pyarrow as pa - - if len(packed) == 2 and isinstance(packed[1], pa.DataType): - # single array UDF in a projection - arrs = [ - ArrowBatchTransformer.cast_array( - packed[0], - packed[1], - arrow_cast=arrow_cast, - safecheck=safecheck, - error_message=( - "Arrow UDFs require the return type to match the expected Arrow type. " - f"Expected: {packed[1]}, but got: {packed[0].type}." - ), - ) - ] - else: - # multiple array UDFs in a projection - arrs = [ - ArrowBatchTransformer.cast_array( - t[0], - t[1], - arrow_cast=arrow_cast, - safecheck=safecheck, - error_message=( - "Arrow UDFs require the return type to match the expected Arrow type. " - f"Expected: {t[1]}, but got: {t[0].type}." - ), - ) - for t in packed - ] - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - @classmethod def to_pandas( cls, @@ -498,258 +397,62 @@ def to_pandas( for i in range(batch.num_columns) ] - -class PandasBatchTransformer: - """ - Pure functions that transform between pandas DataFrames/Series and Arrow RecordBatches. - They should have no side effects (no I/O, no writing to streams). - - This class provides utility methods for converting between pandas and Arrow formats, - used primarily by Pandas UDF wrappers and serializers. - - """ - - @classmethod - def create_array( - cls, - series: "pd.Series", - arrow_type: Optional["pa.DataType"], - timezone: str, - safecheck: bool = True, - spark_type: Optional[DataType] = None, - arrow_cast: bool = False, - int_to_decimal_coercion_enabled: bool = False, - ignore_unexpected_complex_type_values: bool = False, - error_class: Optional[str] = None, - ) -> "pa.Array": - """ - Create an Arrow Array from the given pandas.Series and optional type. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - timezone : str - Timezone for timestamp conversion - safecheck : bool - Whether to perform safe type checking - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast : bool - Whether to apply Arrow casting when type mismatches - int_to_decimal_coercion_enabled : bool - Whether to enable int to decimal coercion - ignore_unexpected_complex_type_values : bool - Whether to ignore unexpected complex type values during conversion - error_class : str, optional - Custom error class for arrow type cast errors (e.g., "UDTF_ARROW_TYPE_CAST_ERROR") - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = _create_converter_from_pandas( - dt, - timezone=timezone, - error_on_duplicated_field_names=False, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, - ) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - # For UDTF (error_class is set), use Arrow-specific error handling - if error_class is not None: - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=safecheck - ) - except pa.lib.ArrowException: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=safecheck - ) - else: - raise - except pa.lib.ArrowException: - raise PySparkRuntimeError( - errorClass=error_class, - messageParameters={ - "col_name": series.name, - "col_type": str(series.dtype), - "arrow_type": str(arrow_type), - }, - ) from None - - # For regular UDF, use standard error handling - try: - try: - return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) - except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError): - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=safecheck - ) - else: - raise - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except (ValueError, pa.lib.ArrowException) as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - if safecheck: - error_msg = error_msg + ( - " It can be caused by overflows or other " - "unsafe conversions warned by Arrow. Arrow safe type check " - "can be disabled by using SQL config " - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - - @staticmethod - def normalize_input( - series: Union["pd.Series", "pd.DataFrame", List], - ) -> "Iterator[Tuple[Any, Optional[pa.DataType], Optional[DataType]]]": - """ - Normalize input to a consistent format for batch conversion. - - Converts various input formats to an iterator of - (data, arrow_type, spark_type) tuples. - - Parameters - ---------- - series : pandas.Series, pandas.DataFrame, or list - A single series/dataframe, list of series/dataframes, or list of - (series, arrow_type) or (series, arrow_type, spark_type) - - Returns - ------- - Iterator[Tuple[Any, Optional[pa.DataType], Optional[DataType]]] - Iterator of (data, arrow_type, spark_type) tuples - """ - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - return ((s[0], s[1], None) if len(s) == 2 else s for s in series) - @classmethod - def create_struct_array( + def _cast_array( cls, - df: "pd.DataFrame", - arrow_struct_type: "pa.StructType", - timezone: str, + arr: "pa.Array", + target_type: "pa.DataType", safecheck: bool = True, - spark_type: Optional["StructType"] = None, - arrow_cast: bool = False, - int_to_decimal_coercion_enabled: bool = False, - ignore_unexpected_complex_type_values: bool = False, - error_class: Optional[str] = None, - assign_cols_by_name: bool = True, - ) -> "pa.StructArray": + error_message: Optional[str] = None, + ) -> "pa.Array": """ - Create an Arrow StructArray from the given pandas.DataFrame. + Cast an Arrow Array to a target type with type checking. + + This is a private method used internally by zip_batches. Parameters ---------- - df : pandas.DataFrame - A pandas DataFrame - arrow_struct_type : pyarrow.StructType - Target Arrow struct type - timezone : str - Timezone for timestamp conversion + arr : pa.Array + The Arrow Array to cast. + target_type : pa.DataType + Target Arrow data type. safecheck : bool - Whether to perform safe type checking - spark_type : StructType, optional - Spark schema for type conversion - arrow_cast : bool - Whether to apply Arrow casting when type mismatches - int_to_decimal_coercion_enabled : bool - Whether to enable int to decimal coercion - ignore_unexpected_complex_type_values : bool - Whether to ignore unexpected complex type values - error_class : str, optional - Custom error class for type cast errors - assign_cols_by_name : bool - If True, assign columns by name; otherwise by position + If True, use safe casting (fails on overflow/truncation). + error_message : str, optional + Custom error message for type mismatch (used if cast fails). Returns ------- - pyarrow.StructArray + pa.Array + The casted array if types differ, or original array if types match. """ import pyarrow as pa - if len(df.columns) == 0: - return pa.array([{}] * len(df), arrow_struct_type) + assert isinstance(arr, pa.Array) + assert isinstance(target_type, pa.DataType) - # Assign result columns by schema name if user labeled with strings - if assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - struct_arrs = [ - cls.create_array( - df[field.name], - field.type, - timezone=timezone, - safecheck=safecheck, - spark_type=( - spark_type[field.name].dataType if spark_type is not None else None - ), - arrow_cast=arrow_cast, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, - error_class=error_class, - ) - for field in arrow_struct_type - ] - # Assign result columns by position - else: - struct_arrs = [ - cls.create_array( - df[df.columns[i]].rename(field.name), - field.type, - timezone=timezone, - safecheck=safecheck, - spark_type=spark_type[i].dataType if spark_type is not None else None, - arrow_cast=arrow_cast, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, - error_class=error_class, - ) - for i, field in enumerate(arrow_struct_type) - ] + if arr.type == target_type: + return arr + + try: + return arr.cast(target_type=target_type, safe=safecheck) + except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: + if error_message: + raise PySparkTypeError(error_message) from e + else: + raise PySparkTypeError( + f"Arrow type mismatch. Expected: {target_type}, but got: {arr.type}." + ) from e - return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type)) +class PandasBatchTransformer: + """ + Pure functions that transform between pandas DataFrames/Series and Arrow RecordBatches. + They should have no side effects (no I/O, no writing to streams). + + This class provides utility methods for converting between pandas and Arrow formats, + used primarily by Pandas UDF wrappers and serializers. + + """ @classmethod def concat_series_batches( @@ -900,59 +603,86 @@ def to_arrow( import pyarrow as pa from pyspark.sql.pandas.types import is_variant + # Normalize input to a consistent format + # Make input conform to [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] + if ( + not isinstance(series, (list, tuple)) + or (len(series) == 2 and isinstance(series[1], pa.DataType)) + or ( + len(series) == 3 + and isinstance(series[1], pa.DataType) + and isinstance(series[2], DataType) + ) + ): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + normalized_series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) + arrs = [] - for s, arrow_type, spark_type in cls.normalize_input(series): - if as_struct: - # UDTF mode: require DataFrame, create struct array - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Output of an arrow-optimized Python UDTFs expects " - f"a pandas.DataFrame but got: {type(s)}" - ) - arrs.append( - cls.create_struct_array( - s, - arrow_type, - timezone=timezone, - safecheck=safecheck, - spark_type=spark_type, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, - error_class=error_class, - assign_cols_by_name=assign_cols_by_name, - arrow_cast=arrow_cast, - ) - ) - elif ( + for s, arrow_type, spark_type in normalized_series: + if as_struct or ( struct_in_pandas == "dict" and arrow_type is not None and pa.types.is_struct(arrow_type) and not is_variant(arrow_type) ): - # Struct type with dict mode: require DataFrame + # Struct mode: require DataFrame, create struct array if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Invalid return type. Please make sure that the UDF returns a " - "pandas.DataFrame when the specified return type is StructType." + error_msg = ( + "Output of an arrow-optimized Python UDTFs expects " + if as_struct + else "Invalid return type. Please make sure that the UDF returns a " ) - arrs.append( - cls.create_struct_array( - s, - arrow_type, - timezone=timezone, - safecheck=safecheck, - spark_type=spark_type, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, - error_class=error_class, - assign_cols_by_name=assign_cols_by_name, - arrow_cast=arrow_cast, + if not as_struct: + error_msg += "pandas.DataFrame when the specified return type is StructType." + else: + error_msg += f"a pandas.DataFrame but got: {type(s)}" + raise PySparkValueError(error_msg) + + # Create struct array from DataFrame + if len(s.columns) == 0: + struct_arr = pa.array([{}] * len(s), arrow_type) + else: + # Determine column selection strategy + use_name_matching = assign_cols_by_name and any( + isinstance(name, str) for name in s.columns ) - ) + + struct_arrs = [] + for i, field in enumerate(arrow_type): + # Get Series and spark_type based on matching strategy + if use_name_matching: + series = s[field.name] + field_spark_type = ( + spark_type[field.name].dataType if spark_type is not None else None + ) + else: + series = s[s.columns[i]].rename(field.name) + field_spark_type = ( + spark_type[i].dataType if spark_type is not None else None + ) + + struct_arrs.append( + PandasSeriesToArrowConversion.create_array( + series, + field.type, + timezone=timezone, + safecheck=safecheck, + spark_type=field_spark_type, + arrow_cast=arrow_cast, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + ) + + struct_arr = pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_type)) + + arrs.append(struct_arr) else: # Normal mode: create array from Series arrs.append( - cls.create_array( + PandasSeriesToArrowConversion.create_array( s, arrow_type, timezone=timezone, @@ -1356,6 +1086,48 @@ def convert_other(value: Any) -> Any: else: # pragma: no cover assert False, f"Need converter for {dataType} but failed to find one." + @staticmethod + def create_array( + results: Sequence[Any], + arrow_type: "pa.DataType", + spark_type: DataType, + safecheck: bool = True, + int_to_decimal_coercion_enabled: bool = False, + ) -> "pa.Array": + """ + Create an Arrow Array from a sequence of Python values. + + Parameters + ---------- + results : Sequence[Any] + Sequence of Python values to convert + arrow_type : pa.DataType + Target Arrow data type + spark_type : DataType + Spark data type for conversion + safecheck : bool + If True, use safe casting (fails on overflow/truncation) + int_to_decimal_coercion_enabled : bool + If True, applies additional coercions in Python before converting to Arrow + + Returns + ------- + pa.Array + Arrow Array with converted values + """ + import pyarrow as pa + + conv = LocalDataToArrowConversion._create_converter( + spark_type, + none_on_identity=True, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ) + converted = [conv(res) for res in results] if conv is not None else results + try: + return pa.array(converted, type=arrow_type) + except pa.lib.ArrowInvalid: + return pa.array(converted).cast(target_type=arrow_type, safe=safecheck) + @staticmethod def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table": require_minimum_pyarrow_version() @@ -1873,6 +1645,134 @@ def localize_tz(arr: "pa.Array") -> "pa.Array": assert False, f"Need converter for {pa_type} but failed to find one." +class PandasSeriesToArrowConversion: + """ + Conversion utilities for converting pandas Series to PyArrow Arrays. + + This class provides methods to convert pandas Series to PyArrow Arrays, + with support for Spark-specific type handling and conversions. + + The class is primarily used by PySpark's Pandas UDF wrappers and serializers, + where pandas data needs to be converted to Arrow for efficient serialization. + """ + + @classmethod + def create_array( + cls, + series: "pd.Series", + arrow_type: Optional["pa.DataType"], + timezone: str, + safecheck: bool = True, + spark_type: Optional[DataType] = None, + arrow_cast: bool = False, + int_to_decimal_coercion_enabled: bool = False, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + ) -> "pa.Array": + """ + Create an Arrow Array from the given pandas.Series and optional type. + + Parameters + ---------- + series : pandas.Series + A single series + arrow_type : pyarrow.DataType, optional + If None, pyarrow's inferred type will be used + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + spark_type : DataType, optional + If None, spark type converted from arrow_type will be used + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values during conversion + error_class : str, optional + Custom error class for arrow type cast errors (e.g., "UDTF_ARROW_TYPE_CAST_ERROR") + + Returns + ------- + pyarrow.Array + """ + import pyarrow as pa + import pandas as pd + + if isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype(series.dtypes.categories.dtype) + + if arrow_type is not None: + dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) + conv = _create_converter_from_pandas( + dt, + timezone=timezone, + error_on_duplicated_field_names=False, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + ) + series = conv(series) + + if hasattr(series.array, "__arrow_array__"): + mask = None + else: + mask = series.isnull() + + # For UDTF (error_class is set), use Arrow-specific error handling + if error_class is not None: + try: + try: + return pa.Array.from_pandas( + series, mask=mask, type=arrow_type, safe=safecheck + ) + except pa.lib.ArrowException: + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except pa.lib.ArrowException: + raise PySparkRuntimeError( + errorClass=error_class, + messageParameters={ + "col_name": series.name, + "col_type": str(series.dtype), + "arrow_type": str(arrow_type), + }, + ) from None + + # For regular UDF, use standard error handling + try: + try: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) + except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError): + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except TypeError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e + except (ValueError, pa.lib.ArrowException) as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + if safecheck: + error_msg += ( + " It can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e + + class ArrowArrayToPandasConversion: """ Conversion utilities for converting PyArrow Arrays and ChunkedArrays to pandas. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 42d5700d31c8..d17d2317a982 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -363,44 +363,67 @@ def apply_type_coercion(): arrow_return_type, pa.StructType ), f"Expected pa.StructType, got {type(arrow_return_type)}" - # Handle empty struct case specially - if batch.num_columns == 0: - coerced_batch = batch # skip type coercion + # Batch is already wrapped into a struct column by worker (wrap_arrow_udtf) + # Unwrap it first to access individual columns + if batch.num_columns == 1 and batch.column(0).type == pa.struct(list(arrow_return_type)): + # Batch is wrapped, unwrap it + unwrapped_batch = ArrowBatchTransformer.flatten_struct(batch, column_index=0) + elif batch.num_columns == 0: + # Empty batch: wrap it back to struct column + coerced_batch = ArrowBatchTransformer.wrap_struct(batch) + yield coerced_batch + continue else: - expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = batch.schema.names - - if expected_field_names != actual_field_names: - raise PySparkTypeError( - "Target schema's field names are not matching the record batch's " - "field names. " - f"Expected: {expected_field_names}, but got: {actual_field_names}." - ) + # Batch is not wrapped (shouldn't happen, but handle it) + unwrapped_batch = batch + + # Handle empty struct case specially (no columns to coerce) + if len(arrow_return_type) == 0: + # Empty struct: wrap unwrapped batch (which should also be empty) back to struct column + coerced_batch = ArrowBatchTransformer.wrap_struct(unwrapped_batch) + yield coerced_batch + continue + + # Check field names match + expected_field_names = [field.name for field in arrow_return_type] + actual_field_names = unwrapped_batch.schema.names + if expected_field_names != actual_field_names: + raise PySparkTypeError( + "Target schema's field names are not matching the record batch's " + "field names. " + f"Expected: {expected_field_names}, but got: {actual_field_names}." + ) - coerced_arrays = [] - for i, field in enumerate(arrow_return_type): - original_array = batch.column(i) - try: - coerced_array = ArrowBatchTransformer.cast_array( - original_array, - field.type, - arrow_cast=True, - safecheck=True, - ) - except PySparkTypeError: - # Re-raise with UDTF-specific error + # Use zip_batches for type coercion: create (array, type) tuples + arrays_and_types = [ + (unwrapped_batch.column(i), field.type) + for i, field in enumerate(arrow_return_type) + ] + try: + coerced_batch = ArrowBatchTransformer.zip_batches( + arrays_and_types, safecheck=True + ) + except PySparkTypeError as e: + # Re-raise with UDTF-specific error + # Find the first array that failed type coercion + # arrays_and_types contains (array, field_type) tuples + for array, expected_type in arrays_and_types: + if array.type != expected_type: raise PySparkRuntimeError( errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", messageParameters={ - "expected": str(field.type), - "actual": str(original_array.type), + "expected": str(expected_type), + "actual": str(array.type), }, - ) - coerced_arrays.append(coerced_array) - coerced_batch = pa.RecordBatch.from_arrays( - coerced_arrays, names=expected_field_names - ) - yield coerced_batch, arrow_return_type + ) from e + # If no type mismatch found, re-raise original error + raise + + # Rename columns to match expected field names + coerced_batch = coerced_batch.rename_columns(expected_field_names) + # Wrap into struct column for JVM + coerced_batch = ArrowBatchTransformer.wrap_struct(coerced_batch) + yield coerced_batch return super().dump_stream(apply_type_coercion(), stream) @@ -525,31 +548,28 @@ def dump_stream(self, iterator, stream): import pyarrow as pa def create_array(results, arrow_type, spark_type): - conv = LocalDataToArrowConversion._create_converter( + return LocalDataToArrowConversion.create_array( + results, + arrow_type, spark_type, - none_on_identity=True, + safecheck=self._safecheck, int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, ) - converted = [conv(res) for res in results] if conv is not None else results - try: - return pa.array(converted, type=arrow_type) - except pa.lib.ArrowInvalid: - return pa.array(converted).cast(target_type=arrow_type, safe=self._safecheck) def py_to_batch(): for packed in iterator: if len(packed) == 3 and isinstance(packed[1], pa.DataType): # single array UDF in a projection - yield ArrowBatchTransformer.create_batch_from_arrays( - (create_array(packed[0], packed[1], packed[2]), packed[1]), - arrow_cast=self._arrow_cast, + yield ArrowBatchTransformer.zip_batches( + [(create_array(packed[0], packed[1], packed[2]), packed[1])], safecheck=self._safecheck, ) else: # multiple array UDFs in a projection arrays_and_types = [(create_array(*t), t[1]) for t in packed] - yield ArrowBatchTransformer.create_batch_from_arrays( - arrays_and_types, arrow_cast=self._arrow_cast, safecheck=self._safecheck + yield ArrowBatchTransformer.zip_batches( + arrays_and_types, + safecheck=self._safecheck, ) return super().dump_stream(py_to_batch(), stream) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 219f7ac9a79a..91ef242bc605 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -52,6 +52,7 @@ ArrowTableToRowsConversion, ArrowBatchTransformer, PandasBatchTransformer, + ArrowArrayToPandasConversion, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( @@ -294,10 +295,8 @@ def verify_result_length(result, length): return ( args_kwargs_offsets, - lambda *a: ArrowBatchTransformer.create_batch_from_arrays( - (verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type), - arrow_cast=True, - safecheck=True, + lambda *a: ArrowBatchTransformer.zip_batches( + [(verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type)], ), ) @@ -580,8 +579,8 @@ def verify_element(elem): return elem def to_batch(res): - return ArrowBatchTransformer.create_batch_from_arrays( - (res, arrow_return_type), arrow_cast=True, safecheck=True + return ArrowBatchTransformer.zip_batches( + [(res, arrow_return_type)], ) return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) @@ -1145,8 +1144,8 @@ def wrapped(*series): result = func(*series) array = pa.array([result]) # Return RecordBatch directly instead of (array, type) tuple - return ArrowBatchTransformer.create_batch_from_arrays( - (array, arrow_return_type), arrow_cast=True, safecheck=True + return ArrowBatchTransformer.zip_batches( + [(array, arrow_return_type)], ) return ( @@ -1169,8 +1168,8 @@ def wrapped(batch_iter): result = func(batch_iter) array = pa.array([result]) # Return RecordBatch directly instead of (array, type) tuple - return ArrowBatchTransformer.create_batch_from_arrays( - (array, arrow_return_type), arrow_cast=True, safecheck=True + return ArrowBatchTransformer.zip_batches( + [(array, arrow_return_type)], ) return ( @@ -1297,8 +1296,8 @@ def wrapped(*series): return ( args_kwargs_offsets, - lambda *a: ArrowBatchTransformer.create_batch_from_arrays( - (wrapped(*a), arrow_return_type), arrow_cast=True, safecheck=True + lambda *a: ArrowBatchTransformer.zip_batches( + [(wrapped(*a), arrow_return_type)], ), ) @@ -1378,8 +1377,8 @@ def wrapped(begin_index, end_index, *series): return ( args_offsets[:2] + args_kwargs_offsets, - lambda *a: ArrowBatchTransformer.create_batch_from_arrays( - (wrapped(*a), arrow_return_type), arrow_cast=True, safecheck=True + lambda *a: ArrowBatchTransformer.zip_batches( + [(wrapped(*a), arrow_return_type)], ), ) @@ -2513,14 +2512,24 @@ def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): for batch in convert_to_arrow(func()): - # Wrap batch into struct for ArrowStreamUDTFSerializer.dump_stream - yield ArrowBatchTransformer.wrap_struct(batch) + # Handle empty batch: wrap it immediately for JVM + if batch.num_columns == 0: + yield ArrowBatchTransformer.wrap_struct(batch), arrow_return_type + else: + # Yield (batch, arrow_return_type) tuple for serializer + # Serializer will handle type coercion and wrapping + yield batch, arrow_return_type else: for row in zip(*args): for batch in convert_to_arrow(func(*row)): - # Wrap batch into struct for ArrowStreamUDTFSerializer.dump_stream - yield ArrowBatchTransformer.wrap_struct(batch) + # Handle empty batch: wrap it immediately for JVM + if batch.num_columns == 0: + yield ArrowBatchTransformer.wrap_struct(batch), arrow_return_type + else: + # Yield (batch, arrow_return_type) tuple for serializer + # Serializer will handle type coercion and wrapping + yield batch, arrow_return_type return evaluate @@ -3282,9 +3291,9 @@ def mapper(batches): # Need to materialize the first batch to get the keys first_batch = next(flattened) - keys = ArrowBatchTransformer.partial_batch(first_batch, parsed_offsets[0][0]) + keys = first_batch.select(parsed_offsets[0][0]) value_batches = ( - ArrowBatchTransformer.partial_batch(batch, parsed_offsets[0][1]) + batch.select(parsed_offsets[0][1]) for batch in itertools.chain((first_batch,), flattened) ) @@ -3368,10 +3377,10 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): - df1_keys = ArrowBatchTransformer.partial_table(a[0], parsed_offsets[0][0]) - df1_vals = ArrowBatchTransformer.partial_table(a[0], parsed_offsets[0][1]) - df2_keys = ArrowBatchTransformer.partial_table(a[1], parsed_offsets[1][0]) - df2_vals = ArrowBatchTransformer.partial_table(a[1], parsed_offsets[1][1]) + df1_keys = pa.Table.from_batches([batch.select(parsed_offsets[0][0]) for batch in a[0]]) + df1_vals = pa.Table.from_batches([batch.select(parsed_offsets[0][1]) for batch in a[0]]) + df2_keys = pa.Table.from_batches([batch.select(parsed_offsets[1][0]) for batch in a[1]]) + df2_vals = pa.Table.from_batches([batch.select(parsed_offsets[1][1]) for batch in a[1]]) return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: @@ -3436,7 +3445,7 @@ def mapper(a): ] # Merge RecordBatches from all UDFs horizontally - return ArrowBatchTransformer.merge_batches(result_batches) + return ArrowBatchTransformer.zip_batches(result_batches) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -3461,7 +3470,7 @@ def mapper(batch_iter): result_batches = [f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs] # Merge all RecordBatches horizontally - return ArrowBatchTransformer.merge_batches(result_batches) + return ArrowBatchTransformer.zip_batches(result_batches) else: import pyarrow as pa @@ -3475,7 +3484,7 @@ def mapper(a): f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs ) # Merge all RecordBatches into a single RecordBatch with multiple columns - return ArrowBatchTransformer.merge_batches(result_batches) + return ArrowBatchTransformer.zip_batches(result_batches) # For grouped/cogrouped map UDFs: # All wrappers yield batches, so mapper returns generator → need flatten From c0f084e973a270d00901af5065c19882988913ea Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 00:51:35 +0000 Subject: [PATCH 29/39] fix: correct pandas UDF struct handling and parameter initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit fixes several issues with pandas UDF handling after the serializer refactoring: 1. Fix parameter initialization order in read_udfs(): - Move pandas_udf_* parameter defaults BEFORE the if-elif chain - Previously they were reset AFTER being set, causing df_for_struct to always be False for scalar pandas UDFs 2. Add struct_in_pandas="dict" to scalar UDF wrappers: - wrap_scalar_pandas_udf: enables DataFrame→struct array conversion - wrap_pandas_batch_iter_udf: same fix for iter variant 3. Fix grouped map UDF column matching: - Use assign_cols_by_name to match DataFrame columns by name when available, otherwise by position - Handle empty DataFrame (0 columns) by creating empty struct array 4. Fix Arrow batch handling: - zip_batches: convert items to list for pa.RecordBatch.from_arrays - mapper: special handling for SQL_ARROW_BATCHED_UDF to return raw result instead of calling zip_batches 5. Fix error handling in create_array: - Only catch ArrowInvalid for arrow_cast retry (not ArrowTypeError) - Add ArrowTypeError to TypeError handler for proper error messages - Update error message format to match expected test output All 264+ pandas UDF tests pass including struct type tests. --- python/pyspark/sql/conversion.py | 12 +++-- python/pyspark/worker.py | 93 +++++++++++++++++++++++++++----- 2 files changed, 87 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 3713d95d6d3a..73088dc80fc2 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -275,7 +275,7 @@ def zip_batches( ] else: # Handle Arrays directly - all_columns = items + all_columns = list(items) # Create RecordBatch from columns return pa.RecordBatch.from_arrays( @@ -1747,14 +1747,16 @@ def create_array( try: try: return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) - except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError): + except pa.lib.ArrowInvalid: + # Only catch ArrowInvalid for arrow_cast, not ArrowTypeError + # ArrowTypeError should propagate to the TypeError handler if arrow_cast: return pa.Array.from_pandas(series, mask=mask).cast( target_type=arrow_type, safe=safecheck ) else: raise - except TypeError as e: + except (TypeError, pa.lib.ArrowTypeError) as e: error_msg = ( "Exception thrown when converting pandas.Series (%s) " "with name '%s' to Arrow Array (%s)." @@ -1767,7 +1769,9 @@ def create_array( ) if safecheck: error_msg += ( - " It can be disabled by using SQL config " + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " "`spark.sql.execution.pandas.convertToArrowArraySafely`." ) raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 91ef242bc605..d1f274c2e16d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -258,6 +258,8 @@ def verify_result_length(result, length): timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", ), ) @@ -483,6 +485,8 @@ def to_batch(res): timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", ) return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) @@ -871,16 +875,40 @@ def wrapped(key_series, value_batches): def flatten_wrapper(k, v): # Convert pandas DataFrame to Arrow RecordBatch and wrap as struct for JVM + import pyarrow as pa from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer for df in wrapped(k, v): - # Split DataFrame into list of (Series, arrow_type) tuples - series_list = [(df[col], arrow_return_type[i].type) for i, col in enumerate(df.columns)] + # Handle empty DataFrame (no columns) + if len(df.columns) == 0: + # Create empty struct array with correct schema + # Use generic field names (_0, _1, ...) to match wrap_struct behavior + struct_fields = [ + pa.field(f"_{i}", field.type) for i, field in enumerate(arrow_return_type) + ] + struct_arr = pa.array([{}] * len(df), pa.struct(struct_fields)) + batch = pa.RecordBatch.from_arrays([struct_arr], ["_0"]) + yield batch + continue + + # Build list of (Series, arrow_type) tuples + # If assign_cols_by_name and DataFrame has named columns, match by name + # Otherwise, match by position + if runner_conf.assign_cols_by_name and any(isinstance(name, str) for name in df.columns): + series_list = [ + (df[field.name], field.type) for field in arrow_return_type + ] + else: + series_list = [ + (df[df.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] batch = PandasBatchTransformer.to_arrow( series_list, timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, ) # Wrap columns as struct for JVM compatibility yield ArrowBatchTransformer.wrap_struct(batch) @@ -917,16 +945,40 @@ def verify_element(df): def flatten_wrapper(k, v): # Convert pandas DataFrames to Arrow RecordBatches and wrap as struct for JVM + import pyarrow as pa from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer for df in wrapped(k, v): - # Split DataFrame into list of (Series, arrow_type) tuples - series_list = [(df[col], arrow_return_type[i].type) for i, col in enumerate(df.columns)] + # Handle empty DataFrame (no columns) + if len(df.columns) == 0: + # Create empty struct array with correct schema + # Use generic field names (_0, _1, ...) to match wrap_struct behavior + struct_fields = [ + pa.field(f"_{i}", field.type) for i, field in enumerate(arrow_return_type) + ] + struct_arr = pa.array([{}] * len(df), pa.struct(struct_fields)) + batch = pa.RecordBatch.from_arrays([struct_arr], ["_0"]) + yield batch + continue + + # Build list of (Series, arrow_type) tuples + # If assign_cols_by_name and DataFrame has named columns, match by name + # Otherwise, match by position + if runner_conf.assign_cols_by_name and any(isinstance(name, str) for name in df.columns): + series_list = [ + (df[field.name], field.type) for field in arrow_return_type + ] + else: + series_list = [ + (df[df.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] batch = PandasBatchTransformer.to_arrow( series_list, timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, ) # Wrap columns as struct for JVM compatibility yield ArrowBatchTransformer.wrap_struct(batch) @@ -2791,6 +2843,13 @@ def mapper(_, it): def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = None key_schema = None + + # Initialize transformer parameters (will be set if needed for specific UDF types) + pandas_udf_input_type = None + pandas_udf_struct_in_pandas = "dict" + pandas_udf_ndarray_as_list = False + pandas_udf_df_for_struct = False + if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, @@ -2963,12 +3022,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) - # Initialize transformer parameters (will be set if needed) - pandas_udf_input_type = None - pandas_udf_struct_in_pandas = "dict" - pandas_udf_ndarray_as_list = False - pandas_udf_df_for_struct = False - # Read all UDFs num_udfs = read_int(infile) udfs = [ @@ -3478,13 +3531,25 @@ def mapper(batch_iter): # Check if we need to convert Arrow batches to pandas (for scalar pandas UDF) needs_arrow_to_pandas = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + # For SQL_ARROW_BATCHED_UDF, the wrapper returns (result, arrow_return_type, return_type) + # tuple that the serializer handles. For SQL_SCALAR_PANDAS_UDF and SQL_SCALAR_ARROW_UDF, + # the wrappers return RecordBatches that need to be zipped. + is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + def mapper(a): - # Each UDF returns a RecordBatch (single column) - result_batches = tuple( + result = tuple( f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs ) - # Merge all RecordBatches into a single RecordBatch with multiple columns - return ArrowBatchTransformer.zip_batches(result_batches) + # For arrow batched UDF, return raw result (serializer handles conversion) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + if is_arrow_batched_udf: + if len(result) == 1: + return result[0] + else: + return result + # For pandas/arrow scalar UDFs, each returns a RecordBatch - merge them + return ArrowBatchTransformer.zip_batches(result) # For grouped/cogrouped map UDFs: # All wrappers yield batches, so mapper returns generator → need flatten From bf3ba3fe748fd8eddf12da0d53961290981d3bd5 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 01:05:54 +0000 Subject: [PATCH 30/39] refactor: remove unused ArrowStreamGroupSerializer parameters Remove timezone, int_to_decimal_coercion_enabled, and assign_cols_by_name parameters that were stored but never used by the serializer or subclasses. --- python/pyspark/sql/conversion.py | 97 +++++++++++---------- python/pyspark/sql/pandas/serializers.py | 41 ++++----- python/pyspark/worker.py | 104 +++++++++-------------- 3 files changed, 105 insertions(+), 137 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 73088dc80fc2..d4fdae54f28c 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -62,11 +62,11 @@ class ArrowBatchTransformer: """ Pure functions that transform Arrow data structures (Arrays, RecordBatches). They should have no side effects (no I/O, no writing to streams). - + This class provides utility methods for Arrow batch transformations used throughout PySpark's Arrow UDF implementation. All methods are static and handle common patterns like struct wrapping/unwrapping, schema conversions, and creating RecordBatches from Arrays. - + """ @staticmethod @@ -112,29 +112,28 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) - @classmethod def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": """ Concatenate multiple RecordBatches into a single RecordBatch. - + This method handles both modern and legacy PyArrow versions. - + Parameters ---------- batches : List[pa.RecordBatch] List of RecordBatches with the same schema - + Returns ------- pa.RecordBatch Single RecordBatch containing all rows from input batches - + Used by ------- - SQL_GROUPED_AGG_ARROW_UDF mapper - SQL_WINDOW_AGG_ARROW_UDF mapper - + Examples -------- >>> import pyarrow as pa @@ -145,13 +144,13 @@ def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": {'a': [1, 2, 3, 4]} """ import pyarrow as pa - + if not batches: raise PySparkValueError( errorClass="INVALID_ARROW_BATCH_CONCAT", messageParameters={"reason": "Cannot concatenate empty list of batches"}, ) - + # Assert all batches have the same schema first_schema = batches[0].schema for i, batch in enumerate(batches[1:], start=1): @@ -165,7 +164,7 @@ def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": ) }, ) - + if hasattr(pa, "concat_batches"): return pa.concat_batches(batches) else: @@ -186,11 +185,11 @@ def zip_batches( ) -> "pa.RecordBatch": """ Zip multiple RecordBatches or Arrays horizontally by combining their columns. - + This is different from concat_batches which concatenates rows vertically. This method combines columns from multiple batches/arrays into a single batch, useful when multiple UDFs each produce a RecordBatch or when combining arrays. - + Parameters ---------- items : List[pa.RecordBatch], List[pa.Array], or List[Tuple[pa.Array, pa.DataType]] @@ -199,12 +198,12 @@ def zip_batches( - List of (array, type) tuples for type casting (always attempts cast if types don't match) safecheck : bool, default True If True, use safe casting (fails on overflow/truncation) (only used when items are tuples). - + Returns ------- pa.RecordBatch Single RecordBatch with all columns from input batches/arrays - + Used by ------- - SQL_GROUPED_AGG_ARROW_UDF mapper @@ -212,7 +211,7 @@ def zip_batches( - wrap_scalar_arrow_udf - wrap_grouped_agg_arrow_udf - ArrowBatchUDFSerializer.dump_stream - + Examples -------- >>> import pyarrow as pa @@ -240,7 +239,7 @@ def zip_batches( # Check if items are RecordBatches, Arrays, or (array, type) tuples first_item = items[0] - + if isinstance(first_item, pa.RecordBatch): # Handle RecordBatches batches = items @@ -278,9 +277,7 @@ def zip_batches( all_columns = list(items) # Create RecordBatch from columns - return pa.RecordBatch.from_arrays( - all_columns, ["_%d" % i for i in range(len(all_columns))] - ) + return pa.RecordBatch.from_arrays(all_columns, ["_%d" % i for i in range(len(all_columns))]) @classmethod def reorder_columns( @@ -288,10 +285,10 @@ def reorder_columns( ) -> "pa.RecordBatch": """ Reorder columns in a RecordBatch to match target schema field order. - + This method is useful when columns need to be arranged in a specific order for schema compatibility, particularly when assign_cols_by_name is enabled. - + Parameters ---------- batch : pa.RecordBatch @@ -299,18 +296,18 @@ def reorder_columns( target_schema : pa.StructType or pyspark.sql.types.StructType Target schema defining the desired column order. Can be either PyArrow StructType or Spark StructType. - + Returns ------- pa.RecordBatch New RecordBatch with columns reordered to match target schema - + Used by ------- - wrap_grouped_map_arrow_udf - wrap_grouped_map_arrow_iter_udf - wrap_cogrouped_map_arrow_udf - + Examples -------- >>> import pyarrow as pa @@ -328,17 +325,18 @@ def reorder_columns( ['a', 'b'] """ import pyarrow as pa - + # Convert Spark StructType to PyArrow StructType if needed - if hasattr(target_schema, 'fields') and hasattr(target_schema.fields[0], 'dataType'): + if hasattr(target_schema, "fields") and hasattr(target_schema.fields[0], "dataType"): # This is Spark StructType - convert to PyArrow from pyspark.sql.pandas.types import to_arrow_schema + arrow_schema = to_arrow_schema(target_schema) field_names = [field.name for field in arrow_schema] else: # This is PyArrow StructType field_names = [field.name for field in target_schema] - + return pa.RecordBatch.from_arrays( [batch.column(name) for name in field_names], names=field_names, @@ -433,7 +431,7 @@ def _cast_array( if arr.type == target_type: return arr - + try: return arr.cast(target_type=target_type, safe=safecheck) except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: @@ -444,14 +442,15 @@ def _cast_array( f"Arrow type mismatch. Expected: {target_type}, but got: {arr.type}." ) from e + class PandasBatchTransformer: """ Pure functions that transform between pandas DataFrames/Series and Arrow RecordBatches. They should have no side effects (no I/O, no writing to streams). - + This class provides utility methods for converting between pandas and Arrow formats, used primarily by Pandas UDF wrappers and serializers. - + """ @classmethod @@ -460,28 +459,28 @@ def concat_series_batches( ) -> List["pd.Series"]: """ Concatenate multiple batches of pandas Series column-wise. - + Takes a list of batches where each batch is a list of Series (one per column), and concatenates all Series column-by-column to produce a single list of concatenated Series. - + Parameters ---------- series_batches : List[List[pd.Series]] List of batches, each batch is a list of Series (one Series per column) arg_offsets : Optional[List[int]] If provided and series_batches is empty, determines the number of empty Series to create - + Returns ------- List[pd.Series] List of concatenated Series, one per column - + Used by ------- - SQL_GROUPED_AGG_PANDAS_UDF mapper - SQL_WINDOW_AGG_PANDAS_UDF mapper - + Examples -------- >>> import pandas as pd @@ -494,7 +493,7 @@ def concat_series_batches( [1, 2, 5, 6] """ import pandas as pd - + if not series_batches: # Empty batches - create empty Series if arg_offsets: @@ -502,7 +501,7 @@ def concat_series_batches( else: num_columns = 0 return [pd.Series(dtype=object) for _ in range(num_columns)] - + # Concatenate Series by column num_columns = len(series_batches[0]) return [ @@ -514,26 +513,26 @@ def concat_series_batches( def series_batches_to_dataframe(cls, series_batches) -> "pd.DataFrame": """ Convert an iterator of Series lists to a single DataFrame. - + Each batch is a list of Series (one per column). This method concatenates Series within each batch horizontally (axis=1), then concatenates all resulting DataFrames vertically (axis=0). - + Parameters ---------- series_batches : Iterator[List[pd.Series]] Iterator where each element is a list of Series representing one batch - + Returns ------- pd.DataFrame Combined DataFrame with all data, or empty DataFrame if no batches - + Used by ------- - wrap_grouped_map_pandas_udf - wrap_grouped_map_pandas_iter_udf - + Examples -------- >>> import pandas as pd @@ -546,10 +545,10 @@ def series_batches_to_dataframe(cls, series_batches) -> "pd.DataFrame": ['a', 'b'] """ import pandas as pd - + # Materialize iterator and convert each batch to DataFrame dataframes = [pd.concat(series_list, axis=1) for series_list in series_batches] - + # Concatenate all DataFrames vertically return pd.concat(dataframes, axis=0) if dataframes else pd.DataFrame() @@ -634,7 +633,9 @@ def to_arrow( else "Invalid return type. Please make sure that the UDF returns a " ) if not as_struct: - error_msg += "pandas.DataFrame when the specified return type is StructType." + error_msg += ( + "pandas.DataFrame when the specified return type is StructType." + ) else: error_msg += f"a pandas.DataFrame but got: {type(s)}" raise PySparkValueError(error_msg) @@ -1723,9 +1724,7 @@ def create_array( if error_class is not None: try: try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=safecheck - ) + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) except pa.lib.ArrowException: if arrow_cast: return pa.Array.from_pandas(series, mask=mask).cast( diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d17d2317a982..5a332f956cf6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -217,16 +217,16 @@ def __repr__(self): class ArrowStreamGroupSerializer(ArrowStreamSerializer): """ Unified serializer for Arrow stream operations with optional grouping support. - + This serializer handles: - Non-grouped operations: SQL_MAP_ARROW_ITER_UDF (num_dfs=0) - Grouped operations: SQL_GROUPED_MAP_ARROW_UDF, SQL_GROUPED_MAP_PANDAS_UDF (num_dfs=1) - Cogrouped operations: SQL_COGROUPED_MAP_ARROW_UDF, SQL_COGROUPED_MAP_PANDAS_UDF (num_dfs=2) - Grouped aggregations: SQL_GROUPED_AGG_ARROW_UDF, SQL_GROUPED_AGG_PANDAS_UDF (num_dfs=1) - + The serializer handles Arrow stream I/O and START signal, while transformation logic (flatten/wrap struct, pandas conversion) is handled by worker wrappers. - + Used by ------- - SQL_MAP_ARROW_ITER_UDF: DataFrame.mapInArrow() @@ -243,54 +243,41 @@ class ArrowStreamGroupSerializer(ArrowStreamSerializer): - SQL_WINDOW_AGG_ARROW_UDF: Window aggregation with arrow UDF - SQL_SCALAR_ARROW_UDF: Scalar arrow UDF - SQL_SCALAR_ARROW_ITER_UDF: Scalar arrow iter UDF - + Parameters ---------- - timezone : str, optional - Timezone for timestamp conversion (stored for compatibility) safecheck : bool, optional - Safecheck flag (stored for compatibility) - int_to_decimal_coercion_enabled : bool, optional - Decimal coercion flag (stored for compatibility) + Safecheck flag for ArrowBatchUDFSerializer subclass num_dfs : int, optional Number of DataFrames per group: - 0: Non-grouped mode (default) - yields all batches as single stream - 1: Grouped mode - yields one iterator of batches per group - 2: Cogrouped mode - yields tuple of two iterators per group - assign_cols_by_name : bool, optional - If True, assign DataFrame columns by name; otherwise by position (default: True) arrow_cast : bool, optional - Arrow cast flag (stored for compatibility, default: False) + Arrow cast flag for ArrowBatchUDFSerializer subclass (default: False) """ def __init__( self, - timezone=None, safecheck=None, - int_to_decimal_coercion_enabled: bool = False, num_dfs: int = 0, - assign_cols_by_name: bool = True, arrow_cast: bool = False, ): super().__init__() - # Store parameters for compatibility - self._timezone = timezone self._safecheck = safecheck - self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled self._num_dfs = num_dfs - self._assign_cols_by_name = assign_cols_by_name self._arrow_cast = arrow_cast def load_stream(self, stream): """ Deserialize Arrow record batches from stream. - + Returns ------- Iterator - num_dfs=0: Iterator[pa.RecordBatch] - all batches in a single stream - num_dfs=1: Iterator[Iterator[pa.RecordBatch]] - one iterator per group - - num_dfs=2: Iterator[Tuple[Iterator[pa.RecordBatch], Iterator[pa.RecordBatch]]] - + - num_dfs=2: Iterator[Tuple[Iterator[pa.RecordBatch], Iterator[pa.RecordBatch]]] - tuple of two iterators per cogrouped group """ if self._num_dfs > 0: @@ -310,11 +297,11 @@ def load_stream(self, stream): def dump_stream(self, iterator, stream): """ Serialize Arrow record batches to stream with START signal. - + The START_ARROW_STREAM marker is sent before the first batch to signal the JVM that Arrow data is about to be transmitted. This allows proper error handling during batch creation. - + Parameters ---------- iterator : Iterator[pa.RecordBatch] @@ -365,7 +352,9 @@ def apply_type_coercion(): # Batch is already wrapped into a struct column by worker (wrap_arrow_udtf) # Unwrap it first to access individual columns - if batch.num_columns == 1 and batch.column(0).type == pa.struct(list(arrow_return_type)): + if batch.num_columns == 1 and batch.column(0).type == pa.struct( + list(arrow_return_type) + ): # Batch is wrapped, unwrap it unwrapped_batch = ArrowBatchTransformer.flatten_struct(batch, column_index=0) elif batch.num_columns == 0: @@ -431,7 +420,7 @@ def apply_type_coercion(): class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ Serializer for UDFs that handles Arrow RecordBatch serialization. - + This is a thin wrapper around ArrowStreamSerializer kept for backward compatibility and future extensibility. Currently it doesn't override any methods - all conversion logic has been moved to wrappers/callers, and this serializer only handles pure @@ -462,6 +451,7 @@ def __init__( def __repr__(self): return "ArrowStreamUDFSerializer" + class ArrowBatchUDFSerializer(ArrowStreamGroupSerializer): """ Serializer used by Python worker to evaluate Arrow Python UDFs @@ -1092,6 +1082,7 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ + def flatten_iterator(): # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] for packed in iterator: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d1f274c2e16d..41e1b75a0658 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -591,10 +591,6 @@ def to_batch(res): def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - def verify_result(result): if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): raise PySparkTypeError( @@ -627,13 +623,13 @@ def wrapper(batch_iter): result = verify_result(f(batch_iter)) for res in map(verify_element, result): yield ArrowBatchTransformer.wrap_struct(res) - + return wrapper def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): import pyarrow as pa - + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -664,11 +660,11 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): from pyspark.sql.conversion import PandasBatchTransformer - + arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd @@ -684,19 +680,19 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se verify_pandas_result( result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) - + # Convert pandas DataFrame to Arrow RecordBatch and yield (consistent with Arrow cogrouped) batch = PandasBatchTransformer.to_arrow( [(result, arrow_return_type)], timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - struct_in_pandas='dict', + struct_in_pandas="dict", assign_cols_by_name=runner_conf.assign_cols_by_name, arrow_cast=True, ) yield batch - + return wrapped @@ -824,7 +820,7 @@ def wrapped(key_batch, value_batches): def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf): import pyarrow as pa - + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -894,10 +890,10 @@ def flatten_wrapper(k, v): # Build list of (Series, arrow_type) tuples # If assign_cols_by_name and DataFrame has named columns, match by name # Otherwise, match by position - if runner_conf.assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - series_list = [ - (df[field.name], field.type) for field in arrow_return_type - ] + if runner_conf.assign_cols_by_name and any( + isinstance(name, str) for name in df.columns + ): + series_list = [(df[field.name], field.type) for field in arrow_return_type] else: series_list = [ (df[df.columns[i]].rename(field.name), field.type) @@ -964,10 +960,10 @@ def flatten_wrapper(k, v): # Build list of (Series, arrow_type) tuples # If assign_cols_by_name and DataFrame has named columns, match by name # Otherwise, match by position - if runner_conf.assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - series_list = [ - (df[field.name], field.type) for field in arrow_return_type - ] + if runner_conf.assign_cols_by_name and any( + isinstance(name, str) for name in df.columns + ): + series_list = [(df[field.name], field.type) for field in arrow_return_type] else: series_list = [ (df[df.columns[i]].rename(field.name), field.type) @@ -1158,7 +1154,7 @@ def verify_element(result): def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): from pyspark.sql.conversion import PandasBatchTransformer - + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1170,7 +1166,7 @@ def wrapped(*series): result = func(*series) result_series = pd.Series([result]) - + # Convert to Arrow RecordBatch in wrapper return PandasBatchTransformer.to_arrow( [(result_series, arrow_return_type)], @@ -1232,7 +1228,7 @@ def wrapped(batch_iter): def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): from pyspark.sql.conversion import PandasBatchTransformer - + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1247,7 +1243,7 @@ def wrapped(series_iter): # This has already been adapted by the mapper function in read_udfs result = func(series_iter) result_series = pd.Series([result]) - + # Convert to Arrow RecordBatch in wrapper return PandasBatchTransformer.to_arrow( [(result_series, arrow_return_type)], @@ -1662,13 +1658,8 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - # UDTF uses ArrowStreamGroupSerializer with as_struct=True - ser = ArrowStreamGroupSerializer( - timezone=runner_conf.timezone, - safecheck=runner_conf.safecheck, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - ) + # UDTF uses ArrowStreamGroupSerializer + ser = ArrowStreamGroupSerializer(safecheck=runner_conf.safecheck) else: ser = ArrowStreamGroupSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: @@ -2895,17 +2886,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, ): - ser = ArrowStreamGroupSerializer( - num_dfs=1, assign_cols_by_name=runner_conf.assign_cols_by_name - ) + ser = ArrowStreamGroupSerializer(num_dfs=1) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): - ser = ArrowStreamGroupSerializer( - safecheck=True, arrow_cast=True, num_dfs=1 - ) + ser = ArrowStreamGroupSerializer(safecheck=True, arrow_cast=True, num_dfs=1) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, @@ -2914,10 +2901,8 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): # Use ArrowStreamGroupSerializer for agg/window UDFs that still use old pattern # load_stream returns raw batches, to_pandas conversion done in worker ser = ArrowStreamGroupSerializer( - timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, num_dfs=1, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type in ( PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, @@ -2925,13 +2910,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): ): # Wrapper calls to_arrow directly, so use MapIter serializer in grouped mode # dump_stream handles RecordBatch directly from wrapper - ser = ArrowStreamGroupSerializer( - num_dfs=1, assign_cols_by_name=runner_conf.assign_cols_by_name - ) + ser = ArrowStreamGroupSerializer(num_dfs=1) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: - ser = ArrowStreamGroupSerializer( - num_dfs=2, assign_cols_by_name=runner_conf.assign_cols_by_name - ) + ser = ArrowStreamGroupSerializer(num_dfs=2) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # Mapper calls to_arrow directly, so use Arrow serializer # load_stream returns raw batches, to_pandas conversion done in worker @@ -3014,9 +2995,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): ) ser = ArrowStreamGroupSerializer( - timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) @@ -3058,7 +3037,7 @@ def func(_, iterator): # For MAP_ARROW_ITER, flatten struct before processing if is_map_arrow_iter: iterator = map(ArrowBatchTransformer.flatten_struct, iterator) - + # Convert Arrow batches to pandas if needed if needs_arrow_to_pandas: iterator = map( @@ -3070,7 +3049,7 @@ def func(_, iterator): ndarray_as_list=pandas_udf_ndarray_as_list, df_for_struct=pandas_udf_df_for_struct, ), - iterator + iterator, ) num_input_rows = 0 @@ -3092,6 +3071,7 @@ def count_rows(batch): nonlocal num_input_rows num_input_rows += batch.num_rows return batch + iterator = map(count_rows, iterator) result_iter = udf(iterator) else: @@ -3182,9 +3162,7 @@ def extract_key_value_indexes(grouped_arg_offsets): def mapper(batch_iter): # Convert Arrow batches to pandas Series lists in worker layer series_iter = map( - lambda batch: ArrowBatchTransformer.to_pandas( - batch, timezone=runner_conf.timezone - ), + lambda batch: ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), batch_iter, ) @@ -3403,20 +3381,22 @@ def mapper(a): def mapper(a): # a is tuple[list[pa.RecordBatch], list[pa.RecordBatch]] left_batches, right_batches = a[0], a[1] - + # Convert batches to tables (batches already have flattened struct columns) left_table = pa.Table.from_batches(left_batches) if left_batches else pa.table({}) right_table = pa.Table.from_batches(right_batches) if right_batches else pa.table({}) - + # Convert tables to pandas Series lists left_series = ArrowBatchTransformer.to_pandas(left_table, timezone=runner_conf.timezone) - right_series = ArrowBatchTransformer.to_pandas(right_table, timezone=runner_conf.timezone) - + right_series = ArrowBatchTransformer.to_pandas( + right_table, timezone=runner_conf.timezone + ) + df1_keys = [left_series[o] for o in parsed_offsets[0][0]] df1_vals = [left_series[o] for o in parsed_offsets[0][1]] df2_keys = [right_series[o] for o in parsed_offsets[1][0]] df2_vals = [right_series[o] for o in parsed_offsets[1][1]] - + return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: @@ -3496,7 +3476,7 @@ def mapper(a): result_batches = [ f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs ] - + # Merge RecordBatches from all UDFs horizontally return ArrowBatchTransformer.zip_batches(result_batches) @@ -3521,7 +3501,7 @@ def mapper(batch_iter): # Each UDF returns pa.RecordBatch (conversion done in wrapper) result_batches = [f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs] - + # Merge all RecordBatches horizontally return ArrowBatchTransformer.zip_batches(result_batches) @@ -3537,9 +3517,7 @@ def mapper(batch_iter): is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF def mapper(a): - result = tuple( - f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs - ) + result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) # For arrow batched UDF, return raw result (serializer handles conversion) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. @@ -3580,7 +3558,7 @@ def func(_, it): ndarray_as_list=pandas_udf_ndarray_as_list, df_for_struct=pandas_udf_df_for_struct, ), - it + it, ) return map(mapper, it) From 9b34b0bce4fda5c2e02f52dad388688011349f9f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 01:55:52 +0000 Subject: [PATCH 31/39] refactor: extract common verification functions and centralize Arrow-to-pandas conversion - Extract common verification functions: verify_result_length, verify_result_type, verify_is_iterable, verify_element_type - Simplify wrapper functions using common verification utilities - Centralize Arrow-to-pandas conversion in read_udfs mapper/func - Remove unused pandas_udf_* variables from read_udfs - Fix is_scalar_pandas_iter to not convert Arrow iter UDFs to pandas --- python/pyspark/sql/pandas/serializers.py | 37 +-- python/pyspark/worker.py | 345 +++++++++-------------- 2 files changed, 131 insertions(+), 251 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5a332f956cf6..5d7a8f599935 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -227,46 +227,18 @@ class ArrowStreamGroupSerializer(ArrowStreamSerializer): The serializer handles Arrow stream I/O and START signal, while transformation logic (flatten/wrap struct, pandas conversion) is handled by worker wrappers. - Used by - ------- - - SQL_MAP_ARROW_ITER_UDF: DataFrame.mapInArrow() - - SQL_GROUPED_MAP_ARROW_UDF: GroupedData.applyInArrow() - - SQL_GROUPED_MAP_ARROW_ITER_UDF: GroupedData.applyInArrow() with iter - - SQL_GROUPED_MAP_PANDAS_UDF: GroupedData.apply() - - SQL_GROUPED_MAP_PANDAS_ITER_UDF: GroupedData.apply() with iter - - SQL_COGROUPED_MAP_ARROW_UDF: DataFrame.groupby().cogroup().applyInArrow() - - SQL_COGROUPED_MAP_PANDAS_UDF: DataFrame.groupby().cogroup().apply() - - SQL_GROUPED_AGG_ARROW_UDF: GroupedData.agg() with arrow UDF - - SQL_GROUPED_AGG_ARROW_ITER_UDF: GroupedData.agg() with arrow iter UDF - - SQL_GROUPED_AGG_PANDAS_UDF: GroupedData.agg() with pandas UDF - - SQL_GROUPED_AGG_PANDAS_ITER_UDF: GroupedData.agg() with pandas iter UDF - - SQL_WINDOW_AGG_ARROW_UDF: Window aggregation with arrow UDF - - SQL_SCALAR_ARROW_UDF: Scalar arrow UDF - - SQL_SCALAR_ARROW_ITER_UDF: Scalar arrow iter UDF - Parameters ---------- - safecheck : bool, optional - Safecheck flag for ArrowBatchUDFSerializer subclass num_dfs : int, optional Number of DataFrames per group: - 0: Non-grouped mode (default) - yields all batches as single stream - 1: Grouped mode - yields one iterator of batches per group - 2: Cogrouped mode - yields tuple of two iterators per group - arrow_cast : bool, optional - Arrow cast flag for ArrowBatchUDFSerializer subclass (default: False) """ - def __init__( - self, - safecheck=None, - num_dfs: int = 0, - arrow_cast: bool = False, - ): + def __init__(self, num_dfs: int = 0): super().__init__() - self._safecheck = safecheck self._num_dfs = num_dfs - self._arrow_cast = arrow_cast def load_stream(self, stream): """ @@ -478,12 +450,9 @@ def __init__( int_to_decimal_coercion_enabled: bool, binary_as_bytes: bool, ): - super().__init__( - safecheck=safecheck, - arrow_cast=True, - num_dfs=0, - ) + super().__init__(num_dfs=0) assert isinstance(input_type, StructType) + self._safecheck = safecheck self._input_type = input_type self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled self._binary_as_bytes = binary_as_bytes diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 41e1b75a0658..a18e85517dbd 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -220,41 +220,81 @@ def wrap_udf(f, args_offsets, kwargs_offsets, return_type): return args_kwargs_offsets, lambda *a: func(*a) +# ============================================================ +# Common verification functions for UDF wrappers +# ============================================================ + + +def verify_result_length(result, expected_length, udf_type): + """Verify that the result length matches the expected length.""" + if len(result) != expected_length: + raise PySparkRuntimeError( + errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + messageParameters={ + "udf_type": udf_type, + "expected": str(expected_length), + "actual": str(len(result)), + }, + ) + return result + + +def verify_result_type(result, expected_type_name): + """Verify that the result has __len__ attribute (is array-like).""" + if not hasattr(result, "__len__"): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": type(result).__name__, + }, + ) + return result + + +def verify_is_iterable(result, expected_type_name): + """Verify that the result is an iterator or iterable.""" + if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": type(result).__name__, + }, + ) + return result + + +def verify_element_type(elem, expected_class, expected_type_name): + """Verify that an element has the expected type.""" + if not isinstance(elem, expected_class): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": "iterator of {}".format(type(elem).__name__), + }, + ) + return elem + + def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - messageParameters={ - "udf_type": "pandas_udf", - "expected": str(length), - "actual": str(len(result)), - }, - ) + def verify_and_convert(result, length): + verify_result_type(result, pd_type) + verify_result_length(result, length, "pandas_udf") return result return ( args_kwargs_offsets, lambda *a: PandasBatchTransformer.to_arrow( - [(verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type)], + [(verify_and_convert(func(*a), len(a[0])), arrow_return_type)], timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, @@ -271,34 +311,15 @@ def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_c return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pyarrow.Array" - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - messageParameters={ - "udf_type": "arrow_udf", - "expected": str(length), - "actual": str(len(result)), - }, - ) + def verify_and_convert(result, length): + verify_result_type(result, "pyarrow.Array") + verify_result_length(result, length, "arrow_udf") return result return ( args_kwargs_offsets, lambda *a: ArrowBatchTransformer.zip_batches( - [(verify_result_length(verify_result_type(func(*a)), len(a[0])), arrow_return_type)], + [(verify_and_convert(func(*a), len(a[0])), arrow_return_type)], ), ) @@ -445,38 +466,19 @@ def verify_result_length(result, length): def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): + import pandas as pd + arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of {}".format(iter_type_label), - "actual": type(result).__name__, - }, - ) - return result + expected_class = pd.DataFrame if type(return_type) == StructType else pd.Series def verify_element(elem): - import pandas as pd - - if not isinstance(elem, pd.DataFrame if type(return_type) == StructType else pd.Series): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of {}".format(iter_type_label), - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) - + verify_element_type(elem, expected_class, "iterator of {}".format(iter_type_label)) verify_pandas_result( elem, return_type, assign_cols_by_name=True, truncate_return_schema=True ) - return elem def to_batch(res): @@ -489,7 +491,12 @@ def to_batch(res): struct_in_pandas="dict", ) - return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) + def wrapped(*iterator): + result = f(*iterator) + verify_is_iterable(result, "iterator of {}".format(iter_type_label)) + return map(to_batch, map(verify_element, result)) + + return wrapped def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema): @@ -553,75 +560,40 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_arrow_array_iter_udf(f, return_type, runner_conf): + import pyarrow as pa + arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.Array", - "actual": type(result).__name__, - }, - ) - return result - - def verify_element(elem): - import pyarrow as pa - - if not isinstance(elem, pa.Array): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.Array", - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) - - return elem + iter_type = "iterator of pyarrow.Array" def to_batch(res): - return ArrowBatchTransformer.zip_batches( - [(res, arrow_return_type)], + return ArrowBatchTransformer.zip_batches([(res, arrow_return_type)]) + + def wrapped(*iterator): + result = f(*iterator) + verify_is_iterable(result, iter_type) + return map( + to_batch, + map(lambda elem: verify_element_type(elem, pa.Array, iter_type), result), ) - return lambda *iterator: map(to_batch, map(verify_element, verify_result(f(*iterator)))) + return wrapped def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.RecordBatch", - "actual": type(result).__name__, - }, - ) - return result - - def verify_element(elem): - import pyarrow as pa - - if not isinstance(elem, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.RecordBatch", - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) + import pyarrow as pa - return elem + iter_type = "iterator of pyarrow.RecordBatch" # For mapInArrow: batches are already flattened in func before calling wrapper # User function receives flattened batches and returns flattened batches # We just need to wrap results back into struct for serialization def wrapper(batch_iter): - result = verify_result(f(batch_iter)) - for res in map(verify_element, result): + result = f(batch_iter) + verify_is_iterable(result, iter_type) + for res in result: + verify_element_type(res, pa.RecordBatch, iter_type) yield ArrowBatchTransformer.wrap_struct(res) return wrapper @@ -1659,7 +1631,7 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: # UDTF uses ArrowStreamGroupSerializer - ser = ArrowStreamGroupSerializer(safecheck=runner_conf.safecheck) + ser = ArrowStreamGroupSerializer() else: ser = ArrowStreamGroupSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: @@ -2835,11 +2807,11 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = None key_schema = None - # Initialize transformer parameters (will be set if needed for specific UDF types) - pandas_udf_input_type = None - pandas_udf_struct_in_pandas = "dict" - pandas_udf_ndarray_as_list = False - pandas_udf_df_for_struct = False + # Check if legacy arrow batched UDF mode (needs to read input type from stream) + is_legacy_arrow_batched_udf = ( + eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + and runner_conf.use_legacy_pandas_udf_conversion + ) if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, @@ -2881,47 +2853,26 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = utf8_deserializer.loads(infile) key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + # Grouped UDFs: num_dfs=1 (single DataFrame per group) if eval_type in ( PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, - ): - ser = ArrowStreamGroupSerializer(num_dfs=1) - elif eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, - PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, - ): - ser = ArrowStreamGroupSerializer(safecheck=True, arrow_cast=True, num_dfs=1) - elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, + PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): - # Use ArrowStreamGroupSerializer for agg/window UDFs that still use old pattern - # load_stream returns raw batches, to_pandas conversion done in worker - ser = ArrowStreamGroupSerializer( - safecheck=runner_conf.safecheck, - num_dfs=1, - ) + ser = ArrowStreamGroupSerializer(num_dfs=1) + # Cogrouped UDFs: num_dfs=2 (two DataFrames per group) elif eval_type in ( - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, ): - # Wrapper calls to_arrow directly, so use MapIter serializer in grouped mode - # dump_stream handles RecordBatch directly from wrapper - ser = ArrowStreamGroupSerializer(num_dfs=1) - elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: ser = ArrowStreamGroupSerializer(num_dfs=2) - elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - # Mapper calls to_arrow directly, so use Arrow serializer - # load_stream returns raw batches, to_pandas conversion done in worker - # dump_stream handles RecordBatch directly from mapper - ser = ArrowStreamGroupSerializer( - safecheck=runner_conf.safecheck, - arrow_cast=True, - num_dfs=2, - ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: ser = ApplyInPandasWithStateSerializer( runner_conf.timezone, @@ -2956,14 +2907,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): ser = TransformWithStateInPySparkRowInitStateSerializer( runner_conf.arrow_max_records_per_batch ) - elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - ser = ArrowStreamGroupSerializer() - elif eval_type in ( - PythonEvalType.SQL_SCALAR_ARROW_UDF, - PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, - ): - # Arrow cast and safe check are always enabled - ser = ArrowStreamGroupSerializer(safecheck=True, arrow_cast=True) + # SQL_ARROW_BATCHED_UDF with new conversion uses ArrowBatchUDFSerializer elif ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion @@ -2976,27 +2920,15 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): runner_conf.binary_as_bytes, ) else: - # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of - # pandas Series. See SPARK-27240. - pandas_udf_df_for_struct = eval_type in ( - PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - ) - # Arrow-optimized Python UDF takes a struct type argument as a Row - is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - pandas_udf_struct_in_pandas = "row" if is_arrow_batched_udf else "dict" - pandas_udf_ndarray_as_list = is_arrow_batched_udf - # Arrow-optimized Python UDF takes input types - pandas_udf_input_type = ( + # All other UDFs use ArrowStreamGroupSerializer with num_dfs=0 + # This includes: SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, + # SQL_MAP_PANDAS_ITER_UDF, SQL_MAP_ARROW_ITER_UDF, SQL_SCALAR_ARROW_UDF, + # SQL_SCALAR_ARROW_ITER_UDF, SQL_ARROW_BATCHED_UDF (legacy mode) + if is_legacy_arrow_batched_udf: + # Read input type from stream (required for protocol), but conversion + # is handled in wrap_arrow_batch_udf_legacy _parse_datatype_json_string(utf8_deserializer.loads(infile)) - if is_arrow_batched_udf - else None - ) - - ser = ArrowStreamGroupSerializer( - safecheck=runner_conf.safecheck, - ) + ser = ArrowStreamGroupSerializer() else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) @@ -3012,16 +2944,10 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, ) + is_scalar_pandas_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF - # Check if we need to convert Arrow batches to pandas - needs_arrow_to_pandas = eval_type in ( - PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - ) - if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: # TODO: Better error message for num_udfs != 1 if is_scalar_iter: @@ -3038,16 +2964,11 @@ def func(_, iterator): if is_map_arrow_iter: iterator = map(ArrowBatchTransformer.flatten_struct, iterator) - # Convert Arrow batches to pandas if needed - if needs_arrow_to_pandas: + # For pandas iter UDFs, convert Arrow batches to pandas DataFrames + if is_scalar_pandas_iter or is_map_pandas_iter: iterator = map( lambda batch: ArrowBatchTransformer.to_pandas( - batch, - timezone=runner_conf.timezone, - schema=pandas_udf_input_type, - struct_in_pandas=pandas_udf_struct_in_pandas, - ndarray_as_list=pandas_udf_ndarray_as_list, - df_for_struct=pandas_udf_df_for_struct, + batch, timezone=runner_conf.timezone, df_for_struct=True ), iterator, ) @@ -3508,15 +3429,18 @@ def mapper(batch_iter): else: import pyarrow as pa - # Check if we need to convert Arrow batches to pandas (for scalar pandas UDF) - needs_arrow_to_pandas = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - # For SQL_ARROW_BATCHED_UDF, the wrapper returns (result, arrow_return_type, return_type) # tuple that the serializer handles. For SQL_SCALAR_PANDAS_UDF and SQL_SCALAR_ARROW_UDF, # the wrappers return RecordBatches that need to be zipped. is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + is_scalar_pandas_udf = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF def mapper(a): + # For scalar pandas UDF, convert Arrow batch to pandas first + if is_scalar_pandas_udf: + a = ArrowBatchTransformer.to_pandas( + a, timezone=runner_conf.timezone, df_for_struct=True + ) result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) # For arrow batched UDF, return raw result (serializer handles conversion) # In the special case of a single UDF this will return a single result rather @@ -3547,19 +3471,6 @@ def func(_, it): else: def func(_, it): - # Convert Arrow batches to pandas if needed - if needs_arrow_to_pandas: - it = map( - lambda batch: ArrowBatchTransformer.to_pandas( - batch, - timezone=runner_conf.timezone, - schema=pandas_udf_input_type, - struct_in_pandas=pandas_udf_struct_in_pandas, - ndarray_as_list=pandas_udf_ndarray_as_list, - df_for_struct=pandas_udf_df_for_struct, - ), - it, - ) return map(mapper, it) # profiling is not supported for UDF From afa0a2599d67bb19c86597e27eb76e906b829ad8 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 02:41:08 +0000 Subject: [PATCH 32/39] refactor: simplify UDF wrappers and data flow - Separate iter UDF branches by type for clarity (scalar pandas/arrow, map pandas, map arrow) - Wrapper functions now return (result, arrow_return_type) tuples, with output conversion centralized in mapper/func - Grouped map pandas UDFs now receive Iterator[DataFrame] directly instead of Iterator[List[Series]] - Inline concat_series_batches and series_batches_to_dataframe methods --- python/pyspark/sql/conversion.py | 99 -------- python/pyspark/worker.py | 413 ++++++++++++++++++------------- 2 files changed, 238 insertions(+), 274 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index d4fdae54f28c..695a383e840f 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -453,105 +453,6 @@ class PandasBatchTransformer: """ - @classmethod - def concat_series_batches( - cls, series_batches: List[List["pd.Series"]], arg_offsets: Optional[List[int]] = None - ) -> List["pd.Series"]: - """ - Concatenate multiple batches of pandas Series column-wise. - - Takes a list of batches where each batch is a list of Series (one per column), - and concatenates all Series column-by-column to produce a single list of - concatenated Series. - - Parameters - ---------- - series_batches : List[List[pd.Series]] - List of batches, each batch is a list of Series (one Series per column) - arg_offsets : Optional[List[int]] - If provided and series_batches is empty, determines the number of empty Series to create - - Returns - ------- - List[pd.Series] - List of concatenated Series, one per column - - Used by - ------- - - SQL_GROUPED_AGG_PANDAS_UDF mapper - - SQL_WINDOW_AGG_PANDAS_UDF mapper - - Examples - -------- - >>> import pandas as pd - >>> batch1 = [pd.Series([1, 2]), pd.Series([3, 4])] - >>> batch2 = [pd.Series([5, 6]), pd.Series([7, 8])] - >>> result = PandasBatchTransformer.concat_series_batches([batch1, batch2]) - >>> len(result) - 2 - >>> result[0].tolist() - [1, 2, 5, 6] - """ - import pandas as pd - - if not series_batches: - # Empty batches - create empty Series - if arg_offsets: - num_columns = max(arg_offsets) + 1 if arg_offsets else 0 - else: - num_columns = 0 - return [pd.Series(dtype=object) for _ in range(num_columns)] - - # Concatenate Series by column - num_columns = len(series_batches[0]) - return [ - pd.concat([batch[i] for batch in series_batches], ignore_index=True) - for i in range(num_columns) - ] - - @classmethod - def series_batches_to_dataframe(cls, series_batches) -> "pd.DataFrame": - """ - Convert an iterator of Series lists to a single DataFrame. - - Each batch is a list of Series (one per column). This method concatenates - Series within each batch horizontally (axis=1), then concatenates all - resulting DataFrames vertically (axis=0). - - Parameters - ---------- - series_batches : Iterator[List[pd.Series]] - Iterator where each element is a list of Series representing one batch - - Returns - ------- - pd.DataFrame - Combined DataFrame with all data, or empty DataFrame if no batches - - Used by - ------- - - wrap_grouped_map_pandas_udf - - wrap_grouped_map_pandas_iter_udf - - Examples - -------- - >>> import pandas as pd - >>> batch1 = [pd.Series([1, 2], name='a'), pd.Series([3, 4], name='b')] - >>> batch2 = [pd.Series([5, 6], name='a'), pd.Series([7, 8], name='b')] - >>> df = PandasBatchTransformer.series_batches_to_dataframe([batch1, batch2]) - >>> df.shape - (4, 2) - >>> df.columns.tolist() - ['a', 'b'] - """ - import pandas as pd - - # Materialize iterator and convert each batch to DataFrame - dataframes = [pd.concat(series_list, axis=1) for series_list in series_batches] - - # Concatenate all DataFrames vertically - return pd.concat(dataframes, axis=0) if dataframes else pd.DataFrame() - @classmethod def to_arrow( cls, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a18e85517dbd..1639285ff6c3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -279,6 +279,14 @@ def verify_element_type(elem, expected_class, expected_type_name): def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + """Wrap a scalar pandas UDF. + + The wrapper only validates and calls the user function. + Output conversion (pandas → Arrow) is done in the mapper. + + Returns: + (arg_offsets, wrapped_func) where wrapped_func returns (result, arrow_return_type) + """ func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -286,42 +294,37 @@ def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_ ) pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - def verify_and_convert(result, length): + def wrapped(*args): + result = func(*args) verify_result_type(result, pd_type) - verify_result_length(result, length, "pandas_udf") - return result + verify_result_length(result, len(args[0]), "pandas_udf") + return (result, arrow_return_type) - return ( - args_kwargs_offsets, - lambda *a: PandasBatchTransformer.to_arrow( - [(verify_and_convert(func(*a), len(a[0])), arrow_return_type)], - timezone=runner_conf.timezone, - safecheck=runner_conf.safecheck, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - arrow_cast=True, - struct_in_pandas="dict", - ), - ) + return (args_kwargs_offsets, wrapped) def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + """Wrap a scalar arrow UDF. + + The wrapper only validates and calls the user function. + Output conversion is done in the mapper. + + Returns: + (arg_offsets, wrapped_func) where wrapped_func returns (result, arrow_return_type) + """ func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - def verify_and_convert(result, length): + def wrapped(*args): + result = func(*args) verify_result_type(result, "pyarrow.Array") - verify_result_length(result, length, "arrow_udf") - return result + verify_result_length(result, len(args[0]), "arrow_udf") + return (result, arrow_return_type) - return ( - args_kwargs_offsets, - lambda *a: ArrowBatchTransformer.zip_batches( - [(verify_and_convert(func(*a), len(a[0])), arrow_return_type)], - ), - ) + return (args_kwargs_offsets, wrapped) def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): @@ -466,6 +469,14 @@ def verify_result_length(result, length): def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): + """Wrap a pandas batch iter UDF. + + The wrapper only validates and calls the user function. + Output conversion (pandas → Arrow) is done in func. + + Returns: + wrapped_func that yields (result, arrow_return_type) tuples + """ import pandas as pd arrow_return_type = to_arrow_type( @@ -481,20 +492,11 @@ def verify_element(elem): ) return elem - def to_batch(res): - return PandasBatchTransformer.to_arrow( - [(res, arrow_return_type)], - timezone=runner_conf.timezone, - safecheck=runner_conf.safecheck, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - arrow_cast=True, - struct_in_pandas="dict", - ) - def wrapped(*iterator): result = f(*iterator) verify_is_iterable(result, "iterator of {}".format(iter_type_label)) - return map(to_batch, map(verify_element, result)) + # Yield (result, arrow_type) pairs - output conversion done in func + return ((verify_element(elem), arrow_return_type) for elem in result) return wrapped @@ -560,6 +562,14 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_arrow_array_iter_udf(f, return_type, runner_conf): + """Wrap an arrow array iter UDF. + + The wrapper only validates and calls the user function. + Output conversion is done in func. + + Returns: + wrapped_func that yields (result, arrow_return_type) tuples + """ import pyarrow as pa arrow_return_type = to_arrow_type( @@ -567,15 +577,12 @@ def wrap_arrow_array_iter_udf(f, return_type, runner_conf): ) iter_type = "iterator of pyarrow.Array" - def to_batch(res): - return ArrowBatchTransformer.zip_batches([(res, arrow_return_type)]) - def wrapped(*iterator): result = f(*iterator) verify_is_iterable(result, iter_type) - return map( - to_batch, - map(lambda elem: verify_element_type(elem, pa.Array, iter_type), result), + # Yield (result, arrow_type) pairs - output conversion done in func + return ( + (verify_element_type(elem, pa.Array, iter_type), arrow_return_type) for elem in result ) return wrapped @@ -819,10 +826,12 @@ def wrapped(key_batch, value_batches): def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - def wrapped(key_series, value_batches): + def wrapped(key_series, value_dfs): import pandas as pd - value_df = PandasBatchTransformer.series_batches_to_dataframe(value_batches) + # Concat all DataFrames in the group vertically + dfs = list(value_dfs) + value_df = pd.concat(dfs, axis=0) if dfs else pd.DataFrame() if len(argspec.args) == 1: result = f(value_df) @@ -885,19 +894,14 @@ def flatten_wrapper(k, v): def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): - def wrapped(key_series, value_batches): - import pandas as pd - - # value_batches is an Iterator[List[pd.Series]] (one list per batch) - # Convert each list of Series into a DataFrame - dataframe_iter = map(lambda series_list: pd.concat(series_list, axis=1), value_batches) - + def wrapped(key_series, value_dfs): + # value_dfs is already Iterator[pd.DataFrame] from mapper if len(argspec.args) == 1: - result = f(dataframe_iter) + result = f(value_dfs) elif len(argspec.args) == 2: # Extract key from pandas Series, preserving numpy types key = tuple(s.iloc[0] for s in key_series) - result = f(key, dataframe_iter) + result = f(key, value_dfs) def verify_element(df): verify_pandas_result( @@ -2940,102 +2944,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): for i in range(num_udfs) ] - is_scalar_iter = eval_type in ( - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, - ) - is_scalar_pandas_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF - - if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: - # TODO: Better error message for num_udfs != 1 - if is_scalar_iter: - assert num_udfs == 1, "One SCALAR_ITER UDF expected here." - if is_map_pandas_iter: - assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." - if is_map_arrow_iter: - assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." - - arg_offsets, udf = udfs[0] - - def func(_, iterator): - # For MAP_ARROW_ITER, flatten struct before processing - if is_map_arrow_iter: - iterator = map(ArrowBatchTransformer.flatten_struct, iterator) - - # For pandas iter UDFs, convert Arrow batches to pandas DataFrames - if is_scalar_pandas_iter or is_map_pandas_iter: - iterator = map( - lambda batch: ArrowBatchTransformer.to_pandas( - batch, timezone=runner_conf.timezone, df_for_struct=True - ), - iterator, - ) - - num_input_rows = 0 - - def map_batch(batch): - nonlocal num_input_rows - - udf_args = [batch[offset] for offset in arg_offsets] - num_input_rows += len(udf_args[0]) - if len(udf_args) == 1: - return udf_args[0] - else: - return tuple(udf_args) - - # For MAP_ARROW_ITER, pass the whole batch to UDF, not extracted columns - if is_map_arrow_iter: - # Count input rows for verification - def count_rows(batch): - nonlocal num_input_rows - num_input_rows += batch.num_rows - return batch - - iterator = map(count_rows, iterator) - result_iter = udf(iterator) - else: - iterator = map(map_batch, iterator) - result_iter = udf(iterator) - - num_output_rows = 0 - for result_batch in result_iter: - num_output_rows += len(result_batch.column(0)) - # This check is for Scalar Iterator UDF to fail fast. - # The length of the entire input can only be explicitly known - # by consuming the input iterator in user side. Therefore, - # it's very unlikely the output length is higher than - # input length. - if is_scalar_iter and num_output_rows > num_input_rows: - raise PySparkRuntimeError( - errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} - ) - yield result_batch - - if is_scalar_iter: - try: - next(iterator) - except StopIteration: - pass - else: - raise PySparkRuntimeError( - errorClass="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF", - messageParameters={}, - ) - - if num_output_rows != num_input_rows: - raise PySparkRuntimeError( - errorClass="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF", - messageParameters={ - "output_length": str(num_output_rows), - "input_length": str(num_input_rows), - }, - ) - - # profiling is not supported for UDF - return func, None, ser, ser - def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and @@ -3065,6 +2973,124 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, + ): + # Scalar iterator UDFs: require input/output length matching + assert num_udfs == 1, "One SCALAR_ITER UDF expected here." + arg_offsets, udf = udfs[0] + needs_pandas = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + + def input_mapper(batch): + if needs_pandas: + return ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True + ) + return batch + + def output_mapper(item): + result, arrow_return_type = item + if needs_pandas: + return PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + return ArrowBatchTransformer.zip_batches([(result, arrow_return_type)]) + + def func(_, iterator): + num_input_rows = 0 + + def mapper(batch): + nonlocal num_input_rows + converted = input_mapper(batch) + udf_args = [converted[offset] for offset in arg_offsets] + num_input_rows += len(udf_args[0]) + return udf_args[0] if len(udf_args) == 1 else tuple(udf_args) + + iterator = map(mapper, iterator) + result_iter = udf(iterator) + + num_output_rows = 0 + for item in result_iter: + result_batch = output_mapper(item) + num_output_rows += len(result_batch.column(0)) + if num_output_rows > num_input_rows: + raise PySparkRuntimeError( + errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} + ) + yield result_batch + + # Verify consumed all input and output length matches + try: + next(iterator) + except StopIteration: + pass + else: + raise PySparkRuntimeError( + errorClass="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF", + messageParameters={}, + ) + if num_output_rows != num_input_rows: + raise PySparkRuntimeError( + errorClass="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF", + messageParameters={ + "output_length": str(num_output_rows), + "input_length": str(num_input_rows), + }, + ) + + return func, None, ser, ser + + elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: + # Map pandas iterator UDF: no length verification needed + assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + arg_offsets, udf = udfs[0] + + def input_mapper(batch): + return ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True + ) + + def output_mapper(item): + result, arrow_return_type = item + return PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + + def func(_, iterator): + def mapper(batch): + converted = input_mapper(batch) + udf_args = [converted[offset] for offset in arg_offsets] + return udf_args[0] if len(udf_args) == 1 else tuple(udf_args) + + iterator = map(mapper, iterator) + result_iter = udf(iterator) + return map(output_mapper, result_iter) + + return func, None, ser, ser + + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: + # Map arrow iterator UDF (mapInArrow): wrapper handles conversion + assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." + _, udf = udfs[0] + + def func(_, iterator): + # Flatten struct → call wrapper → wrapper returns RecordBatch with wrap_struct + iterator = map(ArrowBatchTransformer.flatten_struct, iterator) + return udf(iterator) + + return func, None, ser, ser + if ( eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF @@ -3081,24 +3107,27 @@ def extract_key_value_indexes(grouped_arg_offsets): parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(batch_iter): - # Convert Arrow batches to pandas Series lists in worker layer - series_iter = map( - lambda batch: ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), + import pandas as pd + + # Convert Arrow batches to pandas DataFrames + df_iter = map( + lambda batch: pd.concat( + ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), axis=1 + ), batch_iter, ) # Materialize first batch to extract keys (keys are same for all batches in group) - first_series_list = next(series_iter) - key_series = [first_series_list[o] for o in parsed_offsets[0][0]] + first_df = next(df_iter) + key_series = [first_df.iloc[:, o] for o in parsed_offsets[0][0]] - # Create generator for value Series: yields List[pd.Series] per batch - value_series_gen = ( - [series_list[o] for o in parsed_offsets[0][1]] - for series_list in itertools.chain((first_series_list,), series_iter) + # Create generator for value DataFrames: yields DataFrame per batch + value_df_gen = ( + df.iloc[:, parsed_offsets[0][1]] for df in itertools.chain((first_df,), df_iter) ) # Wrapper yields wrapped RecordBatches (one or more per group) - yield from f(key_series, value_series_gen) + yield from f(key_series, value_df_gen) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't @@ -3411,14 +3440,24 @@ def mapper(a): # batch_iter is now Iterator[pa.RecordBatch] (raw batches from serializer) # Convert to pandas and concatenate into single Series per column def mapper(batch_iter): + import pandas as pd + # Convert raw Arrow batches to pandas Series in worker layer series_batches = [ ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone) for batch in batch_iter ] - # Concatenate all batches column-wise - all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] - concatenated = PandasBatchTransformer.concat_series_batches(series_batches, all_offsets) + # Concatenate all batches column-wise into single Series per column + if series_batches: + num_columns = len(series_batches[0]) + concatenated = [ + pd.concat([batch[i] for batch in series_batches], ignore_index=True) + for i in range(num_columns) + ] + else: + all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] + num_columns = max(all_offsets) + 1 if all_offsets else 0 + concatenated = [pd.Series(dtype=object) for _ in range(num_columns)] # Each UDF returns pa.RecordBatch (conversion done in wrapper) result_batches = [f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs] @@ -3429,29 +3468,53 @@ def mapper(batch_iter): else: import pyarrow as pa - # For SQL_ARROW_BATCHED_UDF, the wrapper returns (result, arrow_return_type, return_type) - # tuple that the serializer handles. For SQL_SCALAR_PANDAS_UDF and SQL_SCALAR_ARROW_UDF, - # the wrappers return RecordBatches that need to be zipped. + # Determine UDF type for appropriate handling is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF is_scalar_pandas_udf = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + is_scalar_arrow_udf = eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF - def mapper(a): - # For scalar pandas UDF, convert Arrow batch to pandas first + def mapper(batch): + """Process a single batch through all UDFs. + + For scalar pandas/arrow UDFs: + 1. Input conversion: Arrow → Pandas (for pandas UDF only) + 2. Call each wrapper, collect (result, arrow_type) pairs + 3. Output conversion: Pandas/Arrow → Arrow RecordBatch + """ + # Step 1: Input conversion (Arrow → Pandas for pandas UDF) if is_scalar_pandas_udf: - a = ArrowBatchTransformer.to_pandas( - a, timezone=runner_conf.timezone, df_for_struct=True + data = ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True ) - result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) + else: + data = batch + + # Step 2: Call each wrapper, collect results + # Wrappers return (result, arrow_return_type) tuples + results_with_types = [f(*[data[o] for o in arg_offsets]) for arg_offsets, f in udfs] + + # Step 3: Output conversion # For arrow batched UDF, return raw result (serializer handles conversion) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. if is_arrow_batched_udf: - if len(result) == 1: - return result[0] + if len(results_with_types) == 1: + return results_with_types[0] else: - return result - # For pandas/arrow scalar UDFs, each returns a RecordBatch - merge them - return ArrowBatchTransformer.zip_batches(result) + return results_with_types + + # For scalar pandas UDF, convert all results to Arrow at once + if is_scalar_pandas_udf: + return PandasBatchTransformer.to_arrow( + results_with_types, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + + # For scalar arrow UDF, zip Arrow arrays into a RecordBatch + if is_scalar_arrow_udf: + return ArrowBatchTransformer.zip_batches(results_with_types) # For grouped/cogrouped map UDFs: # All wrappers yield batches, so mapper returns generator → need flatten From 63b6fbe37fab8b2af771ed944adbff92e7eb3e0f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 30 Jan 2026 21:27:28 -0800 Subject: [PATCH 33/39] fix: remove unused imports to fix linting errors --- python/pyspark/sql/connect/session.py | 1 - python/pyspark/sql/pandas/serializers.py | 16 +--------------- python/pyspark/worker.py | 7 ------- 3 files changed, 1 insertion(+), 23 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index a67b08f5c3f9..bedf442a944d 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -69,7 +69,6 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager -from pyspark.sql.pandas.serializers import ArrowStreamUDFSerializer from pyspark.sql.pandas.types import ( to_arrow_schema, to_arrow_type, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5d7a8f599935..2a5cd1d2df6b 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,7 +20,6 @@ """ from itertools import groupby -from typing import TYPE_CHECKING, Iterator, Optional import pyspark from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError @@ -35,18 +34,11 @@ from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, - ArrowArrayToPandasConversion, ArrowBatchTransformer, PandasBatchTransformer, ) -from pyspark.sql.pandas.types import ( - from_arrow_type, - is_variant, - to_arrow_type, - _create_converter_from_pandas, -) +from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import ( - DataType, StringType, StructType, BinaryType, @@ -55,11 +47,6 @@ IntegerType, ) -if TYPE_CHECKING: - import pandas as pd - import pyarrow as pa - - class SpecialLengths: END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 @@ -991,7 +978,6 @@ def load_stream(self, stream): Please refer the doc of inner function `generate_data_batches` for more details how this function works in overall. """ - import pyarrow as pa import pandas as pd from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 1639285ff6c3..587f9b4a3abb 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -52,7 +52,6 @@ ArrowTableToRowsConversion, ArrowBatchTransformer, PandasBatchTransformer, - ArrowArrayToPandasConversion, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( @@ -607,8 +606,6 @@ def wrapper(batch_iter): def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): - import pyarrow as pa - if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -798,8 +795,6 @@ def wrapped(key_batch, value_batches): def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf): - import pyarrow as pa - if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -3434,8 +3429,6 @@ def mapper(a): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): - import pandas as pd - # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF, # batch_iter is now Iterator[pa.RecordBatch] (raw batches from serializer) # Convert to pandas and concatenate into single Series per column From 79d7c412836c54975d7428f1e7075f58f3600b24 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:45:24 -0800 Subject: [PATCH 34/39] fix: correct SQL_BATCHED_UDF and legacy Arrow UDF handling --- python/pyspark/sql/pandas/serializers.py | 34 ++++++++++++---- python/pyspark/worker.py | 49 +++++++++++++++++++++--- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 2a5cd1d2df6b..6986eac6dea8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -47,6 +47,7 @@ IntegerType, ) + class SpecialLengths: END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 @@ -1037,20 +1038,37 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ + from pyspark.sql.conversion import ArrowBatchTransformer - def flatten_iterator(): - # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] + def create_batches(): + # iterator: iter[list[(iter[pandas.DataFrame], arrow_return_type)]] for packed in iterator: - iter_pdf, _ = packed[0] + iter_pdf, arrow_return_type = packed[0] for pdf in iter_pdf: - yield PandasBatchTransformer.to_arrow( - pdf, + # Extract columns from DataFrame and pair with Arrow types + # Similar to wrap_grouped_map_pandas_udf pattern + if self._assign_cols_by_name and any( + isinstance(name, str) for name in pdf.columns + ): + series_list = [(pdf[field.name], field.type) for field in arrow_return_type] + else: + series_list = [ + (pdf[pdf.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] + batch = PandasBatchTransformer.to_arrow( + series_list, timezone=self._timezone, safecheck=self._safecheck, int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, + arrow_cast=self._arrow_cast, ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) - super().dump_stream(flatten_iterator(), stream) + # Write START_ARROW_STREAM marker before first batch + batches = self._write_stream_start(create_batches(), stream) + ArrowStreamSerializer.dump_stream(self, batches, stream) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): @@ -1190,7 +1208,9 @@ def row_stream(): else EMPTY_DATAFRAME.copy(), ) - _batches = super().load_stream(stream) + # Call ArrowStreamSerializer.load_stream directly to get raw Arrow batches + # (not the parent's load_stream which returns processed data with mode info) + _batches = ArrowStreamSerializer.load_stream(self, stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 587f9b4a3abb..3a8d7c00705a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2924,9 +2924,10 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): # SQL_MAP_PANDAS_ITER_UDF, SQL_MAP_ARROW_ITER_UDF, SQL_SCALAR_ARROW_UDF, # SQL_SCALAR_ARROW_ITER_UDF, SQL_ARROW_BATCHED_UDF (legacy mode) if is_legacy_arrow_batched_udf: - # Read input type from stream (required for protocol), but conversion - # is handled in wrap_arrow_batch_udf_legacy - _parse_datatype_json_string(utf8_deserializer.loads(infile)) + # Read input type from stream - needed for proper UDT conversion + legacy_input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) + else: + legacy_input_type = None ser = ArrowStreamGroupSerializer() else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) @@ -3458,6 +3459,19 @@ def mapper(batch_iter): # Merge all RecordBatches horizontally return ArrowBatchTransformer.zip_batches(result_batches) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: + # Regular batched UDF uses pickle serialization, no pyarrow needed + # Input: individual rows (list of column values) + # Output: individual results + def mapper(row): + # row is a list of column values [col0, col1, col2, ...] + # Each UDF wrapper takes its columns and returns the result + results = [f(*[row[o] for o in arg_offsets]) for arg_offsets, f in udfs] + if len(results) == 1: + return results[0] + else: + return tuple(results) + else: import pyarrow as pa @@ -3474,11 +3488,22 @@ def mapper(batch): 2. Call each wrapper, collect (result, arrow_type) pairs 3. Output conversion: Pandas/Arrow → Arrow RecordBatch """ - # Step 1: Input conversion (Arrow → Pandas for pandas UDF) + # Step 1: Input conversion (Arrow → Pandas for pandas UDF or legacy arrow batched UDF) if is_scalar_pandas_udf: data = ArrowBatchTransformer.to_pandas( batch, timezone=runner_conf.timezone, df_for_struct=True ) + elif is_legacy_arrow_batched_udf: + # Legacy mode: convert Arrow to pandas Series for wrap_arrow_batch_udf_legacy + # which expects *args: pd.Series input + # Use legacy_input_type schema for proper UDT conversion + data = ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + schema=legacy_input_type, + struct_in_pandas="row", + ndarray_as_list=True, + ) else: data = batch @@ -3487,13 +3512,25 @@ def mapper(batch): results_with_types = [f(*[data[o] for o in arg_offsets]) for arg_offsets, f in udfs] # Step 3: Output conversion - # For arrow batched UDF, return raw result (serializer handles conversion) - if is_arrow_batched_udf: + # For arrow batched UDF (non-legacy), return raw result (serializer handles conversion) + if is_arrow_batched_udf and not is_legacy_arrow_batched_udf: if len(results_with_types) == 1: return results_with_types[0] else: return results_with_types + # For legacy arrow batched UDF, convert pandas results to Arrow + # Results are (pandas.Series, arrow_type, spark_type) tuples + if is_legacy_arrow_batched_udf: + return PandasBatchTransformer.to_arrow( + results_with_types, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="row", + ) + # For scalar pandas UDF, convert all results to Arrow at once if is_scalar_pandas_udf: return PandasBatchTransformer.to_arrow( From 78728c211252b66fb7865477e307a879dd1cf27a Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 01:00:09 -0800 Subject: [PATCH 35/39] wip --- python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/connect/session.py | 25 +- python/pyspark/sql/conversion.py | 33 +- python/pyspark/sql/pandas/conversion.py | 31 +- python/pyspark/sql/pandas/serializers.py | 16 +- python/pyspark/sql/tests/test_conversion.py | 398 ++++++++++++++++++++ 6 files changed, 463 insertions(+), 45 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ee35e237b898..d256e786e00c 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -380,6 +380,11 @@ " index out of range, got ''." ] }, + "INVALID_ARROW_BATCH_ZIP": { + "message": [ + "Cannot zip Arrow batches/arrays: ." + ] + }, "INVALID_ARROW_UDTF_RETURN_TYPE": { "message": [ "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '' method returned a value of type with value: ." diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index bedf442a944d..2332bfe49ce2 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -69,6 +69,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager +from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.pandas.types import ( to_arrow_schema, to_arrow_type, @@ -629,21 +630,19 @@ def createDataFrame( safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] - # Convert pandas data to Arrow RecordBatch - from pyspark.sql.conversion import PandasBatchTransformer - - batch_data = [ - (c, at, st) for (_, c), at, st in zip(data.items(), arrow_types, spark_types) - ] - record_batch = PandasBatchTransformer.to_arrow( - batch_data, - timezone=cast(str, timezone), - safecheck=safecheck == "true", - int_to_decimal_coercion_enabled=False, + _table = pa.Table.from_batches( + [ + PandasBatchTransformer.to_arrow( + [ + (c, at, st) + for (_, c), at, st in zip(data.items(), arrow_types, spark_types) + ], + timezone=cast(str, timezone), + safecheck=safecheck == "true", + ) + ] ) - _table = pa.Table.from_batches([record_batch]) - if isinstance(schema, StructType): assert arrow_schema is not None _table = _table.rename_columns( diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 695a383e840f..a60c6b0c510a 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -18,7 +18,18 @@ import array import datetime import decimal -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, + overload, +) from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError from pyspark.sql.pandas.types import ( @@ -242,7 +253,7 @@ def zip_batches( if isinstance(first_item, pa.RecordBatch): # Handle RecordBatches - batches = items + batches = cast(List["pa.RecordBatch"], items) if len(batches) == 1: return batches[0] @@ -514,9 +525,11 @@ def to_arrow( and isinstance(series[2], DataType) ) ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - normalized_series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) + items: List[Any] = [series] + else: + items = list(series) + tupled = ((s, None) if not isinstance(s, (list, tuple)) else s for s in items) + normalized_series = ((s[0], s[1], None) if len(s) == 2 else s for s in tupled) arrs = [] for s, arrow_type, spark_type in normalized_series: @@ -554,19 +567,19 @@ def to_arrow( for i, field in enumerate(arrow_type): # Get Series and spark_type based on matching strategy if use_name_matching: - series = s[field.name] + col_series: "pd.Series" = s[field.name] field_spark_type = ( spark_type[field.name].dataType if spark_type is not None else None ) else: - series = s[s.columns[i]].rename(field.name) + col_series = s[s.columns[i]].rename(field.name) field_spark_type = ( spark_type[i].dataType if spark_type is not None else None ) struct_arrs.append( PandasSeriesToArrowConversion.create_array( - series, + col_series, field.type, timezone=timezone, safecheck=safecheck, @@ -1603,7 +1616,7 @@ def create_array( import pandas as pd if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) + series = series.astype(series.dtype.categories.dtype) if arrow_type is not None: dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) @@ -1637,7 +1650,7 @@ def create_array( raise PySparkRuntimeError( errorClass=error_class, messageParameters={ - "col_name": series.name, + "col_name": str(series.name), "col_type": str(series.dtype), "arrow_type": str(arrow_type), }, diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 4bbae85d02c4..7e524b294a18 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -18,6 +18,7 @@ from typing import ( Any, Callable, + Iterator, List, Optional, Sequence, @@ -31,7 +32,6 @@ from pyspark.errors.exceptions.captured import unwrap_spark_exception from pyspark.util import _load_from_socket -from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.pandas.types import _dedup_names from pyspark.sql.types import ( @@ -808,7 +808,8 @@ def _create_from_pandas_with_arrow( assert isinstance(self, SparkSession) - from pyspark.sql.pandas.serializers import ArrowStreamUDFSerializer + from pyspark.sql.pandas.serializers import ArrowStreamSerializer + from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.types import TimestampType from pyspark.sql.pandas.types import ( from_arrow_type, @@ -879,8 +880,8 @@ def _create_from_pandas_with_arrow( step = step if step > 0 else len(pdf) pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) - # Create list of Arrow (columns, arrow_type, spark_type) for conversion to RecordBatch - pandas_data = [ + # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream + arrow_data = [ [ ( c, @@ -894,20 +895,16 @@ def _create_from_pandas_with_arrow( for pdf_slice in pdf_slices ] - # Convert pandas data to Arrow RecordBatches before serialization - arrow_data = map( - lambda batch_data: PandasBatchTransformer.to_arrow( - batch_data, - timezone=timezone, - safecheck=safecheck, - int_to_decimal_coercion_enabled=False, - ), - pandas_data, - ) - jsparkSession = self._jsparkSession - ser = ArrowStreamUDFSerializer(timezone, safecheck, False) + # Convert pandas data to Arrow batches + def create_batches() -> Iterator["pa.RecordBatch"]: + for batch_data in arrow_data: + yield PandasBatchTransformer.to_arrow( + batch_data, timezone=timezone, safecheck=safecheck + ) + + ser = ArrowStreamSerializer() @no_type_check def reader_func(temp_filename): @@ -918,7 +915,7 @@ def create_iter_server(): return self._jvm.ArrowIteratorServer() # Create Spark DataFrame from Arrow stream file, using one batch per partition - jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server) + jiter = self._sc._serialize_to_jvm(create_batches(), ser, reader_func, create_iter_server) assert self._jvm is not None jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession) df = DataFrame(jdf, self) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6986eac6dea8..a64f96685021 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1161,16 +1161,22 @@ def row_stream(): for batch in batches: self._update_batch_size_stats(batch) - data_pandas = to_pandas(flatten_columns(batch, "inputData")) - init_data_pandas = to_pandas(flatten_columns(batch, "initState")) + data_table = flatten_columns(batch, "inputData") + init_table = flatten_columns(batch, "initState") - assert not (bool(init_data_pandas) and bool(data_pandas)) + # Check column count - empty table has no columns + has_data = data_table.num_columns > 0 + has_init = init_table.num_columns > 0 - if bool(data_pandas): + assert not (has_data and has_init) + + if has_data: + data_pandas = to_pandas(data_table) for row in pd.concat(data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.key_offsets) yield (batch_key, row, None) - elif bool(init_data_pandas): + elif has_init: + init_data_pandas = to_pandas(init_table) for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.init_key_offsets) yield (batch_key, None, row) diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index c3fa1fd19304..54aa58864bc3 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -24,6 +24,8 @@ LocalDataToArrowConversion, ArrowTimestampConversion, ArrowBatchTransformer, + PandasBatchTransformer, + PandasSeriesToArrowConversion, ) from pyspark.sql.types import ( ArrayType, @@ -143,6 +145,402 @@ def test_wrap_struct_empty_batch(self): self.assertEqual(wrapped.num_rows, 0) self.assertEqual(wrapped.num_columns, 1) + def test_concat_batches_basic(self): + """Test concatenating multiple batches vertically.""" + import pyarrow as pa + + batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + + self.assertEqual(result.num_rows, 4) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3, 4]) + + def test_concat_batches_single(self): + """Test concatenating a single batch returns the same batch.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch]) + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_concat_batches_empty_batches(self): + """Test concatenating empty batches.""" + import pyarrow as pa + + schema = pa.schema([("x", pa.int64())]) + batch1 = pa.RecordBatch.from_arrays([pa.array([], type=pa.int64())], schema=schema) + batch2 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + + self.assertEqual(result.num_rows, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + + def test_zip_batches_record_batches(self): + """Test zipping multiple RecordBatches horizontally.""" + import pyarrow as pa + + batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["a"]) + batch2 = pa.RecordBatch.from_arrays([pa.array(["x", "y"])], ["b"]) + + result = ArrowBatchTransformer.zip_batches([batch1, batch2]) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.num_rows, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + self.assertEqual(result.column(1).to_pylist(), ["x", "y"]) + + def test_zip_batches_arrays(self): + """Test zipping Arrow arrays directly.""" + import pyarrow as pa + + arr1 = pa.array([1, 2, 3]) + arr2 = pa.array(["a", "b", "c"]) + + result = ArrowBatchTransformer.zip_batches([arr1, arr2]) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + self.assertEqual(result.column(1).to_pylist(), ["a", "b", "c"]) + + def test_zip_batches_with_type_casting(self): + """Test zipping with type casting.""" + import pyarrow as pa + + arr = pa.array([1, 2, 3], type=pa.int32()) + result = ArrowBatchTransformer.zip_batches([(arr, pa.int64())]) + + self.assertEqual(result.column(0).type, pa.int64()) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_zip_batches_empty_raises(self): + """Test that zipping empty list raises error.""" + with self.assertRaises(PySparkValueError): + ArrowBatchTransformer.zip_batches([]) + + def test_zip_batches_single_batch(self): + """Test zipping single batch returns it unchanged.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + result = ArrowBatchTransformer.zip_batches([batch]) + + self.assertIs(result, batch) + + def test_to_pandas_basic(self): + """Test basic Arrow to pandas conversion.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC") + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].tolist(), [1, 2, 3]) + self.assertEqual(result[1].tolist(), ["a", "b", "c"]) + + def test_to_pandas_empty_table(self): + """Test converting empty table returns NoValue series.""" + import pyarrow as pa + + table = pa.Table.from_arrays([], names=[]) + + result = ArrowBatchTransformer.to_pandas(table, timezone="UTC") + + # Empty table with rows should return a series with _NoValue + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 0) + + def test_to_pandas_with_nulls(self): + """Test conversion with null values.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, None, 3]), pa.array([None, "b", None])], + names=["x", "y"], + ) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC") + + self.assertEqual(len(result), 2) + self.assertTrue(result[0].isna().tolist() == [False, True, False]) + self.assertTrue(result[1].isna().tolist() == [True, False, True]) + + def test_to_pandas_with_schema_and_udt(self): + """Test conversion with Spark schema containing UDT.""" + import pyarrow as pa + + # Create Arrow batch with raw UDT data (list representation) + batch = pa.RecordBatch.from_arrays( + [pa.array([[1.0, 2.0], [3.0, 4.0]], type=pa.list_(pa.float64()))], + names=["point"], + ) + + schema = StructType([StructField("point", ExamplePointUDT())]) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC", schema=schema) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], ExamplePoint(1.0, 2.0)) + self.assertEqual(result[0][1], ExamplePoint(3.0, 4.0)) + + def test_flatten_and_wrap_roundtrip(self): + """Test that flatten -> wrap produces equivalent result.""" + import pyarrow as pa + + struct_array = pa.StructArray.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["x", "y"], + ) + original = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + flattened = ArrowBatchTransformer.flatten_struct(original) + rewrapped = ArrowBatchTransformer.wrap_struct(flattened) + + self.assertEqual(rewrapped.num_columns, 1) + self.assertEqual(rewrapped.column(0).to_pylist(), original.column(0).to_pylist()) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class PandasBatchTransformerTests(unittest.TestCase): + def test_to_arrow_single_series(self): + """Test converting a single pandas Series to Arrow.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2, 3]) + result = PandasBatchTransformer.to_arrow( + (series, pa.int64(), IntegerType()), timezone="UTC" + ) + + self.assertEqual(result.num_columns, 1) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_to_arrow_empty_series(self): + """Test converting empty pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([], dtype="int64") + result = PandasBatchTransformer.to_arrow( + (series, pa.int64(), IntegerType()), timezone="UTC" + ) + + self.assertEqual(result.num_rows, 0) + + def test_to_arrow_with_nulls(self): + """Test converting Series with null values.""" + import pandas as pd + import pyarrow as pa + import numpy as np + from pyspark.sql.types import DoubleType + + series = pd.Series([1.0, np.nan, 3.0]) + result = PandasBatchTransformer.to_arrow( + (series, pa.float64(), DoubleType()), timezone="UTC" + ) + + self.assertEqual(result.column(0).to_pylist(), [1.0, None, 3.0]) + + def test_to_arrow_multiple_series(self): + """Test converting multiple series at once.""" + import pandas as pd + import pyarrow as pa + + series1 = pd.Series([1, 2]) + series2 = pd.Series(["a", "b"]) + + result = PandasBatchTransformer.to_arrow( + [(series1, pa.int64(), None), (series2, pa.string(), None)], timezone="UTC" + ) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + self.assertEqual(result.column(1).to_pylist(), ["a", "b"]) + + def test_to_arrow_struct_mode(self): + """Test converting DataFrame to struct array.""" + import pandas as pd + import pyarrow as pa + + df = pd.DataFrame({"x": [1, 2], "y": ["a", "b"]}) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + spark_type = StructType([StructField("x", IntegerType()), StructField("y", StringType())]) + + result = PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), timezone="UTC", as_struct=True + ) + + self.assertEqual(result.num_columns, 1) + struct_col = result.column(0) + self.assertEqual(struct_col.field(0).to_pylist(), [1, 2]) + self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b"]) + + def test_to_arrow_struct_empty_dataframe(self): + """Test converting empty DataFrame in struct mode.""" + import pandas as pd + import pyarrow as pa + + df = pd.DataFrame({"x": pd.Series([], dtype="int64"), "y": pd.Series([], dtype="str")}) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + spark_type = StructType([StructField("x", IntegerType()), StructField("y", StringType())]) + + result = PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), timezone="UTC", as_struct=True + ) + + self.assertEqual(result.num_rows, 0) + + def test_to_arrow_categorical_series(self): + """Test converting categorical pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Categorical(["a", "b", "a", "c"]) + result = PandasBatchTransformer.to_arrow( + (pd.Series(series), pa.string(), StringType()), timezone="UTC" + ) + + self.assertEqual(result.column(0).to_pylist(), ["a", "b", "a", "c"]) + + def test_to_arrow_requires_dataframe_for_struct(self): + """Test that struct mode requires DataFrame input.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2]) + arrow_type = pa.struct([("x", pa.int64())]) + + with self.assertRaises(PySparkValueError): + PandasBatchTransformer.to_arrow( + (series, arrow_type, None), timezone="UTC", as_struct=True + ) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class PandasSeriesToArrowConversionTests(unittest.TestCase): + def test_create_array_basic(self): + """Test basic array creation from pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2, 3]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [1, 2, 3]) + + def test_create_array_empty(self): + """Test creating array from empty Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([], dtype="int64") + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(len(result), 0) + + def test_create_array_with_nulls(self): + """Test creating array with null values.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, None, 3], dtype="Int64") # nullable integer + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [1, None, 3]) + + def test_create_array_categorical(self): + """Test creating array from categorical Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series(pd.Categorical(["x", "y", "x"])) + result = PandasSeriesToArrowConversion.create_array( + series, pa.string(), "UTC", spark_type=StringType() + ) + + self.assertEqual(result.to_pylist(), ["x", "y", "x"]) + + def test_create_array_with_udt(self): + """Test creating array with UDT.""" + import pandas as pd + import pyarrow as pa + + points = [ExamplePoint(1.0, 2.0), ExamplePoint(3.0, 4.0)] + series = pd.Series(points) + + result = PandasSeriesToArrowConversion.create_array( + series, pa.list_(pa.float64()), "UTC", spark_type=ExamplePointUDT() + ) + + self.assertEqual(result.to_pylist(), [[1.0, 2.0], [3.0, 4.0]]) + + def test_create_array_binary(self): + """Test creating array with binary data.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([b"hello", b"world"]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.binary(), "UTC", spark_type=BinaryType() + ) + + self.assertEqual(result.to_pylist(), [b"hello", b"world"]) + + def test_create_array_all_nulls(self): + """Test creating array with all null values.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([None, None, None]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [None, None, None]) + + def test_create_array_nested_list(self): + """Test creating array with nested list type.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([[1, 2], [3, 4, 5], []]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.list_(pa.int64()), "UTC", spark_type=ArrayType(IntegerType()) + ) + + self.assertEqual(result.to_pylist(), [[1, 2], [3, 4, 5], []]) + + def test_create_array_map_type(self): + """Test creating array with map type.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([{"a": 1}, {"b": 2, "c": 3}]) + result = PandasSeriesToArrowConversion.create_array( + series, + pa.map_(pa.string(), pa.int64()), + "UTC", + spark_type=MapType(StringType(), IntegerType()), + ) + + # Maps are returned as list of tuples + self.assertEqual(result.to_pylist(), [[("a", 1)], [("b", 2), ("c", 3)]]) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class ConversionTests(unittest.TestCase): From 9cbf2507244fed0ea067a99813b53b39bb41d407 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 11:53:39 -0800 Subject: [PATCH 36/39] fix: simplify ArrowStreamArrowUDTFSerializer type coercion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplify dump_stream to fix CI failure in Arrow UDTF lateral join test. The previous implementation had complex unwrapping logic that assumed batches might come wrapped from the worker. However, the worker always yields unwrapped batches, so the detection and unwrapping code was causing issues with the data flow. Changes: - Remove batch unwrapping logic (batches are never wrapped at this point) - Use direct array casting instead of zip_batches for type coercion - Maintain the proper flow: receive unwrapped batch → coerce types → wrap for JVM --- python/pyspark/sql/pandas/serializers.py | 68 +++++++++++------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index a64f96685021..d4d275e9a9d8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -301,6 +301,9 @@ def dump_stream(self, iterator, stream): """ Override to handle type coercion for ArrowUDTF outputs. ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. + + The function performs type coercion on each batch based on arrow_return_type, + then wraps the result into a struct column before serialization. """ import pyarrow as pa @@ -310,32 +313,24 @@ def apply_type_coercion(): arrow_return_type, pa.StructType ), f"Expected pa.StructType, got {type(arrow_return_type)}" - # Batch is already wrapped into a struct column by worker (wrap_arrow_udtf) - # Unwrap it first to access individual columns - if batch.num_columns == 1 and batch.column(0).type == pa.struct( - list(arrow_return_type) - ): - # Batch is wrapped, unwrap it - unwrapped_batch = ArrowBatchTransformer.flatten_struct(batch, column_index=0) - elif batch.num_columns == 0: - # Empty batch: wrap it back to struct column + # Handle empty batch case (no columns) + if batch.num_columns == 0: + # Empty batch: no coercion needed, wrap it back to struct column coerced_batch = ArrowBatchTransformer.wrap_struct(batch) yield coerced_batch continue - else: - # Batch is not wrapped (shouldn't happen, but handle it) - unwrapped_batch = batch - # Handle empty struct case specially (no columns to coerce) + # Handle empty struct case (no fields expected) if len(arrow_return_type) == 0: - # Empty struct: wrap unwrapped batch (which should also be empty) back to struct column - coerced_batch = ArrowBatchTransformer.wrap_struct(unwrapped_batch) + # Empty struct: wrap batch back to struct column + coerced_batch = ArrowBatchTransformer.wrap_struct(batch) yield coerced_batch continue # Check field names match expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = unwrapped_batch.schema.names + actual_field_names = batch.schema.names + if expected_field_names != actual_field_names: raise PySparkTypeError( "Target schema's field names are not matching the record batch's " @@ -343,33 +338,30 @@ def apply_type_coercion(): f"Expected: {expected_field_names}, but got: {actual_field_names}." ) - # Use zip_batches for type coercion: create (array, type) tuples - arrays_and_types = [ - (unwrapped_batch.column(i), field.type) - for i, field in enumerate(arrow_return_type) - ] - try: - coerced_batch = ArrowBatchTransformer.zip_batches( - arrays_and_types, safecheck=True - ) - except PySparkTypeError as e: - # Re-raise with UDTF-specific error - # Find the first array that failed type coercion - # arrays_and_types contains (array, field_type) tuples - for array, expected_type in arrays_and_types: - if array.type != expected_type: + # Create (array, target_type) tuples for type coercion + coerced_arrays = [] + for i, field in enumerate(arrow_return_type): + original_array = batch.column(i) + if original_array.type == field.type: + coerced_arrays.append(original_array) + else: + try: + coerced_array = original_array.cast( + target_type=field.type, safe=True + ) + coerced_arrays.append(coerced_array) + except (pa.ArrowInvalid, pa.ArrowTypeError): raise PySparkRuntimeError( errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", messageParameters={ - "expected": str(expected_type), - "actual": str(array.type), + "expected": str(field.type), + "actual": str(original_array.type), }, - ) from e - # If no type mismatch found, re-raise original error - raise + ) - # Rename columns to match expected field names - coerced_batch = coerced_batch.rename_columns(expected_field_names) + coerced_batch = pa.RecordBatch.from_arrays( + coerced_arrays, names=expected_field_names + ) # Wrap into struct column for JVM coerced_batch = ArrowBatchTransformer.wrap_struct(coerced_batch) yield coerced_batch From 70c0eb5618c3ff5f5da8186628ebf8410cc69ed1 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sat, 31 Jan 2026 12:58:23 -0800 Subject: [PATCH 37/39] wip: remove ArrowStreamArrowUDTFSerializer --- python/pyspark/sql/conversion.py | 79 ++++++++++++++++- python/pyspark/sql/pandas/serializers.py | 105 ++--------------------- python/pyspark/worker.py | 87 +++++++++---------- 3 files changed, 122 insertions(+), 149 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index a60c6b0c510a..7b12dff70abd 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -88,7 +88,7 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record Used by ------- - ArrowStreamGroupSerializer - - ArrowStreamArrowUDTFSerializer + - SQL_ARROW_UDTF mapper - SQL_MAP_ARROW_ITER_UDF mapper - SQL_GROUPED_MAP_ARROW_UDF mapper - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper @@ -110,7 +110,7 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": - wrap_grouped_map_arrow_iter_udf - wrap_cogrouped_map_arrow_udf - wrap_arrow_batch_iter_udf - - ArrowStreamArrowUDTFSerializer.dump_stream + - SQL_ARROW_UDTF mapper - TransformWithStateInPySparkRowSerializer.dump_stream """ import pyarrow as pa @@ -123,6 +123,67 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) + @classmethod + def coerce_types( + cls, batch: "pa.RecordBatch", arrow_return_type: "pa.StructType" + ) -> "pa.RecordBatch": + """ + Apply type coercion to a RecordBatch based on expected schema. + + This method: + 1. Handles empty batches (no columns or empty struct) - returns as-is + 2. Validates field names match expected schema + 3. Casts arrays to expected types if needed (safe casting) + + Parameters + ---------- + batch : pa.RecordBatch + Input RecordBatch to coerce + arrow_return_type : pa.StructType + Expected Arrow schema for type coercion + + Returns + ------- + pa.RecordBatch + RecordBatch with coerced column types + """ + import pyarrow as pa + + # Handle empty batch case (no columns) or empty struct case + if batch.num_columns == 0 or len(arrow_return_type) == 0: + return batch + + # Check field names match + expected_field_names = [field.name for field in arrow_return_type] + actual_field_names = batch.schema.names + + if expected_field_names != actual_field_names: + raise PySparkTypeError( + "Target schema's field names are not matching the record batch's " + "field names. " + f"Expected: {expected_field_names}, but got: {actual_field_names}." + ) + + # Apply type coercion if needed + coerced_arrays = [] + for i, field in enumerate(arrow_return_type): + original_array = batch.column(i) + if original_array.type == field.type: + coerced_arrays.append(original_array) + else: + try: + coerced_arrays.append(original_array.cast(target_type=field.type, safe=True)) + except (pa.ArrowInvalid, pa.ArrowTypeError): + raise PySparkRuntimeError( + errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", + messageParameters={ + "expected": str(field.type), + "actual": str(original_array.type), + }, + ) + + return pa.RecordBatch.from_arrays(coerced_arrays, names=expected_field_names) + @classmethod def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": """ @@ -1047,11 +1108,21 @@ def create_array( def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table": require_minimum_pyarrow_version() import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_type - assert isinstance(data, list) and len(data) > 0 - + assert isinstance(data, list) assert schema is not None and isinstance(schema, StructType) + # Handle empty data - return empty table with correct schema + if len(data) == 0: + arrow_schema = to_arrow_type( + schema, timezone="UTC", prefers_large_types=use_large_var_types + ) + pa_schema = pa.schema(list(arrow_schema)) + # Create empty arrays for each field to ensure proper table structure + empty_arrays = [pa.array([], type=f.type) for f in pa_schema] + return pa.Table.from_arrays(empty_arrays, schema=pa_schema) + column_names = schema.fieldNames() len_column_names = len(column_names) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d4d275e9a9d8..732bf58f88a1 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -22,7 +22,7 @@ from itertools import groupby import pyspark -from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.serializers import ( Serializer, read_int, @@ -207,10 +207,10 @@ class ArrowStreamGroupSerializer(ArrowStreamSerializer): Unified serializer for Arrow stream operations with optional grouping support. This serializer handles: - - Non-grouped operations: SQL_MAP_ARROW_ITER_UDF (num_dfs=0) - - Grouped operations: SQL_GROUPED_MAP_ARROW_UDF, SQL_GROUPED_MAP_PANDAS_UDF (num_dfs=1) - - Cogrouped operations: SQL_COGROUPED_MAP_ARROW_UDF, SQL_COGROUPED_MAP_PANDAS_UDF (num_dfs=2) - - Grouped aggregations: SQL_GROUPED_AGG_ARROW_UDF, SQL_GROUPED_AGG_PANDAS_UDF (num_dfs=1) + - Non-grouped (num_dfs=0): SQL_MAP_ARROW_ITER_UDF, SQL_ARROW_TABLE_UDF, SQL_ARROW_UDTF + - Grouped (num_dfs=1): SQL_GROUPED_MAP_ARROW_UDF, SQL_GROUPED_MAP_PANDAS_UDF, + SQL_GROUPED_AGG_ARROW_UDF, SQL_GROUPED_AGG_PANDAS_UDF + - Cogrouped (num_dfs=2): SQL_COGROUPED_MAP_ARROW_UDF, SQL_COGROUPED_MAP_PANDAS_UDF The serializer handles Arrow stream I/O and START signal, while transformation logic (flatten/wrap struct, pandas conversion) is handled by worker wrappers. @@ -274,101 +274,6 @@ def dump_stream(self, iterator, stream): return super().dump_stream(batches, stream) -class ArrowStreamArrowUDTFSerializer(ArrowStreamGroupSerializer): - """ - Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. - """ - - def __init__(self, table_arg_offsets=None): - super().__init__() - self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] - - def load_stream(self, stream): - """ - Flatten the struct into Arrow's record batches. - """ - for batch in super().load_stream(stream): - # For each column: flatten struct columns at table_arg_offsets into RecordBatch, - # keep other columns as Array - yield [ - ArrowBatchTransformer.flatten_struct(batch, column_index=i) - if i in self.table_arg_offsets - else batch.column(i) - for i in range(batch.num_columns) - ] - - def dump_stream(self, iterator, stream): - """ - Override to handle type coercion for ArrowUDTF outputs. - ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. - - The function performs type coercion on each batch based on arrow_return_type, - then wraps the result into a struct column before serialization. - """ - import pyarrow as pa - - def apply_type_coercion(): - for batch, arrow_return_type in iterator: - assert isinstance( - arrow_return_type, pa.StructType - ), f"Expected pa.StructType, got {type(arrow_return_type)}" - - # Handle empty batch case (no columns) - if batch.num_columns == 0: - # Empty batch: no coercion needed, wrap it back to struct column - coerced_batch = ArrowBatchTransformer.wrap_struct(batch) - yield coerced_batch - continue - - # Handle empty struct case (no fields expected) - if len(arrow_return_type) == 0: - # Empty struct: wrap batch back to struct column - coerced_batch = ArrowBatchTransformer.wrap_struct(batch) - yield coerced_batch - continue - - # Check field names match - expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = batch.schema.names - - if expected_field_names != actual_field_names: - raise PySparkTypeError( - "Target schema's field names are not matching the record batch's " - "field names. " - f"Expected: {expected_field_names}, but got: {actual_field_names}." - ) - - # Create (array, target_type) tuples for type coercion - coerced_arrays = [] - for i, field in enumerate(arrow_return_type): - original_array = batch.column(i) - if original_array.type == field.type: - coerced_arrays.append(original_array) - else: - try: - coerced_array = original_array.cast( - target_type=field.type, safe=True - ) - coerced_arrays.append(coerced_array) - except (pa.ArrowInvalid, pa.ArrowTypeError): - raise PySparkRuntimeError( - errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", - messageParameters={ - "expected": str(field.type), - "actual": str(original_array.type), - }, - ) - - coerced_batch = pa.RecordBatch.from_arrays( - coerced_arrays, names=expected_field_names - ) - # Wrap into struct column for JVM - coerced_batch = ArrowBatchTransformer.wrap_struct(coerced_batch) - yield coerced_batch - - return super().dump_stream(apply_type_coercion(), stream) - - class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ Serializer for UDFs that handles Arrow RecordBatch serialization. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3a8d7c00705a..97f84b37695e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -62,7 +62,6 @@ TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, ArrowBatchUDFSerializer, - ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type, TimestampType from pyspark.sql.types import ( @@ -1634,11 +1633,10 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): else: ser = ArrowStreamGroupSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: - # Read the table argument offsets + # Read the table argument offsets (used by mapper to flatten struct columns) num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] - # Use PyArrow-native serializer for Arrow UDTFs with potential UDT support - ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) + ser = ArrowStreamGroupSerializer() else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) @@ -2480,12 +2478,10 @@ def check_return_value(res): yield row def convert_to_arrow(data: Iterable): + """ + Convert Python data to coerced Arrow batches. + """ data = list(check_return_value(data)) - if len(data) == 0: - # Return one empty RecordBatch to match the left side of the lateral join - return [ - pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) - ] def raise_conversion_error(original_exception): raise PySparkRuntimeError( @@ -2520,30 +2516,24 @@ def raise_conversion_error(original_exception): except Exception as e: raise_conversion_error(e) - return verify_result(table).to_batches() + # Coerce types for each batch + table = verify_result(table) + batches = table.to_batches() + if len(batches) == 0: + # Empty table - create empty batch for lateral join semantics + batches = [pa.RecordBatch.from_pylist([], schema=table.schema)] + for batch in batches: + yield ArrowBatchTransformer.coerce_types(batch, arrow_return_type) def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): - for batch in convert_to_arrow(func()): - # Handle empty batch: wrap it immediately for JVM - if batch.num_columns == 0: - yield ArrowBatchTransformer.wrap_struct(batch), arrow_return_type - else: - # Yield (batch, arrow_return_type) tuple for serializer - # Serializer will handle type coercion and wrapping - yield batch, arrow_return_type - + yield from map(ArrowBatchTransformer.wrap_struct, convert_to_arrow(func())) else: for row in zip(*args): - for batch in convert_to_arrow(func(*row)): - # Handle empty batch: wrap it immediately for JVM - if batch.num_columns == 0: - yield ArrowBatchTransformer.wrap_struct(batch), arrow_return_type - else: - # Yield (batch, arrow_return_type) tuple for serializer - # Serializer will handle type coercion and wrapping - yield batch, arrow_return_type + yield from map( + ArrowBatchTransformer.wrap_struct, convert_to_arrow(func(*row)) + ) return evaluate @@ -2576,8 +2566,8 @@ def mapper(_, it): else column.to_pylist() for column, conv in zip(a.columns, converters) ] - # The eval function yields an iterator. Each element produced by this - # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type). + # The eval function yields wrapped RecordBatches ready for serialization. + # Each batch has been type-coerced and wrapped into a struct column. yield from eval(*[pylist[o] for o in args_kwargs_offsets], num_rows=a.num_rows) if terminate is not None: yield from terminate() @@ -2611,9 +2601,6 @@ def verify_result(result): "func": f.__name__, }, ) - - # We verify the type of the result and do type corerion - # in the serializer return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. @@ -2644,30 +2631,33 @@ def check_return_value(res): return iter([]) def convert_to_arrow(data: Iterable): - data_iter = check_return_value(data) - - # Handle PyArrow Tables/RecordBatches directly + """ + Convert PyArrow data to coerced Arrow batches. + """ is_empty = True - for item in data_iter: + for item in check_return_value(data): is_empty = False if isinstance(item, pa.Table): - yield from item.to_batches() + batches = item.to_batches() or [ + pa.RecordBatch.from_pylist([], schema=item.schema) + ] elif isinstance(item, pa.RecordBatch): - yield item + batches = [item] else: - # Arrow UDTF should only return Arrow types (RecordBatch/Table) raise PySparkRuntimeError( errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR", messageParameters={}, ) + yield from ( + ArrowBatchTransformer.coerce_types(verify_result(batch), arrow_return_type) + for batch in batches + ) if is_empty: yield pa.RecordBatch.from_pylist([], schema=pa.schema(list(arrow_return_type))) def evaluate(*args: pa.RecordBatch): - # For Arrow UDTFs, unpack the RecordBatches and pass them to the function - for batch in convert_to_arrow(func(*args)): - yield verify_result(batch), arrow_return_type + yield from map(ArrowBatchTransformer.wrap_struct, convert_to_arrow(func(*args))) return evaluate @@ -2685,9 +2675,16 @@ def evaluate(*args: pa.RecordBatch): def mapper(_, it): try: - for a in it: - # For PyArrow UDTFs, pass RecordBatches directly (no row conversion needed) - yield from eval(*[a[o] for o in args_kwargs_offsets]) + for batch in it: + # Flatten struct columns at table_arg_offsets into RecordBatches, + # keep other columns as Arrays + args = [ + ArrowBatchTransformer.flatten_struct(batch, column_index=i) + if i in table_arg_offsets + else batch.column(i) + for i in range(batch.num_columns) + ] + yield from eval(*[args[o] for o in args_kwargs_offsets]) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: From 60ffb0c1438a7f497bdc61a3c87b782ce7330a46 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 1 Feb 2026 11:34:11 -0800 Subject: [PATCH 38/39] refactor: remove TransformWithState PySpark Row serializers Move the load_stream and dump_stream logic from TransformWithStateInPySparkRowSerializer and TransformWithStateInPySparkRowInitStateSerializer into worker.py. Both eval types now use ArrowStreamGroupSerializer directly, simplifying the serializer hierarchy. --- python/pyspark/sql/pandas/serializers.py | 209 -------------------- python/pyspark/worker.py | 230 +++++++++++++++++++---- 2 files changed, 196 insertions(+), 243 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 732bf58f88a1..58533188017a 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -30,7 +30,6 @@ UTF8Deserializer, CPickleSerializer, ) -from pyspark.sql import Row from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, @@ -1122,211 +1121,3 @@ def row_stream(): yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) - - -class TransformWithStateInPySparkRowSerializer(ArrowStreamGroupSerializer): - """ - Serializer used by Python worker to evaluate UDF for - :meth:`pyspark.sql.GroupedData.transformWithState`. - - Parameters - ---------- - arrow_max_records_per_batch : int - Limit of the number of records that can be written to a single ArrowRecordBatch in memory. - """ - - def __init__(self, arrow_max_records_per_batch): - super().__init__() - self.arrow_max_records_per_batch = ( - arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 - ) - self.key_offsets = None - - def load_stream(self, stream): - """ - Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunks, - and convert the data into a list of pandas.Series. - - Please refer the doc of inner function `generate_data_batches` for more details how - this function works in overall. - """ - from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPandasFuncMode, - ) - import itertools - - def generate_data_batches(batches): - """ - Deserialize ArrowRecordBatches and return a generator of Row. - - The deserialization logic assumes that Arrow RecordBatches contain the data with the - ordering that data chunks for same grouping key will appear sequentially. - - This function must avoid materializing multiple Arrow RecordBatches into memory at the - same time. And data chunks from the same grouping key should appear sequentially. - """ - for batch in batches: - DataRow = Row(*batch.schema.names) - - # Iterate row by row without converting the whole batch - num_cols = batch.num_columns - for row_idx in range(batch.num_rows): - # build the key for this row - row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets) - row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) - yield row_key, row - - _batches = super().load_stream(stream) - data_batches = generate_data_batches(_batches) - - for k, g in groupby(data_batches, key=lambda x: x[0]): - chained = itertools.chain(g) - chained_values = map(lambda x: x[1], chained) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) - - yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - - yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) - - def dump_stream(self, iterator, stream): - """ - Read through an iterator of (iterator of Row), serialize them to Arrow - RecordBatches, and write batches to stream. - """ - import pyarrow as pa - - def flatten_iterator(): - # iterator: iter[list[(iter[Row], pdf_type)]] - for packed in iterator: - iter_row_with_type = packed[0] - iter_row = iter_row_with_type[0] - pdf_type = iter_row_with_type[1] - - rows_as_dict = [] - for row in iter_row: - row_as_dict = row.asDict(True) - rows_as_dict.append(row_as_dict) - - pdf_schema = pa.schema(list(pdf_type)) - record_batch = pa.RecordBatch.from_pylist(rows_as_dict, schema=pdf_schema) - - # Wrap the batch into a struct before yielding - wrapped_batch = ArrowBatchTransformer.wrap_struct(record_batch) - yield wrapped_batch - - return ArrowStreamGroupSerializer.dump_stream(self, flatten_iterator(), stream) - - -class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer): - """ - Serializer used by Python worker to evaluate UDF for - :meth:`pyspark.sql.GroupedData.transformWithStateInPySparkRowInitStateSerializer`. - Parameters - ---------- - Same as input parameters in TransformWithStateInPySparkRowSerializer. - """ - - def __init__(self, arrow_max_records_per_batch): - super().__init__(arrow_max_records_per_batch) - self.init_key_offsets = None - - def load_stream(self, stream): - import pyarrow as pa - from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPandasFuncMode, - ) - from typing import Iterator, Any, Optional, Tuple - - def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]: - """ - Deserialize ArrowRecordBatches and return a generator of Row. - The deserialization logic assumes that Arrow RecordBatches contain the data with the - ordering that data chunks for same grouping key will appear sequentially. - See `TransformWithStateInPySparkPythonInitialStateRunner` for arrow batch schema sent - from JVM. - This function flattens the columns of input rows and initial state rows and feed them - into the data generator. - """ - - def extract_rows( - cur_batch, col_name, key_offsets - ) -> Optional[Iterator[Tuple[Any, Any]]]: - data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) - - # Check if the entire column is null - if data_column.null_count == len(data_column): - return None - - data_field_names = [ - data_column.type[i].name for i in range(data_column.type.num_fields) - ] - data_field_arrays = [ - data_column.field(i) for i in range(data_column.type.num_fields) - ] - - DataRow = Row(*data_field_names) - - table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) - - if table.num_rows == 0: - return None - - def row_iterator(): - for row_idx in range(table.num_rows): - key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) - row = DataRow( - *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) - ) - yield (key, row) - - return row_iterator() - - """ - The arrow batch is written in the schema: - schema: StructType = new StructType() - .add("inputData", dataSchema) - .add("initState", initStateSchema) - We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. Each batch will have either init_data or input_data, not mix. - """ - for batch in batches: - # Detect which column has data - each batch contains only one type - input_result = extract_rows(batch, "inputData", self.key_offsets) - init_result = extract_rows(batch, "initState", self.init_key_offsets) - - assert not (input_result is not None and init_result is not None) - - if input_result is not None: - for key, input_data_row in input_result: - yield (key, input_data_row, None) - elif init_result is not None: - for key, init_state_row in init_result: - yield (key, None, init_state_row) - - _batches = super().load_stream(stream) - data_batches = generate_data_batches(_batches) - - for k, g in groupby(data_batches, key=lambda x: x[0]): - input_rows = [] - init_rows = [] - - for batch_key, input_row, init_row in g: - if input_row is not None: - input_rows.append(input_row) - if init_row is not None: - init_rows.append(init_row) - - total_len = len(input_rows) + len(init_rows) - if total_len >= self.arrow_max_records_per_batch: - ret_tuple = (iter(input_rows), iter(init_rows)) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) - input_rows = [] - init_rows = [] - - if input_rows or init_rows: - ret_tuple = (iter(input_rows), iter(init_rows)) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) - - yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - - yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 97f84b37695e..d59e76ac109e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,7 @@ from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, + ArrowArrayToPandasConversion, ArrowBatchTransformer, PandasBatchTransformer, ) @@ -59,8 +60,6 @@ ApplyInPandasWithStateSerializer, TransformWithStateInPandasSerializer, TransformWithStateInPandasInitStateSerializer, - TransformWithStateInPySparkRowSerializer, - TransformWithStateInPySparkRowInitStateSerializer, ArrowBatchUDFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type, TimestampType @@ -2380,16 +2379,46 @@ def evaluate(*args: pd.Series, num_rows=1): cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None def mapper(_, it): + def convert_output(output_iter): + """Convert (pandas.DataFrame, arrow_type, spark_type) to Arrow RecordBatch.""" + for df, arrow_type, spark_type in output_iter: + yield PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + as_struct=True, + assign_cols_by_name=False, + arrow_cast=True, + ignore_unexpected_complex_type_values=True, + error_class="UDTF_ARROW_TYPE_CAST_ERROR", + ) + try: for a in it: + # Convert Arrow columns to pandas Series for legacy UDTF + # Use struct_in_pandas="row" to match master's behavior for table arguments + series = [ + ArrowArrayToPandasConversion.convert_legacy( + a.column(i), + f.dataType, + timezone=runner_conf.timezone, + struct_in_pandas="row", + ndarray_as_list=True, + ) + for i, f in enumerate(input_type) + ] # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). - yield from eval(*[a[o] for o in args_kwargs_offsets], num_rows=len(a[0])) + num_rows = a.num_rows if not series else len(series[0]) + yield from convert_output( + eval(*[series[o] for o in args_kwargs_offsets], num_rows=num_rows) + ) if terminate is not None: - yield from terminate() + yield from convert_output(terminate()) except SkipRestOfInputTableException: if terminate is not None: - yield from terminate() + yield from convert_output(terminate()) finally: if cleanup is not None: cleanup() @@ -2897,12 +2926,11 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): runner_conf.arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: - ser = TransformWithStateInPySparkRowSerializer(runner_conf.arrow_max_records_per_batch) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: - ser = TransformWithStateInPySparkRowInitStateSerializer( - runner_conf.arrow_max_records_per_batch - ) + elif eval_type in ( + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, + ): + ser = ArrowStreamGroupSerializer() # SQL_ARROW_BATCHED_UDF with new conversion uses ArrowBatchUDFSerializer elif ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF @@ -3192,31 +3220,71 @@ def values_gen(): # support combining multiple UDFs. assert num_udfs == 1 + import pyarrow as pa + # See TransformWithStateInPySparkExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - ser.key_offsets = parsed_offsets[0][0] + key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - def mapper(a): - mode = a[0] + def func(_, batches): + """ + Convert raw Arrow batches to (mode, key, rows) tuples, call mapper, + and convert output back to Arrow batches. + """ - if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: - key = a[1] - values = a[2] + def generate_data_batches(): + """Convert Arrow batches to (key, row) pairs.""" + for batch in batches: + DataRow = Row(*batch.schema.names) + num_cols = batch.num_columns + for row_idx in range(batch.num_rows): + row_key = tuple(batch[o][row_idx].as_py() for o in key_offsets) + row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) + yield row_key, row + + def generate_tuples(): + """Group by key and yield (mode, key, rows) tuples.""" + data_batches = generate_data_batches() + for k, g in itertools.groupby(data_batches, key=lambda x: x[0]): + chained = itertools.chain(g) + chained_values = map(lambda x: x[1], chained) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) + + def mapper(a): + mode = a[0] + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] + values = a[2] + return f(stateful_processor_api_client, mode, key, values) + else: + return f(stateful_processor_api_client, mode, None, iter([])) + + def convert_output(results): + """Convert UDF output [(iter[Row], pdf_type)] to Arrow batches.""" + for [(iter_row, pdf_type)] in results: + rows_as_dict = [row.asDict(True) for row in iter_row] + record_batch = pa.RecordBatch.from_pylist( + rows_as_dict, schema=pa.schema(list(pdf_type)) + ) + yield ArrowBatchTransformer.wrap_struct(record_batch) - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, mode, key, values) - else: - # mode == PROCESS_TIMER or mode == COMPLETE - return f(stateful_processor_api_client, mode, None, iter([])) + return convert_output(map(mapper, generate_tuples())) + + # Return early since func is already defined + return func, None, ser, ser elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 + import pyarrow as pa + # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] @@ -3226,22 +3294,116 @@ def mapper(a): # [initStateGroupingOffsets, dedupInitDataOffsets] # ] parsed_offsets = extract_key_value_indexes(arg_offsets) - ser.key_offsets = parsed_offsets[0][0] - ser.init_key_offsets = parsed_offsets[1][0] + key_offsets = parsed_offsets[0][0] + init_key_offsets = parsed_offsets[1][0] + arrow_max_records_per_batch = runner_conf.arrow_max_records_per_batch stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - def mapper(a): - mode = a[0] + def func(_, batches): + """ + Convert raw Arrow batches to (mode, key, (input_rows, init_rows)) tuples, + call mapper, and convert output back to Arrow batches. + """ - if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: - key = a[1] - values = a[2] + def extract_rows(cur_batch, col_name, offsets): + """Extract rows from a struct column in the batch.""" + data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, mode, key, values) - else: - # mode == PROCESS_TIMER or mode == COMPLETE - return f(stateful_processor_api_client, mode, None, iter([])) + # Check if the entire column is null + if data_column.null_count == len(data_column): + return None + + # Flatten struct column to RecordBatch + flattened = pa.RecordBatch.from_arrays( + data_column.flatten(), schema=pa.schema(data_column.type) + ) + + if flattened.num_rows == 0: + return None + + DataRow = Row(*flattened.schema.names) + + def row_iterator(): + for row_idx in range(flattened.num_rows): + key = tuple(flattened[o][row_idx].as_py() for o in offsets) + row = DataRow( + *( + flattened.column(i)[row_idx].as_py() + for i in range(flattened.num_columns) + ) + ) + yield (key, row) + + return row_iterator() + + def generate_data_batches(): + """ + Convert Arrow batches to (key, input_row, init_row) triples. + Each batch contains either inputData or initState, not both. + """ + for batch in batches: + input_result = extract_rows(batch, "inputData", key_offsets) + init_result = extract_rows(batch, "initState", init_key_offsets) + + assert not (input_result is not None and init_result is not None) + + if input_result is not None: + for key, input_data_row in input_result: + yield (key, input_data_row, None) + elif init_result is not None: + for key, init_state_row in init_result: + yield (key, None, init_state_row) + + def generate_tuples(): + """Group by key and yield (mode, key, (input_rows, init_rows)) tuples.""" + data_batches = generate_data_batches() + + for k, g in itertools.groupby(data_batches, key=lambda x: x[0]): + input_rows = [] + init_rows = [] + + for batch_key, input_row, init_row in g: + if input_row is not None: + input_rows.append(input_row) + if init_row is not None: + init_rows.append(init_row) + + total_len = len(input_rows) + len(init_rows) + if total_len >= arrow_max_records_per_batch: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + input_rows = [] + init_rows = [] + + if input_rows or init_rows: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) + + def mapper(a): + mode = a[0] + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] + values = a[2] + return f(stateful_processor_api_client, mode, key, values) + else: + return f(stateful_processor_api_client, mode, None, iter([])) + + def convert_output(results): + """Convert UDF output [(iter[Row], pdf_type)] to Arrow batches.""" + for [(iter_row, pdf_type)] in results: + rows_as_dict = [row.asDict(True) for row in iter_row] + record_batch = pa.RecordBatch.from_pylist( + rows_as_dict, schema=pa.schema(list(pdf_type)) + ) + yield ArrowBatchTransformer.wrap_struct(record_batch) + + return convert_output(map(mapper, generate_tuples())) + + # Return early since func is already defined + return func, None, ser, ser elif ( eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF From cabf94ced67506a120fe20360951fbf21396b04f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 1 Feb 2026 19:08:18 -0800 Subject: [PATCH 39/39] refactor: merge to `enforece_schema` transform --- python/pyspark/sql/conversion.py | 136 +++++++------------- python/pyspark/sql/tests/test_conversion.py | 98 ++++++++++++++ python/pyspark/worker.py | 38 ++++-- 3 files changed, 170 insertions(+), 102 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 7b12dff70abd..dbb6eda069a3 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -124,65 +124,82 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": return pa.RecordBatch.from_arrays([struct], ["_0"]) @classmethod - def coerce_types( - cls, batch: "pa.RecordBatch", arrow_return_type: "pa.StructType" + def enforce_schema( + cls, + batch: "pa.RecordBatch", + return_type: "StructType", + timezone: str = "UTC", + prefer_large_var_types: bool = False, + safecheck: bool = True, ) -> "pa.RecordBatch": """ - Apply type coercion to a RecordBatch based on expected schema. + Enforce target schema on a RecordBatch by reordering columns and coercing types. This method: - 1. Handles empty batches (no columns or empty struct) - returns as-is - 2. Validates field names match expected schema - 3. Casts arrays to expected types if needed (safe casting) + 1. Reorders columns to match the target schema field order by name + 2. Casts column types to match target schema types Parameters ---------- batch : pa.RecordBatch - Input RecordBatch to coerce - arrow_return_type : pa.StructType - Expected Arrow schema for type coercion + Input RecordBatch to transform + return_type : pyspark.sql.types.StructType + Target Spark schema to enforce. + timezone : str, default "UTC" + Timezone for timestamp type conversion. + prefer_large_var_types : bool, default False + If True, use large variable types (large_string, large_binary) in Arrow. + safecheck : bool, default True + If True, use safe casting (fails on overflow/truncation). Returns ------- pa.RecordBatch - RecordBatch with coerced column types + RecordBatch with columns reordered and types coerced to match target schema + + Used by + ------- + - wrap_grouped_map_arrow_udf + - wrap_grouped_map_arrow_iter_udf + - wrap_cogrouped_map_arrow_udf + - SQL_ARROW_UDTF mapper """ import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_schema # Handle empty batch case (no columns) or empty struct case - if batch.num_columns == 0 or len(arrow_return_type) == 0: + if batch.num_columns == 0: return batch - # Check field names match - expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = batch.schema.names + if len(return_type) == 0: + return batch - if expected_field_names != actual_field_names: - raise PySparkTypeError( - "Target schema's field names are not matching the record batch's " - "field names. " - f"Expected: {expected_field_names}, but got: {actual_field_names}." - ) + # Convert Spark StructType to PyArrow schema + arrow_schema = to_arrow_schema( + return_type, timezone=timezone, prefers_large_types=prefer_large_var_types + ) - # Apply type coercion if needed + target_field_names = [field.name for field in arrow_schema] + + # Reorder columns by name and coerce types coerced_arrays = [] - for i, field in enumerate(arrow_return_type): - original_array = batch.column(i) - if original_array.type == field.type: - coerced_arrays.append(original_array) + for field in arrow_schema: + arr = batch.column(field.name) + if arr.type == field.type: + coerced_arrays.append(arr) else: try: - coerced_arrays.append(original_array.cast(target_type=field.type, safe=True)) + coerced_arrays.append(arr.cast(target_type=field.type, safe=safecheck)) except (pa.ArrowInvalid, pa.ArrowTypeError): raise PySparkRuntimeError( errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", messageParameters={ "expected": str(field.type), - "actual": str(original_array.type), + "actual": str(arr.type), }, ) - return pa.RecordBatch.from_arrays(coerced_arrays, names=expected_field_names) + return pa.RecordBatch.from_arrays(coerced_arrays, names=target_field_names) @classmethod def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": @@ -351,69 +368,6 @@ def zip_batches( # Create RecordBatch from columns return pa.RecordBatch.from_arrays(all_columns, ["_%d" % i for i in range(len(all_columns))]) - @classmethod - def reorder_columns( - cls, batch: "pa.RecordBatch", target_schema: Union["pa.StructType", "StructType"] - ) -> "pa.RecordBatch": - """ - Reorder columns in a RecordBatch to match target schema field order. - - This method is useful when columns need to be arranged in a specific order - for schema compatibility, particularly when assign_cols_by_name is enabled. - - Parameters - ---------- - batch : pa.RecordBatch - Input RecordBatch with columns to reorder - target_schema : pa.StructType or pyspark.sql.types.StructType - Target schema defining the desired column order. - Can be either PyArrow StructType or Spark StructType. - - Returns - ------- - pa.RecordBatch - New RecordBatch with columns reordered to match target schema - - Used by - ------- - - wrap_grouped_map_arrow_udf - - wrap_grouped_map_arrow_iter_udf - - wrap_cogrouped_map_arrow_udf - - Examples - -------- - >>> import pyarrow as pa - >>> from pyspark.sql.types import StructType, StructField, IntegerType - >>> batch = pa.RecordBatch.from_arrays([pa.array([1, 2]), pa.array([3, 4])], ['b', 'a']) - >>> # Using PyArrow schema - >>> target_pa = pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.int64())]) - >>> result = ArrowBatchTransformer.reorder_columns(batch, target_pa) - >>> result.schema.names - ['a', 'b'] - >>> # Using Spark schema - >>> target_spark = StructType([StructField('a', IntegerType()), StructField('b', IntegerType())]) - >>> result = ArrowBatchTransformer.reorder_columns(batch, target_spark) - >>> result.schema.names - ['a', 'b'] - """ - import pyarrow as pa - - # Convert Spark StructType to PyArrow StructType if needed - if hasattr(target_schema, "fields") and hasattr(target_schema.fields[0], "dataType"): - # This is Spark StructType - convert to PyArrow - from pyspark.sql.pandas.types import to_arrow_schema - - arrow_schema = to_arrow_schema(target_schema) - field_names = [field.name for field in arrow_schema] - else: - # This is PyArrow StructType - field_names = [field.name for field in target_schema] - - return pa.RecordBatch.from_arrays( - [batch.column(name) for name in field_names], - names=field_names, - ) - @classmethod def to_pandas( cls, diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 54aa58864bc3..99e6a4f40351 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -308,6 +308,104 @@ def test_flatten_and_wrap_roundtrip(self): self.assertEqual(rewrapped.num_columns, 1) self.assertEqual(rewrapped.column(0).to_pylist(), original.column(0).to_pylist()) + def test_enforce_schema_reorder(self): + """Test reordering columns to match target schema.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + # Batch with columns in different order than target + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", LongType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.schema.names, ["a", "b"]) + self.assertEqual(result.column(0).to_pylist(), ["a", "b"]) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + + def test_enforce_schema_coerce_types(self): + """Test coercing column types to match target schema.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + # Batch with int32 that should be coerced to int64 (LongType) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.int32())], + names=["x"], + ) + target = StructType([StructField("x", LongType())]) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.column(0).type, pa.int64()) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + + def test_enforce_schema_reorder_and_coerce(self): + """Test both reordering and type coercion together.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.int32()), pa.array(["x", "y"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", LongType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.schema.names, ["a", "b"]) + self.assertEqual(result.column(0).to_pylist(), ["x", "y"]) + self.assertEqual(result.column(1).type, pa.int64()) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + + def test_enforce_schema_empty_batch(self): + """Test that empty batch is returned as-is.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + batch = pa.RecordBatch.from_arrays([], names=[]) + target = StructType([StructField("x", LongType())]) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.num_columns, 0) + + def test_enforce_schema_with_large_var_types(self): + """Test using prefer_large_var_types option.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", IntegerType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target, prefer_large_var_types=True) + + self.assertEqual(result.schema.names, ["a", "b"]) + # With prefer_large_var_types=True, string should be large_string + self.assertEqual(result.column(0).type, pa.large_string()) + self.assertEqual(result.column(0).to_pylist(), ["a", "b"]) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class PandasBatchTransformerTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d59e76ac109e..34fee1aaa0ec 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -623,10 +623,12 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - # Reorder columns by name if needed, then wrap each batch into struct for batch in result.to_batches(): - if runner_conf.assign_cols_by_name: - batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) yield ArrowBatchTransformer.wrap_struct(batch) return lambda kl, vl, kr, vr: wrapped(kl, vl, kr, vr) @@ -783,10 +785,13 @@ def wrapped(key_batch, value_batches): verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - # Reorder columns by name if needed, then wrap each batch into struct + # Enforce schema (reorder + coerce) and wrap each batch into struct for batch in result.to_batches(): - if runner_conf.assign_cols_by_name: - batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) yield ArrowBatchTransformer.wrap_struct(batch) return lambda k, v: wrapped(k, v) @@ -811,8 +816,11 @@ def wrapped(key_batch, value_batches): for batch in result: verify_arrow_batch(batch, runner_conf.assign_cols_by_name, expected_cols_and_types) - if runner_conf.assign_cols_by_name: - batch = ArrowBatchTransformer.reorder_columns(batch, return_type) + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) yield ArrowBatchTransformer.wrap_struct(batch) return lambda k, v: wrapped(k, v) @@ -2545,14 +2553,18 @@ def raise_conversion_error(original_exception): except Exception as e: raise_conversion_error(e) - # Coerce types for each batch + # Enforce schema (reorder + coerce) for each batch table = verify_result(table) batches = table.to_batches() if len(batches) == 0: # Empty table - create empty batch for lateral join semantics batches = [pa.RecordBatch.from_pylist([], schema=table.schema)] for batch in batches: - yield ArrowBatchTransformer.coerce_types(batch, arrow_return_type) + yield ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) def evaluate(*args: list, num_rows=1): if len(args) == 0: @@ -2678,7 +2690,11 @@ def convert_to_arrow(data: Iterable): messageParameters={}, ) yield from ( - ArrowBatchTransformer.coerce_types(verify_result(batch), arrow_return_type) + ArrowBatchTransformer.enforce_schema( + verify_result(batch), + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) for batch in batches )