Skip to content

Commit 3e7ace7

Browse files
committed
added federated learning
1 parent e2a78d4 commit 3e7ace7

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Package for machine_learning.federated_learning
2+
__all__ = ["federated_averaging"]
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""
2+
Federated Learning: enables machine learning on distributed data by moving the training to the data, instead of moving the data to the training.
3+
Here’s a one-liner explanation:
4+
-Centralized machine learning: move the data to the computation.
5+
-Federated (machine) Learning: move the computation to the data.
6+
7+
This script demonstrates the working of the Federated Averaging algorithm (FedAvg)
8+
using a minimal, from-scratch approach with only NumPy.
9+
10+
Overview:
11+
- Synthetic data is generated and distributed among several clients.
12+
- Each client performs local gradient-based updates on its dataset.
13+
- The central server combines all updated client models by averaging them
14+
according to the number of samples each client has.
15+
- The process repeats for multiple communication rounds.
16+
17+
Key Functions:
18+
▪ create_client_datasets(...)
19+
▪ initialize_parameters(...)
20+
▪ client_update(...)
21+
▪ aggregate_models(...)
22+
▪ evaluate_global_model(...)
23+
▪ run_federated_training(...)
24+
25+
Example Usage:
26+
we can use in python federated_learning_simulation.py
27+
28+
Reference :
29+
30+
-(GitHub): "Federated Learning from Scratch (NumPy-based)"
31+
Ex: https://github.com/omar-fl/federated-learning-from-scratch
32+
-(Medium article): “Federated Learning from Scratch with NumPy”
33+
Ex: https://medium.com/@niveditapatnaik/federated-learning-from-scratch-with-numpy-ff9c62a2a4a9
34+
"""
35+
36+
from typing import List, Tuple
37+
import numpy as np
38+
39+
40+
def create_client_datasets(
41+
n_clients: int,
42+
samples_each: int,
43+
n_features: int,
44+
noise: float = 0.1,
45+
seed: int = 42
46+
) -> List[Tuple[np.ndarray, np.ndarray]]:
47+
"""
48+
Generates synthetic linear regression datasets for multiple clients.
49+
Each dataset includes a bias term and Gaussian noise.
50+
51+
Returns:
52+
A list containing tuples of (X, y) for each client.
53+
X has shape (samples_each, n_features + 1).
54+
"""
55+
rng = np.random.default_rng(seed)
56+
true_weights = rng.normal(0, 1, n_features + 1)
57+
clients = []
58+
59+
for _ in range(n_clients):
60+
X = rng.normal(0, 1, (samples_each, n_features))
61+
X_bias = np.c_[np.ones((samples_each, 1)), X]
62+
y = X_bias @ true_weights + rng.normal(0, noise, samples_each)
63+
clients.append((X_bias, y))
64+
65+
return clients
66+
67+
68+
def initialize_parameters(n_params: int, seed: int = 0) -> np.ndarray:
69+
"""
70+
Initialize model parameters (weights + bias) randomly.
71+
72+
>>> params = initialize_parameters(3, seed=0)
73+
>>> len(params)
74+
3
75+
>>> isinstance(params, np.ndarray)
76+
True
77+
"""
78+
rng = np.random.default_rng(seed)
79+
return rng.normal(0, 0.01, n_params)
80+
81+
82+
def mean_squared_error(params: np.ndarray, X: np.ndarray, y: np.ndarray) -> float:
83+
"""Computes mean squared error for predictions on dataset (X, y).
84+
>>> params = np.array([0.0, 1.0])
85+
>>> X = np.array([[1.0, 0.0], [1.0, 1.0]])
86+
>>> y = np.array([0.0, 2.0])
87+
>>> mean_squared_error(params, X, y)
88+
0.5
89+
"""
90+
predictions = X @ params
91+
return float(np.mean((predictions - y) ** 2))
92+
93+
94+
def evaluate_global_model(params: np.ndarray, client_data: List[Tuple[np.ndarray, np.ndarray]]) -> float:
95+
"""Evaluates the average global MSE across all client datasets."""
96+
total_loss, total_samples = 0.0, 0
97+
98+
for X, y in client_data:
99+
total_loss += np.sum((X @ params - y) ** 2)
100+
total_samples += len(y)
101+
102+
return float(total_loss / total_samples)
103+
104+
105+
def client_update(
106+
params: np.ndarray,
107+
X: np.ndarray,
108+
y: np.ndarray,
109+
lr: float = 0.01,
110+
epochs: int = 1,
111+
batch_size: int = 0
112+
) -> np.ndarray:
113+
"""
114+
Performs local training on a client's dataset.
115+
Uses basic gradient descent (full batch or mini-batch depending on batch_size).
116+
"""
117+
updated_params = params.copy()
118+
n_samples = len(y)
119+
120+
for _ in range(epochs):
121+
if batch_size <= 0 or batch_size >= n_samples:
122+
# Full-batch gradient descent
123+
preds = X @ updated_params
124+
grad = (2 / n_samples) * (X.T @ (preds - y))
125+
updated_params -= lr * grad
126+
else:
127+
# Mini-batch gradient descent
128+
order = np.random.permutation(n_samples)
129+
for i in range(0, n_samples, batch_size):
130+
idx = order[i:i + batch_size]
131+
Xb, yb = X[idx], y[idx]
132+
preds = Xb @ updated_params
133+
grad = (2 / len(yb)) * (Xb.T @ (preds - yb))
134+
updated_params -= lr * grad
135+
136+
return updated_params
137+
138+
139+
def aggregate_models(models: List[np.ndarray], sizes: List[int]) -> np.ndarray:
140+
"""
141+
Combines client models by computing a weighted average
142+
based on the number of samples each client used.
143+
144+
>>> w1 = np.array([1.0, 2.0])
145+
>>> w2 = np.array([3.0, 4.0])
146+
>>> aggregate_models([w1, w2], [1, 1])
147+
array([2., 3.])
148+
"""
149+
total_samples = sum(sizes)
150+
if total_samples == 0:
151+
raise ValueError("Cannot aggregate: total sample size is zero.")
152+
153+
aggregated = np.zeros_like(models[0], dtype=float)
154+
for w, n in zip(models, sizes):
155+
aggregated += (n / total_samples) * w
156+
return aggregated
157+
158+
159+
def run_federated_training(
160+
clients: List[Tuple[np.ndarray, np.ndarray]],
161+
rounds: int = 10,
162+
local_epochs: int = 1,
163+
lr: float = 0.01,
164+
batch_size: int = 0,
165+
seed: int = 0
166+
) -> Tuple[np.ndarray, List[float]]:
167+
"""
168+
Runs the full FedAvg simulation for the given client datasets.
169+
170+
Returns:
171+
final_parameters : np.ndarray
172+
loss_history : list of MSE values over communication rounds
173+
"""
174+
n_params = clients[0][0].shape[1]
175+
global_params = initialize_parameters(n_params, seed)
176+
history = []
177+
178+
for round_num in range(1, rounds + 1):
179+
client_models, client_sizes = [], []
180+
181+
for X, y in clients:
182+
local_params = client_update(global_params, X, y, lr, local_epochs, batch_size)
183+
client_models.append(local_params)
184+
client_sizes.append(len(y))
185+
186+
global_params = aggregate_models(client_models, client_sizes)
187+
mse = evaluate_global_model(global_params, clients)
188+
history.append(mse)
189+
190+
print(f"Round {round_num}/{rounds} - Global MSE: {mse:.6f}")
191+
192+
return global_params, history
193+
194+
195+
if __name__ == "__main__":
196+
# Example demonstration
197+
datasets = create_client_datasets(n_clients=5, samples_each=200, n_features=3, noise=0.5, seed=123)
198+
final_model, loss_curve = run_federated_training(datasets, rounds=12, local_epochs=2, lr=0.05)
199+
200+
print("\nFinal model parameters:\n", np.round(final_model, 4))
201+
202+
try:
203+
import matplotlib.pyplot as plt
204+
plt.plot(loss_curve, marker='o')
205+
plt.title("Federated Averaging - Training Loss Curve")
206+
plt.xlabel("Round")
207+
plt.ylabel("Mean Squared Error")
208+
plt.grid(True)
209+
plt.show()
210+
except ImportError:
211+
pass
212+
213+
"""
214+
for testing:
215+
Create "tests/test_federated_averaging.py"
216+
217+
" import numpy as np
218+
from machine_learning.federated_learning import federated_averaging as fed
219+
220+
def test_loss_reduction_in_fedavg():
221+
# Define a small, reproducible test scenario
222+
clients = fed.create_synthetic_clients(
223+
n_clients=3,
224+
samples_per_client=80,
225+
n_features=2,
226+
noise_level=0.3,
227+
seed=0
228+
) "
229+
230+
"""

0 commit comments

Comments
 (0)