Skip to content

Commit 089c51e

Browse files
committed
feat: add SGD optimizer with unit test and doctest
1 parent e2a78d4 commit 089c51e

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

machine_learning/neural_network/optimizers/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Stochastic Gradient Descent (SGD) optimizer.
3+
"""
4+
5+
from typing import List
6+
7+
def sgd_update(weights: List[float], grads: List[float], lr: float) -> List[float]:
8+
"""
9+
Update weights using SGD.
10+
11+
Args:
12+
weights (List[float]): Current weights
13+
grads (List[float]): Gradients
14+
lr (float): Learning rate
15+
16+
Returns:
17+
List[float]: Updated weights
18+
19+
Example:
20+
>>> sgd_update([0.5, -0.2], [0.1, -0.1], 0.01)
21+
[0.499, -0.199]
22+
"""
23+
return [w - lr * g for w, g in zip(weights, grads)]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from sgd import sgd_update
2+
3+
def test_sgd():
4+
weights = [0.5, -0.2]
5+
grads = [0.1, -0.1]
6+
updated = sgd_update(weights, grads, lr=0.01)
7+
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"
8+
9+
if __name__ == "__main__":
10+
test_sgd()
11+
print("SGD test passed!")

test_sgd.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from neural_network.optimizers.sgd import sgd_update
2+
3+
def test_sgd():
4+
weights = [0.5, -0.2]
5+
grads = [0.1, -0.1]
6+
updated = sgd_update(weights, grads, lr=0.01)
7+
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"
8+
9+
if __name__ == "__main__":
10+
test_sgd()
11+
print("SGD test passed!")

0 commit comments

Comments
 (0)