|
4 | 4 | - https://en.wikipedia.org/wiki/Random_forest |
5 | 5 | - https://en.wikipedia.org/wiki/Decision_tree_learning |
6 | 6 | """ |
| 7 | + |
7 | 8 | from __future__ import annotations |
8 | 9 |
|
9 | 10 | from typing import Any, Dict, List, Optional, Sequence, Tuple |
@@ -35,7 +36,9 @@ class DecisionTreeRegressor: |
35 | 36 | True |
36 | 37 | """ |
37 | 38 |
|
38 | | - def __init__(self, max_depth: Optional[int] = None, min_samples_split: int = 2) -> None: |
| 39 | + def __init__( |
| 40 | + self, max_depth: Optional[int] = None, min_samples_split: int = 2 |
| 41 | + ) -> None: |
39 | 42 | self.max_depth: Optional[int] = max_depth |
40 | 43 | self.min_samples_split: int = min_samples_split |
41 | 44 | self.tree: Optional[TreeNodeReg] = None |
@@ -103,7 +106,9 @@ def _grow_tree(self, x: np.ndarray, y: np.ndarray, depth: int = 0) -> TreeNodeRe |
103 | 106 | "right": right_subtree, |
104 | 107 | } |
105 | 108 |
|
106 | | - def _best_split(self, x: np.ndarray, y: np.ndarray, n_features: int) -> Optional[Dict[str, Any]]: |
| 109 | + def _best_split( |
| 110 | + self, x: np.ndarray, y: np.ndarray, n_features: int |
| 111 | + ) -> Optional[Dict[str, Any]]: |
107 | 112 | """ |
108 | 113 | Find the best feature and threshold to split on. |
109 | 114 |
|
@@ -133,10 +138,15 @@ def _best_split(self, x: np.ndarray, y: np.ndarray, n_features: int) -> Optional |
133 | 138 | mse = self._calculate_mse(y[left_indices], y[right_indices], len(y)) |
134 | 139 | if mse < best_mse: |
135 | 140 | best_mse = mse |
136 | | - best_split = {"feature": int(feature), "threshold": float(threshold)} |
| 141 | + best_split = { |
| 142 | + "feature": int(feature), |
| 143 | + "threshold": float(threshold), |
| 144 | + } |
137 | 145 | return best_split |
138 | 146 |
|
139 | | - def _calculate_mse(self, left_y: np.ndarray, right_y: np.ndarray, n_samples: int) -> float: |
| 147 | + def _calculate_mse( |
| 148 | + self, left_y: np.ndarray, right_y: np.ndarray, n_samples: int |
| 149 | + ) -> float: |
140 | 150 | """ |
141 | 151 | Calculate weighted mean squared error for a split. |
142 | 152 |
|
@@ -289,7 +299,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "RandomForestRegressor": |
289 | 299 | feature_indices = rng.choice(n_features, max_features, replace=False) |
290 | 300 | x_bootstrap = x_bootstrap[:, feature_indices] |
291 | 301 | # Train decision tree |
292 | | - tree = DecisionTreeRegressor(max_depth=self.max_depth, min_samples_split=self.min_samples_split) |
| 302 | + tree = DecisionTreeRegressor( |
| 303 | + max_depth=self.max_depth, min_samples_split=self.min_samples_split |
| 304 | + ) |
293 | 305 | tree.fit(x_bootstrap, y_bootstrap) |
294 | 306 | self.trees.append((tree, feature_indices)) |
295 | 307 | return self |
@@ -328,10 +340,14 @@ def predict(self, x: np.ndarray) -> np.ndarray: |
328 | 340 | from sklearn.model_selection import train_test_split |
329 | 341 |
|
330 | 342 | # Generate synthetic regression data |
331 | | - x, y = make_regression(n_samples=200, n_features=5, n_informative=3, noise=10, random_state=42) |
| 343 | + x, y = make_regression( |
| 344 | + n_samples=200, n_features=5, n_informative=3, noise=10, random_state=42 |
| 345 | + ) |
332 | 346 |
|
333 | 347 | # Split the data |
334 | | - x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42) |
| 348 | + x_train, x_test, y_train, y_test = train_test_split( |
| 349 | + x, y, test_size=0.3, random_state=42 |
| 350 | + ) |
335 | 351 |
|
336 | 352 | # Train the Random Forest Regressor |
337 | 353 | rf_regressor = RandomForestRegressor(n_estimators=10, max_depth=5, random_state=42) |
|
0 commit comments