diff --git a/bigframes/ml/ensemble.py b/bigframes/ml/ensemble.py index 2633f13411..98d62ad951 100644 --- a/bigframes/ml/ensemble.py +++ b/bigframes/ml/ensemble.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union import bigframes_vendored.sklearn.ensemble._forest import bigframes_vendored.xgboost.sklearn @@ -78,6 +78,7 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + **kwargs: Union[str, str | int | bool | float | List[str]], ): self.n_estimators = n_estimators self.booster = booster @@ -99,6 +100,7 @@ def __init__( self.xgboost_version = xgboost_version self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() + self._extra_bqml_options = kwargs @classmethod def _from_bq( @@ -117,7 +119,7 @@ def _from_bq( @property def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: """The model options as they will be set for BQML""" - return { + options = { "model_type": "BOOSTED_TREE_REGRESSOR", "data_split_method": "NO_SPLIT", "early_stop": True, @@ -139,6 +141,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "enable_global_explain": self.enable_global_explain, "xgboost_version": self.xgboost_version, } + options.update(self._extra_bqml_options) + return options # type: ignore def _fit( self, @@ -237,6 +241,7 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + **kwargs: Union[str, str | int | bool | float | List[str]], ): self.n_estimators = n_estimators self.booster = booster @@ -258,6 +263,7 @@ def __init__( self.xgboost_version = xgboost_version self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() + self._extra_bqml_options = kwargs @classmethod def _from_bq( @@ -276,7 +282,7 @@ def _from_bq( @property def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: """The model options as they will be set for BQML""" - return { + options = { "model_type": "BOOSTED_TREE_CLASSIFIER", "data_split_method": "NO_SPLIT", "early_stop": True, @@ -298,6 +304,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "enable_global_explain": self.enable_global_explain, "xgboost_version": self.xgboost_version, } + options.update(self._extra_bqml_options) + return options # type: ignore def _fit( self, diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 7f6843aacf..de3b67c7b9 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap from unittest import mock from google.cloud import bigquery @@ -19,7 +20,7 @@ import pytest import bigframes -from bigframes.ml import core, decomposition, linear_model +from bigframes.ml import core, decomposition, ensemble, linear_model import bigframes.ml.core import bigframes.pandas as bpd @@ -286,3 +287,83 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X): "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))", allow_large_results=True, ) + + +def test_xgb_classifier_kwargs_params_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.XGBClassifier(category_encoding_method="LABEL_ENCODING") + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + textwrap.dedent( + """ + CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id` + OPTIONS( + model_type='BOOSTED_TREE_CLASSIFIER', + data_split_method='NO_SPLIT', + early_stop=True, + num_parallel_tree=1, + booster_type='gbtree', + tree_method='auto', + min_tree_child_weight=1, + colsample_bytree=1.0, + colsample_bylevel=1.0, + colsample_bynode=1.0, + min_split_loss=0.0, + max_tree_depth=6, + subsample=1.0, + l1_reg=0.0, + l2_reg=1.0, + learn_rate=0.3, + max_iterations=20, + min_rel_progress=0.01, + enable_global_explain=False, + xgboost_version='0.9', + category_encoding_method='LABEL_ENCODING', + INPUT_LABEL_COLS=['input_column_label']) + AS input_X_y_no_index_sql + """ + ).strip() + ) + + +def test_xgb_regressor_kwargs_params_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.XGBRegressor(category_encoding_method="LABEL_ENCODING") + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + textwrap.dedent( + """ + CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id` + OPTIONS( + model_type='BOOSTED_TREE_REGRESSOR', + data_split_method='NO_SPLIT', + early_stop=True, + num_parallel_tree=1, + booster_type='gbtree', + tree_method='auto', + min_tree_child_weight=1, + colsample_bytree=1.0, + colsample_bylevel=1.0, + colsample_bynode=1.0, + min_split_loss=0.0, + max_tree_depth=6, + subsample=1.0, + l1_reg=0.0, + l2_reg=1.0, + learn_rate=0.3, + max_iterations=20, + min_rel_progress=0.01, + enable_global_explain=False, + xgboost_version='0.9', + category_encoding_method='LABEL_ENCODING', + INPUT_LABEL_COLS=['input_column_label']) + AS input_X_y_no_index_sql + """ + ).strip() + ) diff --git a/third_party/bigframes_vendored/xgboost/sklearn.py b/third_party/bigframes_vendored/xgboost/sklearn.py index 60a22e83d0..133b45cd4c 100644 --- a/third_party/bigframes_vendored/xgboost/sklearn.py +++ b/third_party/bigframes_vendored/xgboost/sklearn.py @@ -98,9 +98,18 @@ class XGBRegressor(XGBModel, XGBRegressorBase): tol (Optional[float]): Minimum relative loss improvement necessary to continue training. Default to 0.01. enable_global_explain (Optional[bool]): - Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. + Whether to compute global explanations using explainable AI to + evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + kwargs (dict): + Keyword arguments for the ``model_option_list`` of the boosted tree + BQML model. See + https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree + + For example, to set ``CATEGORY_ENCODING_METHOD`` to + ``LABEL_ENCODING``, pass in the keyword argument + `category_encoding_method='LABEL_ENCODING'`. """ @@ -148,4 +157,12 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + kwargs (dict): + Keyword arguments for the ``model_option_list`` of the boosted tree + BQML model. See + https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree + + For example, to set ``CATEGORY_ENCODING_METHOD`` to + ``LABEL_ENCODING``, pass in the keyword argument + `category_encoding_method='LABEL_ENCODING'`. """