Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2a33b15

Browse files
T2T Teamcopybara-github
authored andcommitted
Support callable decay rates in Adafactor
PiperOrigin-RevId: 421333767
1 parent 81c2b2e commit 2a33b15

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensor2tensor/utils/adafactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _resource_apply_dense(self, grad, handle):
217217
grad = tf.to_float(grad)
218218
grad_squared = tf.square(grad) + self._epsilon1
219219
grad_squared_mean = tf.reduce_mean(grad_squared)
220-
decay_rate = self._decay_rate
220+
decay_rate = self._call_if_callable(self._decay_rate)
221221
update_scale = self._call_if_callable(self._learning_rate)
222222
update_scale = tf.convert_to_tensor(update_scale, name="update_scale")
223223
update_scale = tf.cast(update_scale, grad_squared_mean.dtype.base_dtype)

0 commit comments

Comments
 (0)