Skip to content

Commit d2d7392

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5d819b8 commit d2d7392

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

machine_learning/random_forest_classifier.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- https://en.wikipedia.org/wiki/Random_forest
1111
- https://en.wikipedia.org/wiki/Decision_tree_learning
1212
"""
13+
1314
from __future__ import annotations
1415

1516
from collections import Counter
@@ -59,7 +60,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> None:
5960
"""
6061
n_total_features = x.shape[1]
6162
self.n_features = (
62-
n_total_features if self.n_features in (None, 0) else min(self.n_features, n_total_features)
63+
n_total_features
64+
if self.n_features in (None, 0)
65+
else min(self.n_features, n_total_features)
6366
)
6467
self.tree = self._grow_tree(x, y, depth=0)
6568

@@ -77,7 +80,11 @@ def _grow_tree(self, x: np.ndarray, y: np.ndarray, depth: int = 0) -> TreeNode:
7780
n_labels = len(np.unique(y))
7881

7982
# Stopping criteria
80-
if depth >= self.max_depth or n_labels == 1 or n_samples < self.min_samples_split:
83+
if (
84+
depth >= self.max_depth
85+
or n_labels == 1
86+
or n_samples < self.min_samples_split
87+
):
8188
leaf_value = self._most_common_label(y)
8289
return {"leaf": True, "value": int(leaf_value)}
8390

@@ -131,7 +138,9 @@ def _best_split(
131138
split_thresh = float(threshold)
132139
return split_idx, split_thresh
133140

134-
def _information_gain(self, y: np.ndarray, x_column: np.ndarray, threshold: float) -> float:
141+
def _information_gain(
142+
self, y: np.ndarray, x_column: np.ndarray, threshold: float
143+
) -> float:
135144
"""Calculate information gain from a split.
136145
137146
>>> y = np.array([0, 0, 1, 1])
@@ -293,7 +302,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "RandomForestClassifier":
293302
self.trees.append(tree)
294303
return self
295304

296-
def _bootstrap_sample(self, x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
305+
def _bootstrap_sample(
306+
self, x: np.ndarray, y: np.ndarray
307+
) -> Tuple[np.ndarray, np.ndarray]:
297308
"""Create a bootstrap sample from the dataset.
298309
299310
Bootstrap sampling randomly samples with replacement from the dataset.
@@ -370,7 +381,9 @@ def _most_common_label(self, y: Sequence[int]) -> int:
370381
)
371382

372383
# Split the data
373-
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
384+
x_train, x_test, y_train, y_test = train_test_split(
385+
x, y, test_size=0.2, random_state=42
386+
)
374387

375388
print(f"Training samples: {x_train.shape[0]}")
376389
print(f"Test samples: {x_test.shape[0]}")
@@ -379,7 +392,9 @@ def _most_common_label(self, y: Sequence[int]) -> int:
379392

380393
# Train Random Forest Classifier
381394
print("Training Random Forest Classifier...")
382-
rf_classifier = RandomForestClassifier(n_estimators=10, max_depth=10, min_samples_split=2)
395+
rf_classifier = RandomForestClassifier(
396+
n_estimators=10, max_depth=10, min_samples_split=2
397+
)
383398
rf_classifier.fit(x_train, y_train)
384399
print("Training complete!")
385400
print()

0 commit comments

Comments
 (0)