Skip to content

Commit f21f429

Browse files
committed
#44 Make Split and TableRead Serializable
1 parent 75d00d7 commit f21f429

File tree

5 files changed

+116
-30
lines changed

5 files changed

+116
-30
lines changed
885 Bytes
Binary file not shown.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.python;
20+
21+
import java.io.ByteArrayInputStream;
22+
import java.io.ByteArrayOutputStream;
23+
import java.io.IOException;
24+
import java.io.ObjectInputStream;
25+
import java.io.ObjectOutputStream;
26+
27+
public class SerializationUtil {
28+
public static byte[] serialize(Object obj) throws IOException {
29+
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
30+
ObjectOutputStream objStream = new ObjectOutputStream(byteStream)) {
31+
objStream.writeObject(obj);
32+
return byteStream.toByteArray();
33+
}
34+
}
35+
36+
public static Object deserialize(byte[] bytes) throws IOException, ClassNotFoundException {
37+
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytes);
38+
ObjectInputStream objStream = new ObjectInputStream(byteStream)) {
39+
return objStream.readObject();
40+
}
41+
}
42+
}

pypaimon/api/table_read.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@
3131
class TableRead(ABC):
3232
"""To read data from data splits."""
3333

34-
@abstractmethod
35-
def to_arrow(self, splits: List[Split]) -> pa.Table:
36-
"""Read data from splits and converted to pyarrow.Table format."""
37-
3834
@abstractmethod
3935
def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
4036
"""Read data from splits and converted to pyarrow.RecordBatchReader format."""
4137

38+
@abstractmethod
39+
def to_arrow(self, splits: List[Split]) -> pa.Table:
40+
"""Read data from splits and converted to pyarrow.Table format."""
41+
4242
@abstractmethod
4343
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
4444
"""Read data from splits and converted to pandas.DataFrame format."""

pypaimon/py4j/java_implementation.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from pypaimon.py4j.java_gateway import get_gateway
2525
from pypaimon.py4j.util import java_utils, constants
26+
from pypaimon.py4j.util.java_utils import serialize_java_object, deserialize_java_object
2627
from pypaimon.api import \
2728
(catalog, table, read_builder, table_scan, split, row_type,
2829
table_read, write_builder, table_write, commit_message,
@@ -109,8 +110,9 @@ def new_scan(self) -> 'TableScan':
109110
return TableScan(j_table_scan)
110111

111112
def new_read(self) -> 'TableRead':
112-
j_table_read = self._j_read_builder.newRead().executeFilter()
113-
return TableRead(j_table_read, self._j_read_builder.readType(), self._catalog_options)
113+
j_table_read_bytes = serialize_java_object(self._j_read_builder.newRead().executeFilter())
114+
j_read_type_bytes = serialize_java_object(self._j_read_builder.readType())
115+
return TableRead(j_table_read_bytes, j_read_type_bytes, self._catalog_options)
114116

115117
def new_predicate_builder(self) -> 'PredicateBuilder':
116118
return PredicateBuilder(self._j_row_type)
@@ -145,55 +147,66 @@ def __init__(self, j_splits):
145147
self._j_splits = j_splits
146148

147149
def splits(self) -> List['Split']:
148-
return list(map(lambda s: Split(s), self._j_splits))
150+
return list(map(lambda s: self._build_single_split(s), self._j_splits))
151+
152+
def _build_single_split(self, j_split) -> 'Split':
153+
j_split_bytes = serialize_java_object(j_split)
154+
row_count = j_split.rowCount()
155+
files_optional = j_split.convertToRawFiles()
156+
if not files_optional.isPresent():
157+
file_size = 0
158+
file_paths = []
159+
else:
160+
files = files_optional.get()
161+
file_size = sum(file.length() for file in files)
162+
file_paths = [file.path() for file in files]
163+
return Split(j_split_bytes, row_count, file_size, file_paths)
149164

150165

151166
class Split(split.Split):
152167

153-
def __init__(self, j_split):
154-
self._j_split = j_split
168+
def __init__(self, j_split_bytes, row_count: int, file_size: int, file_paths: List[str]):
169+
self._j_split_bytes = j_split_bytes
170+
self._row_count = row_count
171+
self._file_size = file_size
172+
self._file_paths = file_paths
155173

156174
def to_j_split(self):
157-
return self._j_split
175+
return deserialize_java_object(self._j_split_bytes)
158176

159177
def row_count(self) -> int:
160-
return self._j_split.rowCount()
178+
return self._row_count
161179

162180
def file_size(self) -> int:
163-
files_optional = self._j_split.convertToRawFiles()
164-
if not files_optional.isPresent():
165-
return 0
166-
files = files_optional.get()
167-
return sum(file.length() for file in files)
181+
return self._file_size
168182

169183
def file_paths(self) -> List[str]:
170-
files_optional = self._j_split.convertToRawFiles()
171-
if not files_optional.isPresent():
172-
return []
173-
files = files_optional.get()
174-
return [file.path() for file in files]
184+
return self._file_paths
175185

176186

177187
class TableRead(table_read.TableRead):
178188

179-
def __init__(self, j_table_read, j_read_type, catalog_options):
180-
self._j_table_read = j_table_read
181-
self._j_read_type = j_read_type
189+
def __init__(self, j_table_read_bytes, j_read_type_bytes, catalog_options):
190+
self._j_table_read_bytes = j_table_read_bytes
191+
self._j_read_type_bytes = j_read_type_bytes
182192
self._catalog_options = catalog_options
183-
self._j_bytes_reader = None
184-
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
185193

186-
def to_arrow(self, splits):
187-
record_batch_reader = self.to_arrow_batch_reader(splits)
188-
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
194+
self._j_table_read = None
195+
self._j_read_type = None
196+
self._arrow_schema = None
197+
self._j_bytes_reader = None
189198

190-
def to_arrow_batch_reader(self, splits):
199+
def to_arrow_batch_reader(self, splits) -> pa.RecordBatchReader:
191200
self._init()
192201
j_splits = list(map(lambda s: s.to_j_split(), splits))
193202
self._j_bytes_reader.setSplits(j_splits)
194203
batch_iterator = self._batch_generator()
195204
return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
196205

206+
def to_arrow(self, splits) -> pa.Table:
207+
record_batch_reader = self.to_arrow_batch_reader(splits)
208+
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
209+
197210
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
198211
return self.to_arrow(splits).to_pandas()
199212

@@ -214,6 +227,12 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
214227
return ray.data.from_arrow(self.to_arrow(splits))
215228

216229
def _init(self):
230+
if self._j_table_read is None:
231+
self._j_table_read = deserialize_java_object(self._j_table_read_bytes)
232+
if self._j_read_type is None:
233+
self._j_read_type = deserialize_java_object(self._j_read_type_bytes)
234+
if self._arrow_schema is None:
235+
self._arrow_schema = java_utils.to_arrow_schema(self._j_read_type)
217236
if self._j_bytes_reader is None:
218237
# get thread num
219238
max_workers = self._catalog_options.get(constants.MAX_WORKERS)

pypaimon/py4j/util/java_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,28 @@ def to_arrow_schema(j_row_type):
100100
arrow_schema = schema_reader.schema
101101
schema_reader.close()
102102
return arrow_schema
103+
104+
105+
def serialize_java_object(java_obj) -> bytes:
106+
gateway = get_gateway()
107+
util = gateway.jvm.org.apache.paimon.python.SerializationUtil
108+
try:
109+
java_bytes = util.serialize(java_obj)
110+
return bytes(java_bytes)
111+
except Exception as e:
112+
raise RuntimeError(f"Java serialization failed: {e}")
113+
114+
115+
def deserialize_java_object(bytes_data):
116+
gateway = get_gateway()
117+
util = gateway.jvm.org.apache.paimon.python.SerializationUtil
118+
try:
119+
byte_buffer = gateway.jvm.java.nio.ByteBuffer.allocate(len(bytes_data))
120+
for b in bytes_data:
121+
byte_buffer.put(b if b >= 0 else b + 256)
122+
byte_buffer.flip()
123+
java_bytes = byte_buffer.array()
124+
125+
return util.deserialize(java_bytes)
126+
except Exception as e:
127+
raise RuntimeError(f"Java deserialization failed: {e}")

0 commit comments

Comments
 (0)