1111Basic equal-weight averaging across two "clients" with two tensors each
1212(vector and 2x2 matrix):
1313
14- >>> A = [np.array([1.0, 2.0]), np.array([[1.0, 2.0], [3.0, 4.0]])]
15- >>> B = [np.array([3.0, 4.0]), np.array([[5.0, 6.0], [7.0, 8.0]])]
14+ >>> A = [
15+ ... np.array([1.0, 2.0]),
16+ ... np.array([[1.0, 2.0], [3.0, 4.0]]),
17+ ... ]
18+ >>> B = [
19+ ... np.array([3.0, 4.0]),
20+ ... np.array([[5.0, 6.0], [7.0, 8.0]]),
21+ ... ]
1622>>> eq = federated_average([A, B])
1723>>> eq[0].tolist()
1824[2.0, 3.0]
2127
2228Weighted averaging with weights [2, 1] (normalized to [2/3, 1/3]):
2329
24- >>> w = federated_average([A, B], weights=np.array([2.0, 1.0]))
30+ >>> w = federated_average(
31+ ... [A, B],
32+ ... weights=np.array([2.0, 1.0]),
33+ ... )
2534>>> w[0].tolist()
2635[1.6666666666666665, 2.6666666666666665]
2736>>> w[1].tolist()
4655
4756- Mismatched tensor shapes across clients
4857
49- >>> C2 = [np.array([1.0, 2.0]), np.array([[1.0, 2.0]])] # second tensor has different shape
58+ >>> C2 = [
59+ ... np.array([1.0, 2.0]),
60+ ... np.array([[1.0, 2.0]]),
61+ ... ] # second tensor has different shape
5062>>> federated_average([A, C2]) # doctest: +ELLIPSIS
5163Traceback (most recent call last):
5264...
6476...
6577ValueError: weights must sum to a positive value
6678
67- >>> federated_average([A, B], weights=np.array([1.0, 2.0, 3.0])) # doctest: +ELLIPSIS
79+ >>> federated_average(
80+ ... [A, B],
81+ ... weights=np.array([1.0, 2.0, 3.0]),
82+ ... ) # doctest: +ELLIPSIS
6883Traceback (most recent call last):
6984...
7085ValueError: weights must have shape (2,)
7186"""
7287
7388from __future__ import annotations
74- from typing import Iterable , List , Sequence
89+
90+ from collections .abc import Sequence
91+
7592import numpy as np
7693
7794
@@ -85,14 +102,17 @@ def _validate_clients(client_models: Sequence[Sequence[np.ndarray]]) -> None:
85102 raise ValueError ("All clients must have the same number of tensors" )
86103 for s_ref , arr in zip (ref_shapes , cm ):
87104 if tuple (arr .shape ) != s_ref :
88- raise ValueError (
89- f"Client { idx } tensor shape { tuple (arr .shape )} does not match { s_ref } "
105+ msg = (
106+ f"Client { idx } tensor shape { tuple (arr .shape )} "
107+ f"does not match { s_ref } "
90108 )
109+ raise ValueError (msg )
91110
92111
93112def _normalize_weights (weights : np .ndarray , num_clients : int ) -> np .ndarray :
94113 if weights .shape != (num_clients ,):
95- raise ValueError (f"weights must have shape ({ num_clients } ,)" )
114+ msg = f"weights must have shape ({ num_clients } ,)"
115+ raise ValueError (msg )
96116 if np .any (weights < 0 ):
97117 raise ValueError ("weights must be non-negative" )
98118 total = float (weights .sum ())
@@ -104,7 +124,7 @@ def _normalize_weights(weights: np.ndarray, num_clients: int) -> np.ndarray:
104124def federated_average (
105125 client_models : Sequence [Sequence [np .ndarray ]],
106126 weights : np .ndarray | None = None ,
107- ) -> List [np .ndarray ]:
127+ ) -> list [np .ndarray ]:
108128 """Compute the weighted average of clients' model tensors.
109129
110130 Parameters
@@ -118,7 +138,7 @@ def federated_average(
118138
119139 Returns
120140 -------
121- List [np.ndarray]
141+ list [np.ndarray]
122142 The list of aggregated tensors with the same shapes as the inputs.
123143 """
124144 _validate_clients (client_models )
@@ -131,7 +151,7 @@ def federated_average(
131151 weights_n = _normalize_weights (weights , num_clients )
132152
133153 num_tensors = len (client_models [0 ])
134- aggregated : List [np .ndarray ] = []
154+ aggregated : list [np .ndarray ] = []
135155 for t_idx in range (num_tensors ):
136156 # Stack the t_idx-th tensor from each client into shape (num_clients, ...)
137157 stacked = np .stack ([np .asarray (cm [t_idx ]) for cm in client_models ], axis = 0 )
0 commit comments