Skip to content

Commit 49409e0

Browse files
authored
Merge pull request #1 from wegar-2/dev
Dev
2 parents 8aba870 + b0785e6 commit 49409e0

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

moddata/_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ def _load_btc():
5959

6060

6161
def _load_pl_banking_stocks() -> pd.DataFrame:
62-
with (
62+
return pd.read_parquet(str(
6363
resources.files('moddata.data').joinpath('pl_banking_stocks.parquet')
64-
as f
65-
):
66-
return pd.read_parquet(f)
64+
))
6765

6866

6967
def load_data(dataset: Dataset) -> pd.DataFrame | None:

tests/pipeline/test_bankchurn_pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,21 @@
44
from moddata.src.config import BankchurnPipelineConfig
55

66

7-
def test_bankchurn_pipeline_run():
7+
def test_bankchurn_pipeline_tree_like():
88
X_train, X_test, y_train, y_test = BankchurnPipeline(
99
config=BankchurnPipelineConfig(
1010
random_state=12345,
11-
train_size=0.8
11+
train_size=0.8,
12+
encoding_and_scaling_model_type="tree_like"
1213
)
1314
).run()
1415

15-
assert X_train.shape == (8_000, 10)
16-
assert X_test.shape == (2_000, 10)
16+
assert X_train.shape == (8_000, 11)
17+
assert X_test.shape == (2_000, 11)
1718
assert y_train.shape == (8_000, 1)
1819
assert y_test.shape == (2_000, 1)
1920

2021
assert np.all(np.array(y_test.index[:3]) == np.array([7867, 1402, 8606]))
22+
23+
24+
test_bankchurn_pipeline_tree_like

0 commit comments

Comments
 (0)