@@ -42,7 +42,7 @@ def create_client_datasets(
4242 samples_each : int ,
4343 n_features : int ,
4444 noise : float = 0.1 ,
45- seed : int = 42
45+ seed : int = 42 ,
4646) -> List [Tuple [np .ndarray , np .ndarray ]]:
4747 """
4848 Generates synthetic linear regression datasets for multiple clients.
@@ -81,17 +81,19 @@ def initialize_parameters(n_params: int, seed: int = 0) -> np.ndarray:
8181
8282def mean_squared_error (params : np .ndarray , X : np .ndarray , y : np .ndarray ) -> float :
8383 """Computes mean squared error for predictions on dataset (X, y).
84- >>> params = np.array([0.0, 1.0])
85- >>> X = np.array([[1.0, 0.0], [1.0, 1.0]])
86- >>> y = np.array([0.0, 2.0])
87- >>> mean_squared_error(params, X, y)
88- 0.5
84+ >>> params = np.array([0.0, 1.0])
85+ >>> X = np.array([[1.0, 0.0], [1.0, 1.0]])
86+ >>> y = np.array([0.0, 2.0])
87+ >>> mean_squared_error(params, X, y)
88+ 0.5
8989 """
9090 predictions = X @ params
9191 return float (np .mean ((predictions - y ) ** 2 ))
9292
9393
94- def evaluate_global_model (params : np .ndarray , client_data : List [Tuple [np .ndarray , np .ndarray ]]) -> float :
94+ def evaluate_global_model (
95+ params : np .ndarray , client_data : List [Tuple [np .ndarray , np .ndarray ]]
96+ ) -> float :
9597 """Evaluates the average global MSE across all client datasets."""
9698 total_loss , total_samples = 0.0 , 0
9799
@@ -108,7 +110,7 @@ def client_update(
108110 y : np .ndarray ,
109111 lr : float = 0.01 ,
110112 epochs : int = 1 ,
111- batch_size : int = 0
113+ batch_size : int = 0 ,
112114) -> np .ndarray :
113115 """
114116 Performs local training on a client's dataset.
@@ -127,7 +129,7 @@ def client_update(
127129 # Mini-batch gradient descent
128130 order = np .random .permutation (n_samples )
129131 for i in range (0 , n_samples , batch_size ):
130- idx = order [i : i + batch_size ]
132+ idx = order [i : i + batch_size ]
131133 Xb , yb = X [idx ], y [idx ]
132134 preds = Xb @ updated_params
133135 grad = (2 / len (yb )) * (Xb .T @ (preds - yb ))
@@ -162,7 +164,7 @@ def run_federated_training(
162164 local_epochs : int = 1 ,
163165 lr : float = 0.01 ,
164166 batch_size : int = 0 ,
165- seed : int = 0
167+ seed : int = 0 ,
166168) -> Tuple [np .ndarray , List [float ]]:
167169 """
168170 Runs the full FedAvg simulation for the given client datasets.
@@ -179,7 +181,9 @@ def run_federated_training(
179181 client_models , client_sizes = [], []
180182
181183 for X , y in clients :
182- local_params = client_update (global_params , X , y , lr , local_epochs , batch_size )
184+ local_params = client_update (
185+ global_params , X , y , lr , local_epochs , batch_size
186+ )
183187 client_models .append (local_params )
184188 client_sizes .append (len (y ))
185189
@@ -194,14 +198,19 @@ def run_federated_training(
194198
195199if __name__ == "__main__" :
196200 # Example demonstration
197- datasets = create_client_datasets (n_clients = 5 , samples_each = 200 , n_features = 3 , noise = 0.5 , seed = 123 )
198- final_model , loss_curve = run_federated_training (datasets , rounds = 12 , local_epochs = 2 , lr = 0.05 )
201+ datasets = create_client_datasets (
202+ n_clients = 5 , samples_each = 200 , n_features = 3 , noise = 0.5 , seed = 123
203+ )
204+ final_model , loss_curve = run_federated_training (
205+ datasets , rounds = 12 , local_epochs = 2 , lr = 0.05
206+ )
199207
200208 print ("\n Final model parameters:\n " , np .round (final_model , 4 ))
201209
202210 try :
203211 import matplotlib .pyplot as plt
204- plt .plot (loss_curve , marker = 'o' )
212+
213+ plt .plot (loss_curve , marker = "o" )
205214 plt .title ("Federated Averaging - Training Loss Curve" )
206215 plt .xlabel ("Round" )
207216 plt .ylabel ("Mean Squared Error" )
@@ -213,7 +222,7 @@ def run_federated_training(
213222"""
214223 for testing:
215224 Create "tests/test_federated_averaging.py"
216-
225+
217226 " import numpy as np
218227 from machine_learning.federated_learning import federated_averaging as fed
219228
@@ -227,4 +236,4 @@ def test_loss_reduction_in_fedavg():
227236 seed=0
228237 ) "
229238
230- """
239+ """
0 commit comments