Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/src/arrow/csv/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ Status ReadOptions::Validate() const {
return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ",
skip_rows_after_names);
}
if (ARROW_PREDICT_FALSE(max_rows == 0)) {
return Status::Invalid("ReadOptions: max_rows cannot be 0 (use -1 for unlimited): ",
max_rows);
}
if (ARROW_PREDICT_FALSE(max_rows < -1)) {
return Status::Invalid("ReadOptions: max_rows cannot be negative except -1: ",
max_rows);
}
if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) {
return Status::Invalid(
"ReadOptions: autogenerate_column_names cannot be true when column_names are "
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/csv/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ struct ARROW_EXPORT ReadOptions {
/// Number of rows to skip after the column names are read, if any
int32_t skip_rows_after_names = 0;

/// Maximum number of rows to read from the CSV file.
/// If -1 (default), read all rows.
/// If 0, return error (invalid).
/// If positive, read exactly this many rows (or fewer if file is shorter).
int64_t max_rows = -1;

/// Column names for the target table.
/// If empty, fall back on autogenerate_column_names.
std::vector<std::string> column_names;
Expand Down
60 changes: 53 additions & 7 deletions cpp/src/arrow/csv/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,8 @@ class ReaderMixin {
parse_options_(parse_options),
convert_options_(convert_options),
count_rows_(count_rows),
input_(std::move(input)) {}
input_(std::move(input)),
rows_read_(std::make_shared<std::atomic<int64_t>>(0)) {}

protected:
// Read header and column names from buffer, create column builders
Expand Down Expand Up @@ -751,6 +752,9 @@ class ReaderMixin {

std::shared_ptr<io::InputStream> input_;
std::shared_ptr<TaskGroup> task_group_;

// Shared atomic counter for tracking rows read when max_rows is set
std::shared_ptr<std::atomic<int64_t>> rows_read_;
};

/////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -932,16 +936,43 @@ class StreamingReaderImpl : public ReaderMixin,
MakeGeneratorStartsWith({block}, std::move(readahead_gen));

auto bytes_decoded = bytes_decoded_;
auto unwrap_and_record_bytes =
[bytes_decoded, prev_bytes_processed](
auto rows_read = rows_read_;
auto max_rows = read_options_.max_rows;
auto unwrap_record_and_limit =
[bytes_decoded, rows_read, max_rows, prev_bytes_processed](
const DecodedBlock& block) mutable -> Result<std::shared_ptr<RecordBatch>> {
bytes_decoded->fetch_add(block.bytes_processed + prev_bytes_processed);
prev_bytes_processed = 0;
return block.record_batch;

auto batch = block.record_batch;
if (max_rows <= 0 || !batch) {
// No limit, return batch as-is
return batch;
}

// Atomically check and update row counter
int64_t current_rows = rows_read->load(std::memory_order_acquire);
if (current_rows >= max_rows) {
// Already read enough rows, signal end of stream
return std::shared_ptr<RecordBatch>(nullptr);
}

int64_t batch_rows = batch->num_rows();
int64_t rows_to_return = std::min(batch_rows, max_rows - current_rows);

// Update counter atomically
rows_read->fetch_add(rows_to_return, std::memory_order_release);

if (rows_to_return < batch_rows) {
// Need to slice the batch to return exact number of rows
return batch->Slice(0, rows_to_return);
}

return batch;
};

auto unwrapped =
MakeMappedGenerator(std::move(restarted_gen), std::move(unwrap_and_record_bytes));
MakeMappedGenerator(std::move(restarted_gen), std::move(unwrap_record_and_limit));

record_batch_gen_ = MakeCancellable(std::move(unwrapped), io_context_.stop_token());
return Status::OK();
Expand Down Expand Up @@ -998,7 +1029,14 @@ class SerialTableReader : public BaseTableReader {
}
// Finish conversion, create schema and table
RETURN_NOT_OK(task_group_->Finish());
return MakeTable();
ARROW_ASSIGN_OR_RAISE(auto table, MakeTable());

// Apply max_rows limit if needed
if (read_options_.max_rows > 0 && table->num_rows() > read_options_.max_rows) {
table = table->Slice(0, read_options_.max_rows);
}

return table;
}

protected:
Expand Down Expand Up @@ -1078,7 +1116,15 @@ class AsyncThreadedTableReader
})
.Then([self]() -> Result<std::shared_ptr<Table>> {
// Finish conversion, create schema and table
return self->MakeTable();
ARROW_ASSIGN_OR_RAISE(auto table, self->MakeTable());

// Apply max_rows limit if needed
if (self->read_options_.max_rows > 0 &&
table->num_rows() > self->read_options_.max_rows) {
return table->Slice(0, self->read_options_.max_rows);
}

return table;
});
});
}
Expand Down
32 changes: 28 additions & 4 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ cdef class ReadOptions(_Weakrefable):
The order of application is as follows:
- `skip_rows` is applied (if non-zero);
- column names are read (unless `column_names` is set);
- `skip_rows_after_names` is applied (if non-zero).
- `skip_rows_after_names` is applied (if non-zero);
- up to `max_rows` rows are read (if positive).
max_rows : int, optional (default -1)
Maximum number of rows to read from the CSV file.
If -1, read all rows. If a positive number, read exactly that many rows
(or fewer if the file has fewer rows). This parameter counts actual data rows
after applying skip_rows and skip_rows_after_names.
column_names : list, optional
The column names of the target table. If empty, fall back on
`autogenerate_column_names`.
Expand Down Expand Up @@ -186,7 +192,7 @@ cdef class ReadOptions(_Weakrefable):
self.options.reset(new CCSVReadOptions(CCSVReadOptions.Defaults()))

def __init__(self, *, use_threads=None, block_size=None, skip_rows=None,
skip_rows_after_names=None, column_names=None,
skip_rows_after_names=None, max_rows=None, column_names=None,
autogenerate_column_names=None, encoding='utf8'):
if use_threads is not None:
self.use_threads = use_threads
Expand All @@ -196,6 +202,8 @@ cdef class ReadOptions(_Weakrefable):
self.skip_rows = skip_rows
if skip_rows_after_names is not None:
self.skip_rows_after_names = skip_rows_after_names
if max_rows is not None:
self.max_rows = max_rows
if column_names is not None:
self.column_names = column_names
if autogenerate_column_names is not None:
Expand Down Expand Up @@ -257,6 +265,21 @@ cdef class ReadOptions(_Weakrefable):
def skip_rows_after_names(self, value):
deref(self.options).skip_rows_after_names = value

@property
def max_rows(self):
"""
Maximum number of rows to read from the CSV file.

If -1 (default), all rows are read. If a positive number,
exactly that many rows are read (or fewer if the file has fewer rows).
This limit is applied after skip_rows and skip_rows_after_names.
"""
return deref(self.options).max_rows

@max_rows.setter
def max_rows(self, value):
deref(self.options).max_rows = value

@property
def column_names(self):
"""
Expand Down Expand Up @@ -303,6 +326,7 @@ cdef class ReadOptions(_Weakrefable):
self.block_size == other.block_size and
self.skip_rows == other.skip_rows and
self.skip_rows_after_names == other.skip_rows_after_names and
self.max_rows == other.max_rows and
self.column_names == other.column_names and
self.autogenerate_column_names ==
other.autogenerate_column_names and
Expand All @@ -319,12 +343,12 @@ cdef class ReadOptions(_Weakrefable):
def __getstate__(self):
return (self.use_threads, self.block_size, self.skip_rows,
self.column_names, self.autogenerate_column_names,
self.encoding, self.skip_rows_after_names)
self.encoding, self.skip_rows_after_names, self.max_rows)

def __setstate__(self, state):
(self.use_threads, self.block_size, self.skip_rows,
self.column_names, self.autogenerate_column_names,
self.encoding, self.skip_rows_after_names) = state
self.encoding, self.skip_rows_after_names, self.max_rows) = state

def __eq__(self, other):
try:
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
int32_t block_size
int32_t skip_rows
int32_t skip_rows_after_names
int64_t max_rows
vector[c_string] column_names
c_bool autogenerate_column_names

Expand Down
146 changes: 146 additions & 0 deletions python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,3 +2065,149 @@ def readinto(self, *args):
for i in range(20):
with pytest.raises(pa.ArrowInvalid):
read_csv(MyBytesIO(data))


def test_max_rows_basic():
"""Test basic max_rows limiting"""
rows = b"a,b,c\n1,2,3\n4,5,6\n7,8,9\n10,11,12\n"

# Read all rows (default)
table = read_csv(io.BytesIO(rows))
assert table.num_rows == 4

# Read 2 rows
opts = ReadOptions(max_rows=2)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 2
assert table.to_pydict() == {"a": [1, 4], "b": [2, 5], "c": [3, 6]}

# Read 1 row
opts = ReadOptions(max_rows=1)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 1
assert table.to_pydict() == {"a": [1], "b": [2], "c": [3]}

# Read more than available (should return all)
opts = ReadOptions(max_rows=100)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 4


def test_max_rows_with_skip_rows():
"""Test max_rows interaction with skip_rows"""
rows = b"# comment\na,b,c\n1,2,3\n4,5,6\n7,8,9\n"

opts = ReadOptions(skip_rows=1, max_rows=2)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 2
assert list(table.column_names) == ["a", "b", "c"]
assert table.to_pydict() == {"a": [1, 4], "b": [2, 5], "c": [3, 6]}


def test_max_rows_with_skip_rows_after_names():
"""Test max_rows interaction with skip_rows_after_names"""
rows = b"a,b,c\nSKIP1\nSKIP2\n1,2,3\n4,5,6\n7,8,9\n"

opts = ReadOptions(skip_rows_after_names=2, max_rows=2)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 2
assert table.to_pydict() == {"a": [1, 4], "b": [2, 5], "c": [3, 6]}


def test_max_rows_edge_cases():
"""Test edge cases for max_rows"""
rows = b"a,b\n1,2\n3,4\n"

# max_rows = 0 should raise error
opts = ReadOptions(max_rows=0)
with pytest.raises(pa.ArrowInvalid, match="max_rows cannot be 0"):
read_csv(io.BytesIO(rows), read_options=opts)

# Negative max_rows (other than -1) should raise error
opts = ReadOptions(max_rows=-5)
with pytest.raises(pa.ArrowInvalid, match="max_rows cannot be negative except -1"):
read_csv(io.BytesIO(rows), read_options=opts)

# max_rows = -1 should read all
opts = ReadOptions(max_rows=-1)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 2


def test_max_rows_with_small_blocks():
"""Test max_rows with block_size smaller than max_rows"""
# Create CSV with many rows
num_rows = 100
csv_data = "a,b,c\n"
for i in range(num_rows):
csv_data += f"{i},{i+1},{i+2}\n"
rows = csv_data.encode()

# Use small block size to force multiple blocks
opts = ReadOptions(block_size=50, max_rows=15)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 15 # Exact count

# Verify first and last row
assert table.column("a")[0].as_py() == "0"
assert table.column("a")[14].as_py() == "14"


def test_max_rows_multithreaded():
"""Test max_rows with use_threads=True"""
# Create large CSV to ensure parallel processing
num_rows = 1000
csv_data = "a,b,c\n"
for i in range(num_rows):
csv_data += f"{i},{i+1},{i+2}\n"
rows = csv_data.encode()

opts = ReadOptions(use_threads=True, max_rows=50)
table = read_csv(io.BytesIO(rows), read_options=opts)
assert table.num_rows == 50 # Must be exact, not approximate

# Verify rows are in order (0-49)
a_values = table.column("a").to_pylist()
expected = [str(i) for i in range(50)]
assert a_values == expected


def test_max_rows_streaming():
"""Test max_rows with streaming reader"""
rows = b"a,b,c\n1,2,3\n4,5,6\n7,8,9\n10,11,12\n13,14,15\n"

opts = ReadOptions(max_rows=3)
reader = open_csv(io.BytesIO(rows), read_options=opts)

# Read all batches
batches = []
while True:
try:
batch = reader.read_next_batch()
batches.append(batch)
except StopIteration:
break

# Concatenate all batches
table = pa.Table.from_batches(batches)

# Must have exactly 3 rows
assert table.num_rows == 3
assert table.to_pydict() == {
"a": ["1", "4", "7"],
"b": ["2", "5", "8"],
"c": ["3", "6", "9"]
}


def test_max_rows_pickle():
"""Test that max_rows is preserved through pickle"""
import pickle

opts = ReadOptions(max_rows=42, skip_rows=1)
pickled = pickle.dumps(opts)
unpickled = pickle.loads(pickled)

assert unpickled.max_rows == 42
assert unpickled.skip_rows == 1
assert opts.equals(unpickled)