Skip to content

Commit 41e6ad2

Browse files
committed
feat: Add configurable display options for PyDataFrame
- Introduced DisplayConfig struct to manage display settings such as max_table_bytes, min_table_rows, and max_cell_length. - Updated PyDataFrame to utilize DisplayConfig for rendering and displaying DataFrames. - Added methods to configure and reset display settings, allowing users to customize their DataFrame presentation in Python.
1 parent d0315ff commit 41e6ad2

File tree

1 file changed

+71
-11
lines changed

1 file changed

+71
-11
lines changed

src/dataframe.rs

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,27 @@ impl PyTableProvider {
7272
PyTable::new(table_provider)
7373
}
7474
}
75-
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
76-
const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
77-
const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
75+
76+
/// Configuration for DataFrame display in Python environment
77+
#[derive(Debug, Clone)]
78+
pub struct DisplayConfig {
79+
/// Maximum bytes to display for table presentation (default: 2MB)
80+
pub max_table_bytes: usize,
81+
/// Minimum number of table rows to display (default: 20)
82+
pub min_table_rows: usize,
83+
/// Maximum length of a cell before it gets minimized (default: 25)
84+
pub max_cell_length: usize,
85+
}
86+
87+
impl Default for DisplayConfig {
88+
fn default() -> Self {
89+
Self {
90+
max_table_bytes: 2 * 1024 * 1024, // 2 MB
91+
min_table_rows: 20,
92+
max_cell_length: 25,
93+
}
94+
}
95+
}
7896

7997
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
8098
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -83,12 +101,16 @@ const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
83101
#[derive(Clone)]
84102
pub struct PyDataFrame {
85103
df: Arc<DataFrame>,
104+
config: Arc<DisplayConfig>,
86105
}
87106

88107
impl PyDataFrame {
89108
/// creates a new PyDataFrame
90109
pub fn new(df: DataFrame) -> Self {
91-
Self { df: Arc::new(df) }
110+
Self {
111+
df: Arc::new(df),
112+
config: Arc::new(DisplayConfig::default()),
113+
}
92114
}
93115
}
94116

@@ -118,7 +140,7 @@ impl PyDataFrame {
118140
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
119141
let (batches, has_more) = wait_for_future(
120142
py,
121-
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
143+
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10, &self.config),
122144
)?;
123145
if batches.is_empty() {
124146
// This should not be reached, but do it for safety since we index into the vector below
@@ -141,8 +163,9 @@ impl PyDataFrame {
141163
py,
142164
collect_record_batches_to_display(
143165
self.df.as_ref().clone(),
144-
MIN_TABLE_ROWS_TO_DISPLAY,
166+
self.config.min_table_rows,
145167
usize::MAX,
168+
&self.config,
146169
),
147170
)?;
148171
if batches.is_empty() {
@@ -218,8 +241,8 @@ impl PyDataFrame {
218241
for (col, formatter) in batch_formatter.iter().enumerate() {
219242
let cell_data = formatter.value(batch_row).to_string();
220243
// From testing, primitive data types do not typically get larger than 21 characters
221-
if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
222-
let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
244+
if cell_data.len() > self.config.max_cell_length {
245+
let short_cell_data = &cell_data[0..self.config.max_cell_length];
223246
cells.push(format!("
224247
<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
225248
<div class=\"expandable-container\">
@@ -797,6 +820,42 @@ impl PyDataFrame {
797820
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
798821
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
799822
}
823+
824+
/// Get the current display configuration
825+
#[getter]
826+
fn display_config(&self) -> DisplayConfig {
827+
(*self.config).clone()
828+
}
829+
830+
/// Update display configuration
831+
#[pyo3(signature = (max_table_bytes=None, min_table_rows=None, max_cell_length=None))]
832+
fn configure_display(
833+
&mut self,
834+
max_table_bytes: Option<usize>,
835+
min_table_rows: Option<usize>,
836+
max_cell_length: Option<usize>,
837+
) {
838+
let mut new_config = (*self.config).clone();
839+
840+
if let Some(bytes) = max_table_bytes {
841+
new_config.max_table_bytes = bytes;
842+
}
843+
844+
if let Some(rows) = min_table_rows {
845+
new_config.min_table_rows = rows;
846+
}
847+
848+
if let Some(length) = max_cell_length {
849+
new_config.max_cell_length = length;
850+
}
851+
852+
self.config = Arc::new(new_config);
853+
}
854+
855+
/// Reset display configuration to default values
856+
fn reset_display_config(&mut self) {
857+
self.config = Arc::new(DisplayConfig::default());
858+
}
800859
}
801860

802861
/// Print DataFrame
@@ -886,6 +945,7 @@ async fn collect_record_batches_to_display(
886945
df: DataFrame,
887946
min_rows: usize,
888947
max_rows: usize,
948+
config: &DisplayConfig,
889949
) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
890950
let partitioned_stream = df.execute_stream_partitioned().await?;
891951
let mut stream = futures::stream::iter(partitioned_stream).flatten();
@@ -894,7 +954,7 @@ async fn collect_record_batches_to_display(
894954
let mut record_batches = Vec::default();
895955
let mut has_more = false;
896956

897-
while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
957+
while (size_estimate_so_far < config.max_table_bytes && rows_so_far < max_rows)
898958
|| rows_so_far < min_rows
899959
{
900960
let mut rb = match stream.next().await {
@@ -909,8 +969,8 @@ async fn collect_record_batches_to_display(
909969
if rows_in_rb > 0 {
910970
size_estimate_so_far += rb.get_array_memory_size();
911971

912-
if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
913-
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
972+
if size_estimate_so_far > config.max_table_bytes {
973+
let ratio = config.max_table_bytes as f32 / size_estimate_so_far as f32;
914974
let total_rows = rows_in_rb + rows_so_far;
915975

916976
let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;

0 commit comments

Comments
 (0)