Skip to content

Commit 8863de1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3e7ace7 commit 8863de1

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Package for machine_learning.federated_learning
2-
__all__ = ["federated_averaging"]
2+
__all__ = ["federated_averaging"]

machine_learning/federated_learning/federated_averaging.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def create_client_datasets(
4242
samples_each: int,
4343
n_features: int,
4444
noise: float = 0.1,
45-
seed: int = 42
45+
seed: int = 42,
4646
) -> List[Tuple[np.ndarray, np.ndarray]]:
4747
"""
4848
Generates synthetic linear regression datasets for multiple clients.
@@ -81,17 +81,19 @@ def initialize_parameters(n_params: int, seed: int = 0) -> np.ndarray:
8181

8282
def mean_squared_error(params: np.ndarray, X: np.ndarray, y: np.ndarray) -> float:
8383
"""Computes mean squared error for predictions on dataset (X, y).
84-
>>> params = np.array([0.0, 1.0])
85-
>>> X = np.array([[1.0, 0.0], [1.0, 1.0]])
86-
>>> y = np.array([0.0, 2.0])
87-
>>> mean_squared_error(params, X, y)
88-
0.5
84+
>>> params = np.array([0.0, 1.0])
85+
>>> X = np.array([[1.0, 0.0], [1.0, 1.0]])
86+
>>> y = np.array([0.0, 2.0])
87+
>>> mean_squared_error(params, X, y)
88+
0.5
8989
"""
9090
predictions = X @ params
9191
return float(np.mean((predictions - y) ** 2))
9292

9393

94-
def evaluate_global_model(params: np.ndarray, client_data: List[Tuple[np.ndarray, np.ndarray]]) -> float:
94+
def evaluate_global_model(
95+
params: np.ndarray, client_data: List[Tuple[np.ndarray, np.ndarray]]
96+
) -> float:
9597
"""Evaluates the average global MSE across all client datasets."""
9698
total_loss, total_samples = 0.0, 0
9799

@@ -108,7 +110,7 @@ def client_update(
108110
y: np.ndarray,
109111
lr: float = 0.01,
110112
epochs: int = 1,
111-
batch_size: int = 0
113+
batch_size: int = 0,
112114
) -> np.ndarray:
113115
"""
114116
Performs local training on a client's dataset.
@@ -127,7 +129,7 @@ def client_update(
127129
# Mini-batch gradient descent
128130
order = np.random.permutation(n_samples)
129131
for i in range(0, n_samples, batch_size):
130-
idx = order[i:i + batch_size]
132+
idx = order[i : i + batch_size]
131133
Xb, yb = X[idx], y[idx]
132134
preds = Xb @ updated_params
133135
grad = (2 / len(yb)) * (Xb.T @ (preds - yb))
@@ -162,7 +164,7 @@ def run_federated_training(
162164
local_epochs: int = 1,
163165
lr: float = 0.01,
164166
batch_size: int = 0,
165-
seed: int = 0
167+
seed: int = 0,
166168
) -> Tuple[np.ndarray, List[float]]:
167169
"""
168170
Runs the full FedAvg simulation for the given client datasets.
@@ -179,7 +181,9 @@ def run_federated_training(
179181
client_models, client_sizes = [], []
180182

181183
for X, y in clients:
182-
local_params = client_update(global_params, X, y, lr, local_epochs, batch_size)
184+
local_params = client_update(
185+
global_params, X, y, lr, local_epochs, batch_size
186+
)
183187
client_models.append(local_params)
184188
client_sizes.append(len(y))
185189

@@ -194,14 +198,19 @@ def run_federated_training(
194198

195199
if __name__ == "__main__":
196200
# Example demonstration
197-
datasets = create_client_datasets(n_clients=5, samples_each=200, n_features=3, noise=0.5, seed=123)
198-
final_model, loss_curve = run_federated_training(datasets, rounds=12, local_epochs=2, lr=0.05)
201+
datasets = create_client_datasets(
202+
n_clients=5, samples_each=200, n_features=3, noise=0.5, seed=123
203+
)
204+
final_model, loss_curve = run_federated_training(
205+
datasets, rounds=12, local_epochs=2, lr=0.05
206+
)
199207

200208
print("\nFinal model parameters:\n", np.round(final_model, 4))
201209

202210
try:
203211
import matplotlib.pyplot as plt
204-
plt.plot(loss_curve, marker='o')
212+
213+
plt.plot(loss_curve, marker="o")
205214
plt.title("Federated Averaging - Training Loss Curve")
206215
plt.xlabel("Round")
207216
plt.ylabel("Mean Squared Error")
@@ -213,7 +222,7 @@ def run_federated_training(
213222
"""
214223
for testing:
215224
Create "tests/test_federated_averaging.py"
216-
225+
217226
" import numpy as np
218227
from machine_learning.federated_learning import federated_averaging as fed
219228
@@ -227,4 +236,4 @@ def test_loss_reduction_in_fedavg():
227236
seed=0
228237
) "
229238
230-
"""
239+
"""

0 commit comments

Comments
 (0)