Skip to content

Commit 1789330

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ea67a9f commit 1789330

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

machine_learning/decision_tree.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,19 @@ def train(self, x: np.ndarray, y: np.ndarray) -> None:
178178
left_indices = x <= split_point
179179
right_indices = x > split_point
180180

181-
if np.sum(left_indices) < self.min_leaf_size or \
182-
np.sum(right_indices) < self.min_leaf_size:
181+
if (
182+
np.sum(left_indices) < self.min_leaf_size
183+
or np.sum(right_indices) < self.min_leaf_size
184+
):
183185
continue
184186

185187
y_left = y[left_indices]
186188
y_right = y[right_indices]
187189

188190
# Calculate weighted MSE for this split
189-
error = (len(y_left) * self.mean_squared_error(y_left, np.mean(y_left)) +
190-
len(y_right) * self.mean_squared_error(y_right, np.mean(y_right)))
191+
error = len(y_left) * self.mean_squared_error(
192+
y_left, np.mean(y_left)
193+
) + len(y_right) * self.mean_squared_error(y_right, np.mean(y_right))
191194

192195
if error < min_error:
193196
min_error = error
@@ -201,7 +204,9 @@ def train(self, x: np.ndarray, y: np.ndarray) -> None:
201204
# Create child nodes and recursively train them
202205
self.decision_boundary = best_split
203206
self.left = DecisionTree(depth=self.depth - 1, min_leaf_size=self.min_leaf_size)
204-
self.right = DecisionTree(depth=self.depth - 1, min_leaf_size=self.min_leaf_size)
207+
self.right = DecisionTree(
208+
depth=self.depth - 1, min_leaf_size=self.min_leaf_size
209+
)
205210

206211
left_indices = x <= best_split
207212
right_indices = x > best_split
@@ -244,9 +249,7 @@ class TestDecisionTree:
244249
"""Decision Tree test class for verification purposes."""
245250

246251
@staticmethod
247-
def helper_mean_squared_error_test(
248-
labels: np.ndarray, prediction: float
249-
) -> float:
252+
def helper_mean_squared_error_test(labels: np.ndarray, prediction: float) -> float:
250253
"""
251254
Helper function to test mean_squared_error implementation.
252255
@@ -278,9 +281,9 @@ def main() -> None:
278281
- Error analysis: Understanding model performance
279282
"""
280283
# Example 1: Sine wave function approximation
281-
print("\n" + "="*60)
284+
print("\n" + "=" * 60)
282285
print("Example 1: Sine Wave Function Approximation")
283-
print("="*60)
286+
print("=" * 60)
284287
print("Training a decision tree to approximate f(x) = sin(x)")
285288
print("This demonstrates the tree's ability to learn non-linear patterns\n")
286289

@@ -304,9 +307,9 @@ def main() -> None:
304307
print(f"Average MSE: {avg_error:.6f}")
305308

306309
# Example 2: Linear relationship
307-
print("\n" + "="*60)
310+
print("\n" + "=" * 60)
308311
print("Example 2: Linear Relationship (House Price Analogy)")
309-
print("="*60)
312+
print("=" * 60)
310313
print("Simulating house price prediction based on square footage\n")
311314

312315
# Simple linear relationship: price = 100 * sqft + noise

0 commit comments

Comments
 (0)