1010- https://en.wikipedia.org/wiki/Random_forest
1111- https://en.wikipedia.org/wiki/Decision_tree_learning
1212"""
13+
1314from __future__ import annotations
1415
1516from collections import Counter
@@ -59,7 +60,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> None:
5960 """
6061 n_total_features = x .shape [1 ]
6162 self .n_features = (
62- n_total_features if self .n_features in (None , 0 ) else min (self .n_features , n_total_features )
63+ n_total_features
64+ if self .n_features in (None , 0 )
65+ else min (self .n_features , n_total_features )
6366 )
6467 self .tree = self ._grow_tree (x , y , depth = 0 )
6568
@@ -77,7 +80,11 @@ def _grow_tree(self, x: np.ndarray, y: np.ndarray, depth: int = 0) -> TreeNode:
7780 n_labels = len (np .unique (y ))
7881
7982 # Stopping criteria
80- if depth >= self .max_depth or n_labels == 1 or n_samples < self .min_samples_split :
83+ if (
84+ depth >= self .max_depth
85+ or n_labels == 1
86+ or n_samples < self .min_samples_split
87+ ):
8188 leaf_value = self ._most_common_label (y )
8289 return {"leaf" : True , "value" : int (leaf_value )}
8390
@@ -131,7 +138,9 @@ def _best_split(
131138 split_thresh = float (threshold )
132139 return split_idx , split_thresh
133140
134- def _information_gain (self , y : np .ndarray , x_column : np .ndarray , threshold : float ) -> float :
141+ def _information_gain (
142+ self , y : np .ndarray , x_column : np .ndarray , threshold : float
143+ ) -> float :
135144 """Calculate information gain from a split.
136145
137146 >>> y = np.array([0, 0, 1, 1])
@@ -293,7 +302,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "RandomForestClassifier":
293302 self .trees .append (tree )
294303 return self
295304
296- def _bootstrap_sample (self , x : np .ndarray , y : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
305+ def _bootstrap_sample (
306+ self , x : np .ndarray , y : np .ndarray
307+ ) -> Tuple [np .ndarray , np .ndarray ]:
297308 """Create a bootstrap sample from the dataset.
298309
299310 Bootstrap sampling randomly samples with replacement from the dataset.
@@ -370,7 +381,9 @@ def _most_common_label(self, y: Sequence[int]) -> int:
370381 )
371382
372383 # Split the data
373- x_train , x_test , y_train , y_test = train_test_split (x , y , test_size = 0.2 , random_state = 42 )
384+ x_train , x_test , y_train , y_test = train_test_split (
385+ x , y , test_size = 0.2 , random_state = 42
386+ )
374387
375388 print (f"Training samples: { x_train .shape [0 ]} " )
376389 print (f"Test samples: { x_test .shape [0 ]} " )
@@ -379,7 +392,9 @@ def _most_common_label(self, y: Sequence[int]) -> int:
379392
380393 # Train Random Forest Classifier
381394 print ("Training Random Forest Classifier..." )
382- rf_classifier = RandomForestClassifier (n_estimators = 10 , max_depth = 10 , min_samples_split = 2 )
395+ rf_classifier = RandomForestClassifier (
396+ n_estimators = 10 , max_depth = 10 , min_samples_split = 2
397+ )
383398 rf_classifier .fit (x_train , y_train )
384399 print ("Training complete!" )
385400 print ()
0 commit comments