Skip to content

Commit dd9dfaa

Browse files
committed
#46 Improve Readability of TableRead Impletation
1 parent 75d00d7 commit dd9dfaa

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

pypaimon/api/table_read.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@
3131
class TableRead(ABC):
3232
"""To read data from data splits."""
3333

34-
@abstractmethod
35-
def to_arrow(self, splits: List[Split]) -> pa.Table:
36-
"""Read data from splits and converted to pyarrow.Table format."""
37-
3834
@abstractmethod
3935
def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
4036
"""Read data from splits and converted to pyarrow.RecordBatchReader format."""
4137

38+
@abstractmethod
39+
def to_arrow(self, splits: List[Split]) -> pa.Table:
40+
"""Read data from splits and converted to pyarrow.Table format."""
41+
4242
@abstractmethod
4343
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
4444
"""Read data from splits and converted to pandas.DataFrame format."""

pypaimon/py4j/java_implementation.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,23 +177,20 @@ def file_paths(self) -> List[str]:
177177
class TableRead(table_read.TableRead):
178178

179179
def __init__(self, j_table_read, j_read_type, catalog_options):
180-
self._j_table_read = j_table_read
181-
self._j_read_type = j_read_type
182-
self._catalog_options = catalog_options
183-
self._j_bytes_reader = None
184180
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
181+
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
182+
j_table_read, j_read_type, TableRead._get_max_workers(catalog_options))
185183

186-
def to_arrow(self, splits):
187-
record_batch_reader = self.to_arrow_batch_reader(splits)
188-
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
189-
190-
def to_arrow_batch_reader(self, splits):
191-
self._init()
184+
def to_arrow_batch_reader(self, splits) -> pa.RecordBatchReader:
192185
j_splits = list(map(lambda s: s.to_j_split(), splits))
193186
self._j_bytes_reader.setSplits(j_splits)
194187
batch_iterator = self._batch_generator()
195188
return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
196189

190+
def to_arrow(self, splits) -> pa.Table:
191+
record_batch_reader = self.to_arrow_batch_reader(splits)
192+
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
193+
197194
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
198195
return self.to_arrow(splits).to_pandas()
199196

@@ -213,19 +210,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
213210

214211
return ray.data.from_arrow(self.to_arrow(splits))
215212

216-
def _init(self):
217-
if self._j_bytes_reader is None:
218-
# get thread num
219-
max_workers = self._catalog_options.get(constants.MAX_WORKERS)
220-
if max_workers is None:
221-
# default is sequential
222-
max_workers = 1
223-
else:
224-
max_workers = int(max_workers)
225-
if max_workers <= 0:
226-
raise ValueError("max_workers must be greater than 0")
227-
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
228-
self._j_table_read, self._j_read_type, max_workers)
213+
@staticmethod
214+
def _get_max_workers(catalog_options):
215+
# default is sequential
216+
max_workers = int(catalog_options.get(constants.MAX_WORKERS, 1))
217+
if max_workers <= 0:
218+
raise ValueError("max_workers must be greater than 0")
219+
return max_workers
229220

230221
def _batch_generator(self) -> Iterator[pa.RecordBatch]:
231222
while True:

0 commit comments

Comments
 (0)