Skip to content

Commit 45aaf40

Browse files
committed
Fix_ruff_issues_in_FedAvg_module
1 parent cd5f2dd commit 45aaf40

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

machine_learning/federated_averaging.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111
Basic equal-weight averaging across two "clients" with two tensors each
1212
(vector and 2x2 matrix):
1313
14-
>>> A = [np.array([1.0, 2.0]), np.array([[1.0, 2.0], [3.0, 4.0]])]
15-
>>> B = [np.array([3.0, 4.0]), np.array([[5.0, 6.0], [7.0, 8.0]])]
14+
>>> A = [
15+
... np.array([1.0, 2.0]),
16+
... np.array([[1.0, 2.0], [3.0, 4.0]]),
17+
... ]
18+
>>> B = [
19+
... np.array([3.0, 4.0]),
20+
... np.array([[5.0, 6.0], [7.0, 8.0]]),
21+
... ]
1622
>>> eq = federated_average([A, B])
1723
>>> eq[0].tolist()
1824
[2.0, 3.0]
@@ -21,7 +27,10 @@
2127
2228
Weighted averaging with weights [2, 1] (normalized to [2/3, 1/3]):
2329
24-
>>> w = federated_average([A, B], weights=np.array([2.0, 1.0]))
30+
>>> w = federated_average(
31+
... [A, B],
32+
... weights=np.array([2.0, 1.0]),
33+
... )
2534
>>> w[0].tolist()
2635
[1.6666666666666665, 2.6666666666666665]
2736
>>> w[1].tolist()
@@ -46,7 +55,10 @@
4655
4756
- Mismatched tensor shapes across clients
4857
49-
>>> C2 = [np.array([1.0, 2.0]), np.array([[1.0, 2.0]])] # second tensor has different shape
58+
>>> C2 = [
59+
... np.array([1.0, 2.0]),
60+
... np.array([[1.0, 2.0]]),
61+
... ] # second tensor has different shape
5062
>>> federated_average([A, C2]) # doctest: +ELLIPSIS
5163
Traceback (most recent call last):
5264
...
@@ -64,14 +76,19 @@
6476
...
6577
ValueError: weights must sum to a positive value
6678
67-
>>> federated_average([A, B], weights=np.array([1.0, 2.0, 3.0])) # doctest: +ELLIPSIS
79+
>>> federated_average(
80+
... [A, B],
81+
... weights=np.array([1.0, 2.0, 3.0]),
82+
... ) # doctest: +ELLIPSIS
6883
Traceback (most recent call last):
6984
...
7085
ValueError: weights must have shape (2,)
7186
"""
7287

7388
from __future__ import annotations
74-
from typing import Iterable, List, Sequence
89+
90+
from collections.abc import Sequence
91+
7592
import numpy as np
7693

7794

@@ -85,14 +102,17 @@ def _validate_clients(client_models: Sequence[Sequence[np.ndarray]]) -> None:
85102
raise ValueError("All clients must have the same number of tensors")
86103
for s_ref, arr in zip(ref_shapes, cm):
87104
if tuple(arr.shape) != s_ref:
88-
raise ValueError(
89-
f"Client {idx} tensor shape {tuple(arr.shape)} does not match {s_ref}"
105+
msg = (
106+
f"Client {idx} tensor shape {tuple(arr.shape)} "
107+
f"does not match {s_ref}"
90108
)
109+
raise ValueError(msg)
91110

92111

93112
def _normalize_weights(weights: np.ndarray, num_clients: int) -> np.ndarray:
94113
if weights.shape != (num_clients,):
95-
raise ValueError(f"weights must have shape ({num_clients},)")
114+
msg = f"weights must have shape ({num_clients},)"
115+
raise ValueError(msg)
96116
if np.any(weights < 0):
97117
raise ValueError("weights must be non-negative")
98118
total = float(weights.sum())
@@ -104,7 +124,7 @@ def _normalize_weights(weights: np.ndarray, num_clients: int) -> np.ndarray:
104124
def federated_average(
105125
client_models: Sequence[Sequence[np.ndarray]],
106126
weights: np.ndarray | None = None,
107-
) -> List[np.ndarray]:
127+
) -> list[np.ndarray]:
108128
"""Compute the weighted average of clients' model tensors.
109129
110130
Parameters
@@ -118,7 +138,7 @@ def federated_average(
118138
119139
Returns
120140
-------
121-
List[np.ndarray]
141+
list[np.ndarray]
122142
The list of aggregated tensors with the same shapes as the inputs.
123143
"""
124144
_validate_clients(client_models)
@@ -131,7 +151,7 @@ def federated_average(
131151
weights_n = _normalize_weights(weights, num_clients)
132152

133153
num_tensors = len(client_models[0])
134-
aggregated: List[np.ndarray] = []
154+
aggregated: list[np.ndarray] = []
135155
for t_idx in range(num_tensors):
136156
# Stack the t_idx-th tensor from each client into shape (num_clients, ...)
137157
stacked = np.stack([np.asarray(cm[t_idx]) for cm in client_models], axis=0)

0 commit comments

Comments
 (0)