2727 SparkXGBRegressor ,
2828 SparkXGBRegressorModel ,
2929)
30+ from xgboost .spark .utils import _get_max_num_concurrent_tasks
3031from xgboost .testing .collective import get_avail_port
3132
3233logging .getLogger ("py4j" ).setLevel (logging .INFO )
3334
3435pytestmark = [tm .timeout (60 ), pytest .mark .skipif (** tm .no_spark ())]
3536
37+ DUAL_SPARK_MODES = [
38+ pytest .param ("local" , id = "local" ),
39+ pytest .param ("local_cluster" , id = "local_cluster" ),
40+ ]
41+
3642RegData = namedtuple (
3743 "RegData" ,
3844 (
@@ -76,7 +82,10 @@ def no_sparse_unwrap() -> tm.PytestSkip:
7682
7783
7884@pytest .fixture (scope = "module" )
79- def spark () -> Generator [SparkSession , None , None ]:
85+ def spark (request : pytest .FixtureRequest ) -> Generator [SparkSession , None , None ]:
86+ mode = getattr (request , "param" , "local" )
87+ if mode not in {"local" , "local_cluster" }:
88+ raise ValueError (f"Unknown Spark test mode: { mode } " )
8089 os .environ ["XGBOOST_PYSPARK_SHARED_SESSION" ] = "1"
8190 config = {
8291 "spark.master" : "local[4]" ,
@@ -88,12 +97,30 @@ def spark() -> Generator[SparkSession, None, None]:
8897 "spark.sql.pyspark.jvmStacktrace.enabled" : "true" ,
8998 "spark.ui.enabled" : "false" ,
9099 }
100+ if mode == "local_cluster" :
101+ config .update (
102+ {
103+ "spark.master" : "local-cluster[2, 1, 1024]" ,
104+ "spark.cores.max" : "2" ,
105+ "spark.task.cpus" : "1" ,
106+ "spark.executor.cores" : "1" ,
107+ }
108+ )
91109
92110 builder = SparkSession .builder .appName ("XGBoost PySpark Python API Tests" )
93111 for k , v in config .items ():
94112 builder .config (k , v )
95113 logging .getLogger ("pyspark" ).setLevel (logging .INFO )
96114 sess = builder .getOrCreate ()
115+ if mode == "local_cluster" :
116+ # Block until workers are connected.
117+ num_slots = sess .sparkContext .defaultParallelism
118+ (
119+ sess .sparkContext .parallelize (range (num_slots ), num_slots )
120+ .barrier ()
121+ .mapPartitions (lambda _ : [])
122+ .collect ()
123+ )
97124 try :
98125 yield sess
99126 finally :
@@ -102,6 +129,11 @@ def spark() -> Generator[SparkSession, None, None]:
102129 os .environ .pop ("XGBOOST_PYSPARK_SHARED_SESSION" , None )
103130
104131
132+ @pytest .fixture (scope = "module" )
133+ def num_workers (spark : SparkSession ) -> int :
134+ return _get_max_num_concurrent_tasks (spark .sparkContext )
135+
136+
105137class TestRegressor :
106138 @pytest .fixture (scope = "class" )
107139 def reg_data (self , spark : SparkSession ) -> RegData :
@@ -141,7 +173,10 @@ def reg_data(self, spark: SparkSession) -> RegData:
141173 X_train , X_test , y_train , y_test , w , base_margin , is_val , X , y , df
142174 )
143175
144- def test_regressor (self , reg_data : RegData ) -> None :
176+ @pytest .mark .parametrize ("spark" , DUAL_SPARK_MODES , indirect = True )
177+ def test_regressor (
178+ self , spark : SparkSession , reg_data : RegData , num_workers : int
179+ ) -> None :
145180 train_rows = np .where (~ reg_data .is_val )[0 ]
146181 validation_rows = np .where (reg_data .is_val )[0 ]
147182
@@ -164,6 +199,7 @@ def test_regressor(self, reg_data: RegData) -> None:
164199 pred_contrib_col = "pred_contribs" ,
165200 weight_col = "weight" ,
166201 validation_indicator_col = "is_val" ,
202+ num_workers = num_workers ,
167203 ** reg_param ,
168204 ).fit (reg_data .df )
169205 pred_result = spark_regressor .transform (reg_data .df )
@@ -179,6 +215,26 @@ def test_regressor(self, reg_data: RegData) -> None:
179215 .toPandas ()["pred_contribs" ]
180216 .tolist ()
181217 )
218+ rounds = reg .get_booster ().num_boosted_rounds ()
219+ iter_range = (0 , max (1 , min (5 , rounds )))
220+ iter_preds = (
221+ SparkXGBRegressor (
222+ weight_col = "weight" ,
223+ validation_indicator_col = "is_val" ,
224+ iteration_range = iter_range ,
225+ num_workers = num_workers ,
226+ ** reg_param ,
227+ )
228+ .fit (reg_data .df )
229+ .transform (reg_data .df )
230+ .orderBy ("row_id" )
231+ .select ("prediction" )
232+ .toPandas ()["prediction" ]
233+ .to_numpy ()
234+ )
235+ assert np .allclose (
236+ iter_preds , reg .predict (reg_data .X , iteration_range = iter_range ), rtol = 1e-3
237+ )
182238 assert np .allclose (preds , reg .predict (reg_data .X ), rtol = 1e-3 )
183239 assert np .allclose (pred_contribs .sum (axis = 1 ), preds , rtol = 1e-3 )
184240 assert np .allclose (
@@ -308,13 +364,13 @@ def test_valid_type(self, spark: SparkSession) -> None:
308364 with pytest .raises (TypeError , match = "The validation indicator must be boolean" ):
309365 reg .fit (df_train )
310366
311- def test_callbacks (self , reg_data : RegData ) -> None :
312- train_df = reg_data .df .select ("features" , "label" )
367+ @pytest .mark .parametrize ("spark" , DUAL_SPARK_MODES , indirect = True )
368+ def test_callbacks (self , spark : SparkSession , reg_data : RegData ) -> None :
369+ train_df = reg_data .df .select ("row_id" , "features" , "label" )
313370
314371 def custom_lr (boosting_round : int ) -> float :
315372 return 1.0 / (boosting_round + 1 )
316373
317- cb = [LearningRateScheduler (custom_lr )]
318374 reg_params = {
319375 "n_estimators" : 10 ,
320376 "max_depth" : 3 ,
@@ -324,7 +380,9 @@ def custom_lr(boosting_round: int) -> float:
324380
325381 with tempfile .TemporaryDirectory () as tmpdir :
326382 path = os .path .join (tmpdir , "spark-xgb-reg-cb" )
327- regressor = SparkXGBRegressor (callbacks = cb , ** reg_params )
383+ regressor = SparkXGBRegressor (
384+ callbacks = [LearningRateScheduler (custom_lr )], ** reg_params
385+ )
328386 regressor .save (path )
329387 regressor = SparkXGBRegressor .load (path )
330388 loaded_callbacks = regressor .getOrDefault (regressor .callbacks )
@@ -334,13 +392,16 @@ def custom_lr(boosting_round: int) -> float:
334392 model = regressor .fit (train_df )
335393 preds = (
336394 model .transform (train_df )
395+ .orderBy ("row_id" )
337396 .select ("prediction" )
338397 .toPandas ()["prediction" ]
339398 .to_numpy ()
340399 )
341400
342- assert preds .shape == (len (reg_data .y ),)
343- assert np .isfinite (preds ).all ()
401+ ref = XGBRegressor (
402+ callbacks = [LearningRateScheduler (custom_lr )], ** reg_params
403+ ).fit (reg_data .X , reg_data .y )
404+ assert np .allclose (preds , ref .predict (reg_data .X ), rtol = 1e-3 )
344405
345406 @pytest .mark .parametrize ("tree_method" , ["hist" , "approx" ])
346407 def test_empty_train_data (self , spark : SparkSession , tree_method : str ) -> None :
@@ -405,7 +466,10 @@ def clf_data(self, spark: SparkSession) -> ClfData:
405466 X_train , X_test , y_train , y_test , w , base_margin , is_val , X , y , df
406467 )
407468
408- def test_classifier (self , clf_data : ClfData ) -> None :
469+ @pytest .mark .parametrize ("spark" , DUAL_SPARK_MODES , indirect = True )
470+ def test_classifier (
471+ self , spark : SparkSession , clf_data : ClfData , num_workers : int
472+ ) -> None :
409473 train_df = clf_data .df
410474 X = clf_data .X
411475 y = clf_data .y
@@ -432,6 +496,7 @@ def test_classifier(self, clf_data: ClfData) -> None:
432496 spark_cls = SparkXGBClassifier (
433497 weight_col = "weight" ,
434498 validation_indicator_col = "is_val" ,
499+ num_workers = num_workers ,
435500 ** cls_params ,
436501 ).fit (train_df )
437502
0 commit comments