Skip to content

Commit ceccbaa

Browse files
committed
support thresh in dropna
1 parent e43d15d commit ceccbaa

File tree

4 files changed

+130
-37
lines changed

4 files changed

+130
-37
lines changed

bigframes/core/block_transforms.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def dropna(
523523
block: blocks.Block,
524524
column_ids: typing.Sequence[str],
525525
how: typing.Literal["all", "any"] = "any",
526+
thresh: typing.Optional[int] = None,
526527
subset: Optional[typing.Sequence[str]] = None,
527528
):
528529
"""
@@ -531,18 +532,46 @@ def dropna(
531532
if subset is None:
532533
subset = column_ids
533534

534-
predicates = [
535-
ops.notnull_op.as_expr(column_id)
536-
for column_id in column_ids
537-
if column_id in subset
538-
]
539-
if len(predicates) == 0:
540-
return block
541-
if how == "any":
542-
predicate = functools.reduce(ops.and_op.as_expr, predicates)
543-
else: # "all"
544-
predicate = functools.reduce(ops.or_op.as_expr, predicates)
545-
return block.filter(predicate)
535+
if thresh is not None:
536+
# Count non-null values per row
537+
notnull_predicates = [
538+
ops.notnull_op.as_expr(column_id)
539+
for column_id in column_ids
540+
if column_id in subset
541+
]
542+
543+
if len(notnull_predicates) == 0:
544+
return block
545+
546+
# Handle single predicate case
547+
if len(notnull_predicates) == 1:
548+
count_expr = ops.AsTypeOp(pd.Int64Dtype()).as_expr(notnull_predicates[0])
549+
else:
550+
# Sum the boolean expressions to count non-null values
551+
count_expr = functools.reduce(
552+
lambda a, b: ops.add_op.as_expr(
553+
ops.AsTypeOp(pd.Int64Dtype()).as_expr(a),
554+
ops.AsTypeOp(pd.Int64Dtype()).as_expr(b),
555+
),
556+
notnull_predicates,
557+
)
558+
559+
# Filter rows where count >= thresh
560+
thresh_predicate = ops.ge_op.as_expr(count_expr, ex.const(thresh))
561+
return block.filter(thresh_predicate)
562+
else:
563+
predicates = [
564+
ops.notnull_op.as_expr(column_id)
565+
for column_id in column_ids
566+
if column_id in subset
567+
]
568+
if len(predicates) == 0:
569+
return block
570+
if how == "any":
571+
predicate = functools.reduce(ops.and_op.as_expr, predicates)
572+
else: # "all"
573+
predicate = functools.reduce(ops.or_op.as_expr, predicates)
574+
return block.filter(predicate)
546575

547576

548577
def nsmallest(

bigframes/dataframe.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,7 +2801,8 @@ def dropna(
28012801
self,
28022802
*,
28032803
axis: int | str = 0,
2804-
how: str = "any",
2804+
how: typing.Literal["all", "any"] = "any",
2805+
thresh: typing.Optional[int] = None,
28052806
subset: typing.Union[None, blocks.Label, Sequence[blocks.Label]] = None,
28062807
inplace: bool = False,
28072808
ignore_index=False,
@@ -2810,6 +2811,10 @@ def dropna(
28102811
raise NotImplementedError(
28112812
f"'inplace'=True not supported. {constants.FEEDBACK_LINK}"
28122813
)
2814+
if thresh is not None and how != "any":
2815+
raise TypeError(
2816+
"You cannot set both the how and thresh arguments at the same time."
2817+
)
28132818
if how not in ("any", "all"):
28142819
raise ValueError("'how' must be one of 'any', 'all'")
28152820

@@ -2833,21 +2838,41 @@ def dropna(
28332838
for id_ in self._block.label_to_col_id[label]
28342839
]
28352840

2836-
result = block_ops.dropna(self._block, self._block.value_columns, how=how, subset=subset_ids) # type: ignore
2841+
result = block_ops.dropna(
2842+
self._block,
2843+
self._block.value_columns,
2844+
how=how,
2845+
thresh=thresh,
2846+
subset=subset_ids,
2847+
) # type: ignore
28372848
if ignore_index:
28382849
result = result.reset_index()
28392850
return DataFrame(result)
28402851
else:
2841-
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
2842-
if how == "any":
2843-
null_locations = DataFrame(isnull_block).any().to_pandas()
2844-
else: # 'all'
2845-
null_locations = DataFrame(isnull_block).all().to_pandas()
2846-
keep_columns = [
2847-
col
2848-
for col, to_drop in zip(self._block.value_columns, null_locations)
2849-
if not to_drop
2850-
]
2852+
if thresh is not None:
2853+
# Count non-null values per column
2854+
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
2855+
notnull_block = self._block.multi_apply_unary_op(ops.notnull_op)
2856+
2857+
# Sum non-null values for each column
2858+
notnull_counts = DataFrame(notnull_block).sum().to_pandas()
2859+
2860+
keep_columns = [
2861+
col
2862+
for col, count in zip(self._block.value_columns, notnull_counts)
2863+
if count >= thresh
2864+
]
2865+
else:
2866+
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
2867+
if how == "any":
2868+
null_locations = DataFrame(isnull_block).any().to_pandas()
2869+
else: # 'all'
2870+
null_locations = DataFrame(isnull_block).all().to_pandas()
2871+
keep_columns = [
2872+
col
2873+
for col, to_drop in zip(self._block.value_columns, null_locations)
2874+
if not to_drop
2875+
]
28512876
return DataFrame(self._block.select_columns(keep_columns))
28522877

28532878
def any(

tests/system/small/test_dataframe.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,26 +1181,41 @@ def test_assign_callable_lambda(scalars_dfs):
11811181

11821182

11831183
@pytest.mark.parametrize(
1184-
("axis", "how", "ignore_index", "subset"),
1184+
("axis", "how", "ignore_index", "subset", "thresh"),
11851185
[
1186-
(0, "any", False, None),
1187-
(0, "any", True, None),
1188-
(0, "all", False, ["bool_col", "time_col"]),
1189-
(0, "any", False, ["bool_col", "time_col"]),
1190-
(0, "all", False, "time_col"),
1191-
(1, "any", False, None),
1192-
(1, "all", False, None),
1186+
(0, "any", False, None, None),
1187+
(0, "any", True, None, None),
1188+
(0, "all", False, ["bool_col", "time_col"], None),
1189+
(0, "any", False, ["bool_col", "time_col"], None),
1190+
(0, "all", False, "time_col", None),
1191+
(1, "any", False, None, None),
1192+
(1, "all", False, None, None),
1193+
(0, "any", False, None, 2),
1194+
(0, "any", True, None, 3),
1195+
(1, "any", False, None, 2),
11931196
],
11941197
)
1195-
def test_df_dropna(scalars_dfs, axis, how, ignore_index, subset):
1198+
def test_df_dropna(scalars_dfs, axis, how, ignore_index, subset, thresh):
11961199
# TODO: supply a reason why this isn't compatible with pandas 1.x
11971200
pytest.importorskip("pandas", minversion="2.0.0")
11981201
scalars_df, scalars_pandas_df = scalars_dfs
1199-
df = scalars_df.dropna(axis=axis, how=how, ignore_index=ignore_index, subset=subset)
1202+
1203+
if thresh is not None:
1204+
df = scalars_df.dropna(
1205+
axis=axis, thresh=thresh, ignore_index=ignore_index, subset=subset
1206+
)
1207+
pd_result = scalars_pandas_df.dropna(
1208+
axis=axis, thresh=thresh, ignore_index=ignore_index, subset=subset
1209+
)
1210+
else:
1211+
df = scalars_df.dropna(
1212+
axis=axis, how=how, ignore_index=ignore_index, subset=subset
1213+
)
1214+
pd_result = scalars_pandas_df.dropna(
1215+
axis=axis, how=how, ignore_index=ignore_index, subset=subset
1216+
)
1217+
12001218
bf_result = df.to_pandas()
1201-
pd_result = scalars_pandas_df.dropna(
1202-
axis=axis, how=how, ignore_index=ignore_index, subset=subset
1203-
)
12041219

12051220
# Pandas uses int64 instead of Int64 (nullable) dtype.
12061221
pd_result.index = pd_result.index.astype(pd.Int64Dtype())

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,7 @@ def dropna(
17621762
*,
17631763
axis: int | str = 0,
17641764
how: str = "any",
1765+
thresh: Optional[int] = None,
17651766
subset=None,
17661767
inplace: bool = False,
17671768
ignore_index=False,
@@ -1812,6 +1813,25 @@ def dropna(
18121813
<BLANKLINE>
18131814
[3 rows x 3 columns]
18141815
1816+
Keep rows with at least 2 non-null values.
1817+
1818+
>>> df.dropna(thresh=2)
1819+
name toy born
1820+
1 Batman Batmobile 1940-04-25
1821+
2 Catwoman Bullwhip <NA>
1822+
<BLANKLINE>
1823+
[2 rows x 3 columns]
1824+
1825+
Keep columns with at least 2 non-null values:
1826+
1827+
>>> df.dropna(axis='columns', thresh=2)
1828+
name toy
1829+
0 Alfred <NA>
1830+
1 Batman Batmobile
1831+
2 Catwoman Bullwhip
1832+
<BLANKLINE>
1833+
[3 rows x 2 columns]
1834+
18151835
Define in which columns to look for missing values.
18161836
18171837
>>> df.dropna(subset=['name', 'toy'])
@@ -1834,6 +1854,8 @@ def dropna(
18341854
18351855
* 'any' : If any NA values are present, drop that row or column.
18361856
* 'all' : If all values are NA, drop that row or column.
1857+
typing(int, optional):
1858+
Require that many non-NA values. Cannot be combined with how.
18371859
subset (column label or sequence of labels, optional):
18381860
Labels along other axis to consider, e.g. if you are dropping
18391861
rows these would be a list of columns to include.
@@ -1851,6 +1873,8 @@ def dropna(
18511873
Raises:
18521874
ValueError:
18531875
If ``how`` is not one of ``any`` or ``all``.
1876+
TyperError:
1877+
If both ``how`` and ``thresh`` are specified.
18541878
"""
18551879
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
18561880

0 commit comments

Comments
 (0)