Skip to content

Commit 98777fb

Browse files
committed
Add comprehensive type hints to neural network optimizers
- Add type hints to all internal helper functions as required by algorithms-keeper bot - Fix function signatures for _adagrad_update_recursive, _adam_update_recursive, _nag_update_recursive, and _check_shapes_and_get_velocity - Add type hints to example functions: rosenbrock, gradient_f, f - Update imports to include Tuple type where needed - Maintain all existing functionality with 58 passing doctests - Resolve all algorithms-keeper bot feedback for PR approval
1 parent 05b5c45 commit 98777fb

File tree

5 files changed

+85
-65
lines changed

5 files changed

+85
-65
lines changed

neural_network/optimizers/adagrad.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,46 +119,50 @@ def update(
119119
ValueError: If parameters and gradients have different shapes
120120
"""
121121

122-
def _adagrad_update_recursive(params, grads, acc_grads):
122+
def _adagrad_update_recursive(
123+
parameters: Union[float, List[Union[float, List[float]]]],
124+
gradients: Union[float, List[Union[float, List[float]]]],
125+
accumulated_gradients: Union[float, List[Union[float, List[float]]]]
126+
) -> tuple[Union[float, List[Union[float, List[float]]]], Union[float, List[Union[float, List[float]]]]]:
123127
# Handle scalar case
124-
if isinstance(params, (int, float)):
125-
if not isinstance(grads, (int, float)):
128+
if isinstance(parameters, (int, float)):
129+
if not isinstance(gradients, (int, float)):
126130
raise ValueError(
127131
"Shape mismatch: parameter is scalar but gradient is not"
128132
)
129133

130-
if acc_grads is None:
131-
acc_grads = 0.0
134+
if accumulated_gradients is None:
135+
accumulated_gradients = 0.0
132136

133137
# Accumulate squared gradients: G = G + g^2
134-
new_acc_grads = acc_grads + grads * grads
138+
new_acc_grads = accumulated_gradients + gradients * gradients
135139

136140
# Adaptive learning rate: α / √(G + ε)
137141
adaptive_lr = self.learning_rate / math.sqrt(
138142
new_acc_grads + self.epsilon
139143
)
140144

141145
# Parameter update: θ = θ - adaptive_lr * g
142-
new_param = params - adaptive_lr * grads
146+
new_param = parameters - adaptive_lr * gradients
143147

144148
return new_param, new_acc_grads
145149

146150
# Handle list case
147-
if len(params) != len(grads):
151+
if len(parameters) != len(gradients):
148152
raise ValueError(
149-
f"Shape mismatch: parameters length {len(params)} vs "
150-
f"gradients length {len(grads)}"
153+
f"Shape mismatch: parameters length {len(parameters)} vs "
154+
f"gradients length {len(gradients)}"
151155
)
152156

153-
if acc_grads is None:
154-
acc_grads = [None] * len(params)
155-
elif len(acc_grads) != len(params):
157+
if accumulated_gradients is None:
158+
accumulated_gradients = [None] * len(parameters)
159+
elif len(accumulated_gradients) != len(parameters):
156160
raise ValueError("Accumulated gradients shape mismatch")
157161

158162
new_params = []
159163
new_acc_grads = []
160164

161-
for i, (p, g, ag) in enumerate(zip(params, grads, acc_grads)):
165+
for i, (p, g, ag) in enumerate(zip(parameters, gradients, accumulated_gradients)):
162166
if isinstance(p, list) and isinstance(g, list):
163167
# Recursive case for nested lists
164168
new_p, new_ag = _adagrad_update_recursive(p, g, ag)

neural_network/optimizers/adam.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
import math
19-
from typing import List, Union
19+
from typing import List, Union, Tuple
2020

2121
from .base_optimizer import BaseOptimizer
2222

@@ -161,45 +161,50 @@ def update(
161161
bias_correction1 = 1 - self.beta1**self._time_step
162162
bias_correction2 = 1 - self.beta2**self._time_step
163163

164-
def _adam_update_recursive(params, grads, first_moment, second_moment):
164+
def _adam_update_recursive(
165+
parameters: Union[float, List],
166+
gradients: Union[float, List],
167+
first_moment: Union[float, List],
168+
second_moment: Union[float, List]
169+
) -> Tuple[Union[float, List], Union[float, List], Union[float, List]]:
165170
# Handle scalar case
166-
if isinstance(params, (int, float)):
167-
if not isinstance(grads, (int, float)):
171+
if isinstance(parameters, (int, float)):
172+
if not isinstance(gradients, (int, float)):
168173
raise ValueError(
169174
"Shape mismatch: parameter is scalar but gradient is not"
170175
)
171176

172177
# Update first moment: m = β₁ * m + (1-β₁) * g
173-
new_first_moment = self.beta1 * first_moment + (1 - self.beta1) * grads
178+
new_first_moment = self.beta1 * first_moment + (1 - self.beta1) * gradients
174179

175180
# Update second moment: v = β₂ * v + (1-β₂) * g²
176181
new_second_moment = self.beta2 * second_moment + (1 - self.beta2) * (
177-
grads * grads
182+
gradients * gradients
178183
)
179184

180185
# Bias-corrected moments
181186
m_hat = new_first_moment / bias_correction1
182187
v_hat = new_second_moment / bias_correction2
183188

184189
# Parameter update: θ = θ - α * m̂ / (√v̂ + ε)
185-
new_param = params - self.learning_rate * m_hat / (
190+
new_param = parameters - self.learning_rate * m_hat / (
186191
math.sqrt(v_hat) + self.epsilon
187192
)
188193

189194
return new_param, new_first_moment, new_second_moment
190195

191196
# Handle list case
192-
if len(params) != len(grads):
197+
if len(parameters) != len(gradients):
193198
raise ValueError(
194-
f"Shape mismatch: parameters length {len(params)} vs "
195-
f"gradients length {len(grads)}"
199+
f"Shape mismatch: parameters length {len(parameters)} vs "
200+
f"gradients length {len(gradients)}"
196201
)
197202

198203
new_params = []
199204
new_first_moments = []
200205
new_second_moments = []
201206

202-
for p, g, m1, m2 in zip(params, grads, first_moment, second_moment):
207+
for p, g, m1, m2 in zip(parameters, gradients, first_moment, second_moment):
203208
if isinstance(p, list) and isinstance(g, list):
204209
# Recursive case for nested lists
205210
new_p, new_m1, new_m2 = _adam_update_recursive(p, g, m1, m2)
@@ -309,11 +314,11 @@ def __str__(self) -> str:
309314
x_adagrad = [-1.0, 1.0]
310315
x_adam = [-1.0, 1.0]
311316

312-
def rosenbrock(x, y):
317+
def rosenbrock(x: float, y: float) -> float:
313318
"""Rosenbrock function: f(x,y) = 100*(y-x²)² + (1-x)²"""
314319
return 100 * (y - x * x) ** 2 + (1 - x) ** 2
315320

316-
def rosenbrock_gradient(x, y):
321+
def rosenbrock_gradient(x: float, y: float) -> List[float]:
317322
"""Gradient of Rosenbrock function"""
318323
df_dx = -400 * x * (y - x * x) - 2 * (1 - x)
319324
df_dy = 200 * (y - x * x)

neural_network/optimizers/momentum_sgd.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import List, Union
18+
from typing import List, Union, Tuple
1919

2020
from .base_optimizer import BaseOptimizer
2121

@@ -115,40 +115,44 @@ def update(
115115
ValueError: If parameters and gradients have different shapes
116116
"""
117117

118-
def _check_shapes_and_get_velocity(params, grads, velocity):
118+
def _check_shapes_and_get_velocity(
119+
parameters: Union[float, List[Union[float, List[float]]]],
120+
gradients: Union[float, List[Union[float, List[float]]]],
121+
velocity_values: Union[float, List[Union[float, List[float]]]]
122+
) -> Tuple[Union[float, List[Union[float, List[float]]]], Union[float, List[Union[float, List[float]]]]]:
119123
# Handle scalar case
120-
if isinstance(params, (int, float)):
121-
if not isinstance(grads, (int, float)):
124+
if isinstance(parameters, (int, float)):
125+
if not isinstance(gradients, (int, float)):
122126
raise ValueError(
123127
"Shape mismatch: parameter is scalar but gradient is not"
124128
)
125129

126-
if velocity is None:
127-
velocity = 0.0
130+
if velocity_values is None:
131+
velocity_values = 0.0
128132

129133
# Update velocity: v = β * v + (1-β) * g
130-
new_velocity = self.momentum * velocity + (1 - self.momentum) * grads
134+
new_velocity = self.momentum * velocity_values + (1 - self.momentum) * gradients
131135
# Update parameter: θ = θ - α * v
132-
new_param = params - self.learning_rate * new_velocity
136+
new_param = parameters - self.learning_rate * new_velocity
133137

134138
return new_param, new_velocity
135139

136140
# Handle list case
137-
if len(params) != len(grads):
141+
if len(parameters) != len(gradients):
138142
raise ValueError(
139-
f"Shape mismatch: parameters length {len(params)} vs "
140-
f"gradients length {len(grads)}"
143+
f"Shape mismatch: parameters length {len(parameters)} vs "
144+
f"gradients length {len(gradients)}"
141145
)
142146

143-
if velocity is None:
144-
velocity = [None] * len(params)
145-
elif len(velocity) != len(params):
147+
if velocity_values is None:
148+
velocity_values = [None] * len(parameters)
149+
elif len(velocity_values) != len(parameters):
146150
raise ValueError("Velocity shape mismatch")
147151

148152
new_params = []
149153
new_velocity = []
150154

151-
for i, (p, g, v) in enumerate(zip(params, grads, velocity)):
155+
for i, (p, g, v) in enumerate(zip(parameters, gradients, velocity_values)):
152156
if isinstance(p, list) and isinstance(g, list):
153157
# Recursive case for nested lists
154158
new_p, new_v = _check_shapes_and_get_velocity(p, g, v)

neural_network/optimizers/nag.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from __future__ import annotations
2020

21-
from typing import List, Union
21+
from typing import List, Union, Tuple
2222

2323
from .base_optimizer import BaseOptimizer
2424

@@ -117,10 +117,14 @@ def update(
117117
ValueError: If parameters and gradients have different shapes
118118
"""
119119

120-
def _nag_update_recursive(params, grads, velocity):
120+
def _nag_update_recursive(
121+
parameters: Union[float, List],
122+
gradients: Union[float, List],
123+
velocity: Union[float, List, None]
124+
) -> Tuple[Union[float, List], Union[float, List]]:
121125
# Handle scalar case
122-
if isinstance(params, (int, float)):
123-
if not isinstance(grads, (int, float)):
126+
if isinstance(parameters, (int, float)):
127+
if not isinstance(gradients, (int, float)):
124128
raise ValueError(
125129
"Shape mismatch: parameter is scalar but gradient is not"
126130
)
@@ -129,32 +133,32 @@ def _nag_update_recursive(params, grads, velocity):
129133
velocity = 0.0
130134

131135
# Update velocity: v = β * v + (1-β) * g
132-
new_velocity = self.momentum * velocity + (1 - self.momentum) * grads
136+
new_velocity = self.momentum * velocity + (1 - self.momentum) * gradients
133137

134138
# NAG update: θ = θ - α * (β * v + (1-β) * g)
135139
nesterov_update = (
136-
self.momentum * new_velocity + (1 - self.momentum) * grads
140+
self.momentum * new_velocity + (1 - self.momentum) * gradients
137141
)
138-
new_param = params - self.learning_rate * nesterov_update
142+
new_param = parameters - self.learning_rate * nesterov_update
139143

140144
return new_param, new_velocity
141145

142146
# Handle list case
143-
if len(params) != len(grads):
147+
if len(parameters) != len(gradients):
144148
raise ValueError(
145-
f"Shape mismatch: parameters length {len(params)} vs "
146-
f"gradients length {len(grads)}"
149+
f"Shape mismatch: parameters length {len(parameters)} vs "
150+
f"gradients length {len(gradients)}"
147151
)
148152

149153
if velocity is None:
150-
velocity = [None] * len(params)
151-
elif len(velocity) != len(params):
154+
velocity = [None] * len(parameters)
155+
elif len(velocity) != len(parameters):
152156
raise ValueError("Velocity shape mismatch")
153157

154158
new_params = []
155159
new_velocity = []
156160

157-
for i, (p, g, v) in enumerate(zip(params, grads, velocity)):
161+
for i, (p, g, v) in enumerate(zip(parameters, gradients, velocity)):
158162
if isinstance(p, list) and isinstance(g, list):
159163
# Recursive case for nested lists
160164
new_p, new_v = _nag_update_recursive(p, g, v)
@@ -250,11 +254,11 @@ def __str__(self) -> str:
250254
x_momentum = [2.5]
251255
x_nag = [2.5]
252256

253-
def gradient_f(x):
257+
def gradient_f(x: float) -> float:
254258
"""Gradient of f(x) = 0.1*x^4 - 2*x^2 + x is f'(x) = 0.4*x^3 - 4*x + 1"""
255259
return 0.4 * x**3 - 4 * x + 1
256260

257-
def f(x):
261+
def f(x: float) -> float:
258262
"""The function f(x) = 0.1*x^4 - 2*x^2 + x"""
259263
return 0.1 * x**4 - 2 * x**2 + x
260264

neural_network/optimizers/sgd.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,27 @@ def update(
9898
ValueError: If parameters and gradients have different shapes
9999
"""
100100

101-
def _check_and_update_recursive(params, grads):
101+
def _check_and_update_recursive(
102+
parameters: Union[float, List[Union[float, List[float]]]],
103+
gradients: Union[float, List[Union[float, List[float]]]]
104+
) -> Union[float, List[Union[float, List[float]]]]:
102105
# Handle 1D case (list of floats)
103-
if isinstance(params, (int, float)):
104-
if not isinstance(grads, (int, float)):
106+
if isinstance(parameters, (int, float)):
107+
if not isinstance(gradients, (int, float)):
105108
raise ValueError(
106109
"Shape mismatch: parameter is scalar but gradient is not"
107110
)
108-
return params - self.learning_rate * grads
111+
return parameters - self.learning_rate * gradients
109112

110113
# Handle list case
111-
if len(params) != len(grads):
114+
if len(parameters) != len(gradients):
112115
raise ValueError(
113-
f"Shape mismatch: parameters length {len(params)} vs "
114-
f"gradients length {len(grads)}"
116+
f"Shape mismatch: parameters length {len(parameters)} vs "
117+
f"gradients length {len(gradients)}"
115118
)
116119

117120
result = []
118-
for p, g in zip(params, grads):
121+
for p, g in zip(parameters, gradients):
119122
if isinstance(p, list) and isinstance(g, list):
120123
# Recursive case for nested lists
121124
result.append(_check_and_update_recursive(p, g))

0 commit comments

Comments
 (0)