@@ -178,16 +178,19 @@ def train(self, x: np.ndarray, y: np.ndarray) -> None:
178178 left_indices = x <= split_point
179179 right_indices = x > split_point
180180
181- if np .sum (left_indices ) < self .min_leaf_size or \
182- np .sum (right_indices ) < self .min_leaf_size :
181+ if (
182+ np .sum (left_indices ) < self .min_leaf_size
183+ or np .sum (right_indices ) < self .min_leaf_size
184+ ):
183185 continue
184186
185187 y_left = y [left_indices ]
186188 y_right = y [right_indices ]
187189
188190 # Calculate weighted MSE for this split
189- error = (len (y_left ) * self .mean_squared_error (y_left , np .mean (y_left )) +
190- len (y_right ) * self .mean_squared_error (y_right , np .mean (y_right )))
191+ error = len (y_left ) * self .mean_squared_error (
192+ y_left , np .mean (y_left )
193+ ) + len (y_right ) * self .mean_squared_error (y_right , np .mean (y_right ))
191194
192195 if error < min_error :
193196 min_error = error
@@ -201,7 +204,9 @@ def train(self, x: np.ndarray, y: np.ndarray) -> None:
201204 # Create child nodes and recursively train them
202205 self .decision_boundary = best_split
203206 self .left = DecisionTree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
204- self .right = DecisionTree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
207+ self .right = DecisionTree (
208+ depth = self .depth - 1 , min_leaf_size = self .min_leaf_size
209+ )
205210
206211 left_indices = x <= best_split
207212 right_indices = x > best_split
@@ -244,9 +249,7 @@ class TestDecisionTree:
244249 """Decision Tree test class for verification purposes."""
245250
246251 @staticmethod
247- def helper_mean_squared_error_test (
248- labels : np .ndarray , prediction : float
249- ) -> float :
252+ def helper_mean_squared_error_test (labels : np .ndarray , prediction : float ) -> float :
250253 """
251254 Helper function to test mean_squared_error implementation.
252255
@@ -278,9 +281,9 @@ def main() -> None:
278281 - Error analysis: Understanding model performance
279282 """
280283 # Example 1: Sine wave function approximation
281- print ("\n " + "=" * 60 )
284+ print ("\n " + "=" * 60 )
282285 print ("Example 1: Sine Wave Function Approximation" )
283- print ("=" * 60 )
286+ print ("=" * 60 )
284287 print ("Training a decision tree to approximate f(x) = sin(x)" )
285288 print ("This demonstrates the tree's ability to learn non-linear patterns\n " )
286289
@@ -304,9 +307,9 @@ def main() -> None:
304307 print (f"Average MSE: { avg_error :.6f} " )
305308
306309 # Example 2: Linear relationship
307- print ("\n " + "=" * 60 )
310+ print ("\n " + "=" * 60 )
308311 print ("Example 2: Linear Relationship (House Price Analogy)" )
309- print ("=" * 60 )
312+ print ("=" * 60 )
310313 print ("Simulating house price prediction based on square footage\n " )
311314
312315 # Simple linear relationship: price = 100 * sqft + noise
0 commit comments