Skip to content

Commit 91ae8fb

Browse files
authored
Continue refactor of pyspark tests (dmlc#12007)
1 parent 12b7ad9 commit 91ae8fb

File tree

3 files changed

+74
-696
lines changed

3 files changed

+74
-696
lines changed

tests/test_distributed/test_with_spark/test_spark_local.py renamed to tests/test_distributed/test_with_spark/test_spark.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,18 @@
2727
SparkXGBRegressor,
2828
SparkXGBRegressorModel,
2929
)
30+
from xgboost.spark.utils import _get_max_num_concurrent_tasks
3031
from xgboost.testing.collective import get_avail_port
3132

3233
logging.getLogger("py4j").setLevel(logging.INFO)
3334

3435
pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())]
3536

37+
DUAL_SPARK_MODES = [
38+
pytest.param("local", id="local"),
39+
pytest.param("local_cluster", id="local_cluster"),
40+
]
41+
3642
RegData = namedtuple(
3743
"RegData",
3844
(
@@ -76,7 +82,10 @@ def no_sparse_unwrap() -> tm.PytestSkip:
7682

7783

7884
@pytest.fixture(scope="module")
79-
def spark() -> Generator[SparkSession, None, None]:
85+
def spark(request: pytest.FixtureRequest) -> Generator[SparkSession, None, None]:
86+
mode = getattr(request, "param", "local")
87+
if mode not in {"local", "local_cluster"}:
88+
raise ValueError(f"Unknown Spark test mode: {mode}")
8089
os.environ["XGBOOST_PYSPARK_SHARED_SESSION"] = "1"
8190
config = {
8291
"spark.master": "local[4]",
@@ -88,12 +97,30 @@ def spark() -> Generator[SparkSession, None, None]:
8897
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
8998
"spark.ui.enabled": "false",
9099
}
100+
if mode == "local_cluster":
101+
config.update(
102+
{
103+
"spark.master": "local-cluster[2, 1, 1024]",
104+
"spark.cores.max": "2",
105+
"spark.task.cpus": "1",
106+
"spark.executor.cores": "1",
107+
}
108+
)
91109

92110
builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests")
93111
for k, v in config.items():
94112
builder.config(k, v)
95113
logging.getLogger("pyspark").setLevel(logging.INFO)
96114
sess = builder.getOrCreate()
115+
if mode == "local_cluster":
116+
# Block until workers are connected.
117+
num_slots = sess.sparkContext.defaultParallelism
118+
(
119+
sess.sparkContext.parallelize(range(num_slots), num_slots)
120+
.barrier()
121+
.mapPartitions(lambda _: [])
122+
.collect()
123+
)
97124
try:
98125
yield sess
99126
finally:
@@ -102,6 +129,11 @@ def spark() -> Generator[SparkSession, None, None]:
102129
os.environ.pop("XGBOOST_PYSPARK_SHARED_SESSION", None)
103130

104131

132+
@pytest.fixture(scope="module")
133+
def num_workers(spark: SparkSession) -> int:
134+
return _get_max_num_concurrent_tasks(spark.sparkContext)
135+
136+
105137
class TestRegressor:
106138
@pytest.fixture(scope="class")
107139
def reg_data(self, spark: SparkSession) -> RegData:
@@ -141,7 +173,10 @@ def reg_data(self, spark: SparkSession) -> RegData:
141173
X_train, X_test, y_train, y_test, w, base_margin, is_val, X, y, df
142174
)
143175

144-
def test_regressor(self, reg_data: RegData) -> None:
176+
@pytest.mark.parametrize("spark", DUAL_SPARK_MODES, indirect=True)
177+
def test_regressor(
178+
self, spark: SparkSession, reg_data: RegData, num_workers: int
179+
) -> None:
145180
train_rows = np.where(~reg_data.is_val)[0]
146181
validation_rows = np.where(reg_data.is_val)[0]
147182

@@ -164,6 +199,7 @@ def test_regressor(self, reg_data: RegData) -> None:
164199
pred_contrib_col="pred_contribs",
165200
weight_col="weight",
166201
validation_indicator_col="is_val",
202+
num_workers=num_workers,
167203
**reg_param,
168204
).fit(reg_data.df)
169205
pred_result = spark_regressor.transform(reg_data.df)
@@ -179,6 +215,26 @@ def test_regressor(self, reg_data: RegData) -> None:
179215
.toPandas()["pred_contribs"]
180216
.tolist()
181217
)
218+
rounds = reg.get_booster().num_boosted_rounds()
219+
iter_range = (0, max(1, min(5, rounds)))
220+
iter_preds = (
221+
SparkXGBRegressor(
222+
weight_col="weight",
223+
validation_indicator_col="is_val",
224+
iteration_range=iter_range,
225+
num_workers=num_workers,
226+
**reg_param,
227+
)
228+
.fit(reg_data.df)
229+
.transform(reg_data.df)
230+
.orderBy("row_id")
231+
.select("prediction")
232+
.toPandas()["prediction"]
233+
.to_numpy()
234+
)
235+
assert np.allclose(
236+
iter_preds, reg.predict(reg_data.X, iteration_range=iter_range), rtol=1e-3
237+
)
182238
assert np.allclose(preds, reg.predict(reg_data.X), rtol=1e-3)
183239
assert np.allclose(pred_contribs.sum(axis=1), preds, rtol=1e-3)
184240
assert np.allclose(
@@ -308,13 +364,13 @@ def test_valid_type(self, spark: SparkSession) -> None:
308364
with pytest.raises(TypeError, match="The validation indicator must be boolean"):
309365
reg.fit(df_train)
310366

311-
def test_callbacks(self, reg_data: RegData) -> None:
312-
train_df = reg_data.df.select("features", "label")
367+
@pytest.mark.parametrize("spark", DUAL_SPARK_MODES, indirect=True)
368+
def test_callbacks(self, spark: SparkSession, reg_data: RegData) -> None:
369+
train_df = reg_data.df.select("row_id", "features", "label")
313370

314371
def custom_lr(boosting_round: int) -> float:
315372
return 1.0 / (boosting_round + 1)
316373

317-
cb = [LearningRateScheduler(custom_lr)]
318374
reg_params = {
319375
"n_estimators": 10,
320376
"max_depth": 3,
@@ -324,7 +380,9 @@ def custom_lr(boosting_round: int) -> float:
324380

325381
with tempfile.TemporaryDirectory() as tmpdir:
326382
path = os.path.join(tmpdir, "spark-xgb-reg-cb")
327-
regressor = SparkXGBRegressor(callbacks=cb, **reg_params)
383+
regressor = SparkXGBRegressor(
384+
callbacks=[LearningRateScheduler(custom_lr)], **reg_params
385+
)
328386
regressor.save(path)
329387
regressor = SparkXGBRegressor.load(path)
330388
loaded_callbacks = regressor.getOrDefault(regressor.callbacks)
@@ -334,13 +392,16 @@ def custom_lr(boosting_round: int) -> float:
334392
model = regressor.fit(train_df)
335393
preds = (
336394
model.transform(train_df)
395+
.orderBy("row_id")
337396
.select("prediction")
338397
.toPandas()["prediction"]
339398
.to_numpy()
340399
)
341400

342-
assert preds.shape == (len(reg_data.y),)
343-
assert np.isfinite(preds).all()
401+
ref = XGBRegressor(
402+
callbacks=[LearningRateScheduler(custom_lr)], **reg_params
403+
).fit(reg_data.X, reg_data.y)
404+
assert np.allclose(preds, ref.predict(reg_data.X), rtol=1e-3)
344405

345406
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
346407
def test_empty_train_data(self, spark: SparkSession, tree_method: str) -> None:
@@ -405,7 +466,10 @@ def clf_data(self, spark: SparkSession) -> ClfData:
405466
X_train, X_test, y_train, y_test, w, base_margin, is_val, X, y, df
406467
)
407468

408-
def test_classifier(self, clf_data: ClfData) -> None:
469+
@pytest.mark.parametrize("spark", DUAL_SPARK_MODES, indirect=True)
470+
def test_classifier(
471+
self, spark: SparkSession, clf_data: ClfData, num_workers: int
472+
) -> None:
409473
train_df = clf_data.df
410474
X = clf_data.X
411475
y = clf_data.y
@@ -432,6 +496,7 @@ def test_classifier(self, clf_data: ClfData) -> None:
432496
spark_cls = SparkXGBClassifier(
433497
weight_col="weight",
434498
validation_indicator_col="is_val",
499+
num_workers=num_workers,
435500
**cls_params,
436501
).fit(train_df)
437502

0 commit comments

Comments
 (0)