Skip to content

Commit 5e0f844

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1664757 commit 5e0f844

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

machine_learning/random_forest_regressor.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- https://en.wikipedia.org/wiki/Random_forest
55
- https://en.wikipedia.org/wiki/Decision_tree_learning
66
"""
7+
78
from __future__ import annotations
89

910
from typing import Any, Dict, List, Optional, Sequence, Tuple
@@ -35,7 +36,9 @@ class DecisionTreeRegressor:
3536
True
3637
"""
3738

38-
def __init__(self, max_depth: Optional[int] = None, min_samples_split: int = 2) -> None:
39+
def __init__(
40+
self, max_depth: Optional[int] = None, min_samples_split: int = 2
41+
) -> None:
3942
self.max_depth: Optional[int] = max_depth
4043
self.min_samples_split: int = min_samples_split
4144
self.tree: Optional[TreeNodeReg] = None
@@ -103,7 +106,9 @@ def _grow_tree(self, x: np.ndarray, y: np.ndarray, depth: int = 0) -> TreeNodeRe
103106
"right": right_subtree,
104107
}
105108

106-
def _best_split(self, x: np.ndarray, y: np.ndarray, n_features: int) -> Optional[Dict[str, Any]]:
109+
def _best_split(
110+
self, x: np.ndarray, y: np.ndarray, n_features: int
111+
) -> Optional[Dict[str, Any]]:
107112
"""
108113
Find the best feature and threshold to split on.
109114
@@ -133,10 +138,15 @@ def _best_split(self, x: np.ndarray, y: np.ndarray, n_features: int) -> Optional
133138
mse = self._calculate_mse(y[left_indices], y[right_indices], len(y))
134139
if mse < best_mse:
135140
best_mse = mse
136-
best_split = {"feature": int(feature), "threshold": float(threshold)}
141+
best_split = {
142+
"feature": int(feature),
143+
"threshold": float(threshold),
144+
}
137145
return best_split
138146

139-
def _calculate_mse(self, left_y: np.ndarray, right_y: np.ndarray, n_samples: int) -> float:
147+
def _calculate_mse(
148+
self, left_y: np.ndarray, right_y: np.ndarray, n_samples: int
149+
) -> float:
140150
"""
141151
Calculate weighted mean squared error for a split.
142152
@@ -289,7 +299,9 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "RandomForestRegressor":
289299
feature_indices = rng.choice(n_features, max_features, replace=False)
290300
x_bootstrap = x_bootstrap[:, feature_indices]
291301
# Train decision tree
292-
tree = DecisionTreeRegressor(max_depth=self.max_depth, min_samples_split=self.min_samples_split)
302+
tree = DecisionTreeRegressor(
303+
max_depth=self.max_depth, min_samples_split=self.min_samples_split
304+
)
293305
tree.fit(x_bootstrap, y_bootstrap)
294306
self.trees.append((tree, feature_indices))
295307
return self
@@ -328,10 +340,14 @@ def predict(self, x: np.ndarray) -> np.ndarray:
328340
from sklearn.model_selection import train_test_split
329341

330342
# Generate synthetic regression data
331-
x, y = make_regression(n_samples=200, n_features=5, n_informative=3, noise=10, random_state=42)
343+
x, y = make_regression(
344+
n_samples=200, n_features=5, n_informative=3, noise=10, random_state=42
345+
)
332346

333347
# Split the data
334-
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42)
348+
x_train, x_test, y_train, y_test = train_test_split(
349+
x, y, test_size=0.3, random_state=42
350+
)
335351

336352
# Train the Random Forest Regressor
337353
rf_regressor = RandomForestRegressor(n_estimators=10, max_depth=5, random_state=42)

0 commit comments

Comments
 (0)