Skip to content

Commit f1cf808

Browse files
committed
Add_Federated_Averaging_FedAvg_module_with_doctests
1 parent e2a78d4 commit f1cf808

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Federated Averaging (FedAvg)
3+
https://arxiv.org/abs/1602.05629
4+
5+
This module provides a minimal, educational implementation of the Federated
6+
Learning paradigm using the Federated Averaging algorithm. Multiple clients
7+
compute local model updates on their private data and the server aggregates
8+
their updates by (weighted) averaging without collecting raw data.
9+
10+
Notes
11+
-----
12+
- This implementation is framework-agnostic and uses NumPy arrays to represent
13+
model parameters for simplicity and portability within this repository.
14+
- It demonstrates the mechanics of FedAvg, not production concerns like
15+
privacy amplification (e.g., differential privacy), robustness, or security.
16+
17+
Terminology
18+
-----------
19+
- Global model: a list of NumPy arrays representing model parameters.
20+
- Client update: new model parameters produced locally, or the delta from the
21+
global model; we aggregate parameters directly here for clarity.
22+
23+
Examples
24+
--------
25+
Create three synthetic "clients" whose local training produces simple parameter
26+
arrays, then aggregate them with FedAvg.
27+
28+
>>> import numpy as np
29+
>>> # Global model with two parameter tensors
30+
>>> global_model = [np.array([0.0, 0.0]), np.array([[0.0]])]
31+
>>> # Client models after local training
32+
>>> client_models = [
33+
... [np.array([1.0, 2.0]), np.array([[1.0]])],
34+
... [np.array([3.0, 4.0]), np.array([[3.0]])],
35+
... [np.array([5.0, 6.0]), np.array([[5.0]])],
36+
... ]
37+
>>> # Equal weights -> simple average
38+
>>> new_global = federated_average(client_models)
39+
>>> [arr.tolist() for arr in new_global]
40+
[[3.0, 4.0], [[3.0]]]
41+
42+
Weighted averaging by client data sizes:
43+
44+
>>> weights = np.array([10, 20, 30], dtype=float)
45+
>>> new_global_w = federated_average(client_models, weights)
46+
>>> [arr.tolist() for arr in new_global_w]
47+
[[3.6666666666666665, 4.666666666666666], [[3.6666666666666665]]]
48+
49+
Contract
50+
--------
51+
Inputs:
52+
- client_models: list[list[np.ndarray]]: each inner list mirrors model layers
53+
- weights: Optional[np.ndarray] of shape (num_clients,), non-negative, sums to > 0
54+
Output:
55+
- list[np.ndarray]: aggregated model parameters, same shapes as client models
56+
Error modes:
57+
- ValueError for empty clients, shape mismatch, or invalid weights
58+
"""
59+
60+
from __future__ import annotations
61+
62+
from typing import Iterable, List, Sequence
63+
64+
import numpy as np
65+
66+
67+
def _validate_clients(client_models: Sequence[Sequence[np.ndarray]]) -> None:
68+
if not client_models:
69+
raise ValueError("client_models must be a non-empty list")
70+
# Ensure all clients have same number of layers and shapes
71+
ref_shapes = [tuple(arr.shape) for arr in client_models[0]]
72+
for idx, cm in enumerate(client_models, start=1):
73+
if len(cm) != len(ref_shapes):
74+
raise ValueError("All clients must have the same number of tensors")
75+
for s_ref, arr in zip(ref_shapes, cm):
76+
if tuple(arr.shape) != s_ref:
77+
raise ValueError(
78+
f"Client {idx} tensor shape {tuple(arr.shape)} does not match {s_ref}"
79+
)
80+
81+
82+
def _normalize_weights(weights: np.ndarray, n: int) -> np.ndarray:
83+
if weights.shape != (n,):
84+
raise ValueError(f"weights must have shape ({n},)")
85+
if np.any(weights < 0):
86+
raise ValueError("weights must be non-negative")
87+
total = float(weights.sum())
88+
if total <= 0.0:
89+
raise ValueError("weights must sum to a positive value")
90+
return weights / total
91+
92+
93+
def federated_average(
94+
client_models: Sequence[Sequence[np.ndarray]],
95+
weights: np.ndarray | None = None,
96+
) -> List[np.ndarray]:
97+
"""
98+
Aggregate client model parameters using (weighted) averaging.
99+
100+
Parameters
101+
----------
102+
client_models : list[list[np.ndarray]]
103+
Model parameters for each client; all clients must have same shapes.
104+
weights : np.ndarray | None
105+
Optional non-negative weights per client. If None, equal weights.
106+
107+
Returns
108+
-------
109+
list[np.ndarray]
110+
Aggregated model parameters (same shapes as client tensors).
111+
112+
Examples
113+
--------
114+
>>> import numpy as np
115+
>>> cm = [
116+
... [np.array([1.0, 2.0])],
117+
... [np.array([3.0, 4.0])],
118+
... ]
119+
>>> [arr.tolist() for arr in federated_average(cm)]
120+
[[2.0, 3.0]]
121+
>>> w = np.array([1.0, 3.0])
122+
>>> [arr.tolist() for arr in federated_average(cm, w)]
123+
[[2.5, 3.5]]
124+
"""
125+
_validate_clients(client_models)
126+
num_clients = len(client_models)
127+
128+
if weights is None:
129+
weights_n = np.full((num_clients,), 1.0 / num_clients, dtype=float)
130+
else:
131+
weights = np.asarray(weights, dtype=float)
132+
weights_n = _normalize_weights(weights, num_clients)
133+
134+
num_tensors = len(client_models[0])
135+
aggregated: List[np.ndarray] = []
136+
for t_idx in range(num_tensors):
137+
# Stack the t_idx-th tensor from each client into shape (num_clients, ...)
138+
stacked = np.stack([np.asarray(cm[t_idx]) for cm in client_models], axis=0)
139+
# Weighted sum across clients axis=0
140+
# np.tensordot weights of shape (n,) with stacked of shape (n, *dims)
141+
agg = np.tensordot(weights_n, stacked, axes=(0, 0))
142+
aggregated.append(np.asarray(agg))
143+
144+
return aggregated
145+
146+
147+
if __name__ == "__main__":
148+
import doctest
149+
150+
doctest.testmod()

0 commit comments

Comments
 (0)