From 7d3492f424c66beee4bd62763c2bc3c25e0f2beb Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Fri, 12 Sep 2025 22:00:36 +0000 Subject: [PATCH 1/6] feat: support 'binary' for precision_score --- bigframes/ml/metrics/_metrics.py | 90 +++++++++++++++---- tests/system/small/ml/test_metrics.py | 51 +++++++++++ .../sklearn/metrics/_classification.py | 2 +- 3 files changed, 126 insertions(+), 17 deletions(-) diff --git a/bigframes/ml/metrics/_metrics.py b/bigframes/ml/metrics/_metrics.py index c9639f4b16..03160057fb 100644 --- a/bigframes/ml/metrics/_metrics.py +++ b/bigframes/ml/metrics/_metrics.py @@ -15,9 +15,11 @@ """Metrics functions for evaluating models. This module is styled after scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html.""" +from __future__ import annotations + import inspect import typing -from typing import Tuple, Union +from typing import Literal, overload, Tuple, Union import bigframes_vendored.constants as constants import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification @@ -259,31 +261,64 @@ def recall_score( recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score) +@overload def precision_score( - y_true: Union[bpd.DataFrame, bpd.Series], - y_pred: Union[bpd.DataFrame, bpd.Series], + y_true: bpd.DataFrame | bpd.Series, + y_pred: bpd.DataFrame | bpd.Series, *, - average: typing.Optional[str] = "binary", + pos_label: int | float | bool | str = ..., + average: Literal["binary"] = ..., +) -> float: + ... + + +@overload +def precision_score( + y_true: bpd.DataFrame | bpd.Series, + y_pred: bpd.DataFrame | bpd.Series, + *, + pos_label: int | float | bool | str = ..., + average: None = ..., ) -> pd.Series: - # TODO(ashleyxu): support more average type, default to "binary" - if average is not None: - raise NotImplementedError( - f"Only average=None is supported. {constants.FEEDBACK_LINK}" - ) + ... + +def precision_score( + y_true: bpd.DataFrame | bpd.Series, + y_pred: bpd.DataFrame | bpd.Series, + *, + pos_label: int | float | bool | str = 1, + average: Literal["binary"] | None = "binary", +) -> pd.Series | float: y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred) - is_accurate = y_true_series == y_pred_series + if average is None: + return _precision_score_per_class(y_true_series, y_pred_series) + + if average == "binary": + return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label) + + raise NotImplementedError( + f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}" + ) + + +precision_score.__doc__ = inspect.getdoc( + vendored_metrics_classification.precision_score +) + + +def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series: + is_accurate = y_true == y_pred unique_labels = ( - bpd.concat([y_true_series, y_pred_series], join="outer") + bpd.concat([y_true, y_pred], join="outer") .drop_duplicates() .sort_values(inplace=False) ) index = unique_labels.to_list() precision = ( - is_accurate.groupby(y_pred_series).sum() - / is_accurate.groupby(y_pred_series).count() + is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count() ).to_pandas() precision_score = pd.Series(0, index=index) @@ -293,9 +328,32 @@ def precision_score( return precision_score -precision_score.__doc__ = inspect.getdoc( - vendored_metrics_classification.precision_score -) +def _precision_score_binary_pos_only( + y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str +) -> float: + if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2: + raise ValueError( + "Target is multiclass but average='binary'. Please choose another average setting." + ) + + total_labels = set( + y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list() + ) + + if len(total_labels) != 2: + raise ValueError( + "Target is multiclass but average='binary'. Please choose another average setting." + ) + + if pos_label not in total_labels: + raise ValueError( + f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}" + ) + + target_elem_idx = y_pred == pos_label + is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx] + + return is_accurate.sum() / is_accurate.count() def f1_score( diff --git a/tests/system/small/ml/test_metrics.py b/tests/system/small/ml/test_metrics.py index fd5dbef2e3..7acaf3e8c3 100644 --- a/tests/system/small/ml/test_metrics.py +++ b/tests/system/small/ml/test_metrics.py @@ -743,6 +743,57 @@ def test_precision_score_series(session): ) +@pytest.mark.parametrize( + ("pos_label", "expected_score"), + [ + ("a", 1 / 3), + ("b", 0), + ], +) +def test_precision_score_binary(session, pos_label, expected_score): + pd_df = pd.DataFrame( + { + "y_true": ["a", "a", "a", "b", "b"], + "y_pred": ["b", "b", "a", "a", "a"], + } + ) + df = session.read_pandas(pd_df) + + precision_score = metrics.precision_score( + df["y_true"], df["y_pred"], average="binary", pos_label=pos_label + ) + + assert precision_score == pytest.approx(expected_score) + + +@pytest.mark.parametrize( + ("y_true", "y_pred", "pos_label"), + [ + pytest.param( + pd.Series([1, 2, 3]), pd.Series([1, 0]), 1, id="y_true-non-binary-label" + ), + pytest.param( + pd.Series([1, 0]), pd.Series([1, 2, 3]), 1, id="y_pred-non-binary-label" + ), + pytest.param( + pd.Series([1, 0]), pd.Series([1, 2]), 1, id="combined-non-binary-label" + ), + pytest.param(pd.Series([1, 0]), pd.Series([1, 0]), 2, id="invalid-pos_label"), + ], +) +def test_precision_score_binary_invalid_input_raise_error( + session, y_true, y_pred, pos_label +): + + bf_y_true = session.read_pandas(y_true) + bf_y_pred = session.read_pandas(y_pred) + + with pytest.raises(ValueError): + metrics.precision_score( + bf_y_true, bf_y_pred, average="binary", pos_label=pos_label + ) + + def test_f1_score(session): pd_df = pd.DataFrame( { diff --git a/third_party/bigframes_vendored/sklearn/metrics/_classification.py b/third_party/bigframes_vendored/sklearn/metrics/_classification.py index c1a909e849..fd6e8678ea 100644 --- a/third_party/bigframes_vendored/sklearn/metrics/_classification.py +++ b/third_party/bigframes_vendored/sklearn/metrics/_classification.py @@ -201,7 +201,7 @@ def precision_score( default='binary' This parameter is required for multiclass/multilabel targets. Possible values are 'None', 'micro', 'macro', 'samples', 'weighted', 'binary'. - Only average=None is supported. + Only None and 'binary' is supported. Returns: precision: float (if average is not None) or Series of float of shape \ From ab90b72fb5bac2ec57d4267508470c019cbb32aa Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Fri, 12 Sep 2025 22:02:45 +0000 Subject: [PATCH 2/6] add test --- tests/system/small/ml/test_metrics.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/system/small/ml/test_metrics.py b/tests/system/small/ml/test_metrics.py index 7acaf3e8c3..040d4d97f6 100644 --- a/tests/system/small/ml/test_metrics.py +++ b/tests/system/small/ml/test_metrics.py @@ -766,6 +766,20 @@ def test_precision_score_binary(session, pos_label, expected_score): assert precision_score == pytest.approx(expected_score) +def test_precision_score_binary_default_arguments(session): + pd_df = pd.DataFrame( + { + "y_true": [1, 1, 1, 0, 0], + "y_pred": [0, 0, 1, 1, 1], + } + ) + df = session.read_pandas(pd_df) + + precision_score = metrics.precision_score(df["y_true"], df["y_pred"]) + + assert precision_score == pytest.approx(1 / 3) + + @pytest.mark.parametrize( ("y_true", "y_pred", "pos_label"), [ From 06392d20dadc686be83fc5437138e60cb7ba3a8a Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 15 Sep 2025 18:51:01 +0000 Subject: [PATCH 3/6] use unique(keep_order=False) to count unique items --- bigframes/ml/metrics/_metrics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bigframes/ml/metrics/_metrics.py b/bigframes/ml/metrics/_metrics.py index 03160057fb..b396933303 100644 --- a/bigframes/ml/metrics/_metrics.py +++ b/bigframes/ml/metrics/_metrics.py @@ -331,13 +331,17 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser def _precision_score_binary_pos_only( y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str ) -> float: - if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2: + if ( + y_true.unique(keep_order=False).count() != 2 + or y_pred.unique(keep_order=False).count() != 2 + ): raise ValueError( "Target is multiclass but average='binary'. Please choose another average setting." ) total_labels = set( - y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list() + y_true.unique(keep_order=False).to_list() + + y_pred.unique(keep_order=False).to_list() ) if len(total_labels) != 2: From 8d5d5731c9ee63e4c75d7fbd933b43366359ce4b Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 15 Sep 2025 18:56:15 +0000 Subject: [PATCH 4/6] use local variables to hold unique classes --- bigframes/ml/metrics/_metrics.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/bigframes/ml/metrics/_metrics.py b/bigframes/ml/metrics/_metrics.py index b396933303..d676aada47 100644 --- a/bigframes/ml/metrics/_metrics.py +++ b/bigframes/ml/metrics/_metrics.py @@ -331,18 +331,15 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser def _precision_score_binary_pos_only( y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str ) -> float: - if ( - y_true.unique(keep_order=False).count() != 2 - or y_pred.unique(keep_order=False).count() != 2 - ): + y_true_classes = y_true.unique(keep_order=False) + y_pred_classes = y_pred.unique(keep_order=False) + + if y_true_classes.count() != 2 or y_pred_classes.count() != 2: raise ValueError( "Target is multiclass but average='binary'. Please choose another average setting." ) - total_labels = set( - y_true.unique(keep_order=False).to_list() - + y_pred.unique(keep_order=False).to_list() - ) + total_labels = set(y_true_classes.to_list() + y_pred_classes.to_list()) if len(total_labels) != 2: raise ValueError( From 96758ff9a3b435dbcef93908bb9cc7599795ed87 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 16 Sep 2025 04:50:49 +0000 Subject: [PATCH 5/6] use concat before checking unique labels --- bigframes/ml/metrics/_metrics.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/bigframes/ml/metrics/_metrics.py b/bigframes/ml/metrics/_metrics.py index d676aada47..9194343098 100644 --- a/bigframes/ml/metrics/_metrics.py +++ b/bigframes/ml/metrics/_metrics.py @@ -331,24 +331,16 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser def _precision_score_binary_pos_only( y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str ) -> float: - y_true_classes = y_true.unique(keep_order=False) - y_pred_classes = y_pred.unique(keep_order=False) + unique_labels = bpd.concat([y_true, y_pred]).unique(keep_order=False) - if y_true_classes.count() != 2 or y_pred_classes.count() != 2: + if unique_labels.count() != 2: raise ValueError( "Target is multiclass but average='binary'. Please choose another average setting." ) - total_labels = set(y_true_classes.to_list() + y_pred_classes.to_list()) - - if len(total_labels) != 2: - raise ValueError( - "Target is multiclass but average='binary'. Please choose another average setting." - ) - - if pos_label not in total_labels: + if pos_label not in unique_labels: raise ValueError( - f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}" + f"pos_labe={pos_label} is not a valid label. It should be one of {unique_labels.to_list()}" ) target_elem_idx = y_pred == pos_label From e1c032bdfb0b5e44bc3dd6f6cf79d4274afb8c07 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 16 Sep 2025 04:58:39 +0000 Subject: [PATCH 6/6] fix test --- bigframes/ml/metrics/_metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/metrics/_metrics.py b/bigframes/ml/metrics/_metrics.py index 9194343098..8787a68c58 100644 --- a/bigframes/ml/metrics/_metrics.py +++ b/bigframes/ml/metrics/_metrics.py @@ -293,7 +293,7 @@ def precision_score( y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred) if average is None: - return _precision_score_per_class(y_true_series, y_pred_series) + return _precision_score_per_label(y_true_series, y_pred_series) if average == "binary": return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label) @@ -308,7 +308,7 @@ def precision_score( ) -def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series: +def _precision_score_per_label(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series: is_accurate = y_true == y_pred unique_labels = ( bpd.concat([y_true, y_pred], join="outer") @@ -338,7 +338,7 @@ def _precision_score_binary_pos_only( "Target is multiclass but average='binary'. Please choose another average setting." ) - if pos_label not in unique_labels: + if not (unique_labels == pos_label).any(): raise ValueError( f"pos_labe={pos_label} is not a valid label. It should be one of {unique_labels.to_list()}" )