Skip to content

Commit dc72e87

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 089c51e commit dc72e87

File tree

3 files changed

+5
-0
lines changed

3 files changed

+5
-0
lines changed

machine_learning/neural_network/optimizers/sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import List
66

7+
78
def sgd_update(weights: List[float], grads: List[float], lr: float) -> List[float]:
89
"""
910
Update weights using SGD.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from sgd import sgd_update
22

3+
34
def test_sgd():
45
weights = [0.5, -0.2]
56
grads = [0.1, -0.1]
67
updated = sgd_update(weights, grads, lr=0.01)
78
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"
89

10+
911
if __name__ == "__main__":
1012
test_sgd()
1113
print("SGD test passed!")

test_sgd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from neural_network.optimizers.sgd import sgd_update
22

3+
34
def test_sgd():
45
weights = [0.5, -0.2]
56
grads = [0.1, -0.1]
67
updated = sgd_update(weights, grads, lr=0.01)
78
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"
89

10+
911
if __name__ == "__main__":
1012
test_sgd()
1113
print("SGD test passed!")

0 commit comments

Comments
 (0)