Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 81c2b2e

Browse files
T2T Teamcopybara-github
authored andcommitted
Update adafactor so it can accept a callable learning rate.
PiperOrigin-RevId: 413718409
1 parent 86caf01 commit 81c2b2e

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

tensor2tensor/utils/adafactor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self,
122122
123123
Args:
124124
multiply_by_parameter_scale: a boolean
125-
learning_rate: an optional Scalar.
125+
learning_rate: an optional Scalar or callable.
126126
decay_rate: an optional Scalar.
127127
beta1: a float value between 0 and 1
128128
clipping_threshold: an optional float >= 1
@@ -218,7 +218,9 @@ def _resource_apply_dense(self, grad, handle):
218218
grad_squared = tf.square(grad) + self._epsilon1
219219
grad_squared_mean = tf.reduce_mean(grad_squared)
220220
decay_rate = self._decay_rate
221-
update_scale = self._learning_rate
221+
update_scale = self._call_if_callable(self._learning_rate)
222+
update_scale = tf.convert_to_tensor(update_scale, name="update_scale")
223+
update_scale = tf.cast(update_scale, grad_squared_mean.dtype.base_dtype)
222224
old_val = var
223225
if var.dtype.base_dtype == tf.bfloat16:
224226
old_val = tf.to_float(self._parameter_encoding.decode(old_val))
@@ -272,6 +274,7 @@ def _resource_apply_dense(self, grad, handle):
272274
new_val = quantization.simulated_quantize(
273275
var - subtrahend, self._simulated_quantize_bits,
274276
self._quantization_noise)
277+
new_val = tf.cast(new_val, var.dtype)
275278
var_update = tf.assign(var, new_val, use_locking=self._use_locking)
276279
updates = [var_update] + updates
277280
return tf.group(*updates)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for adafactor."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensor2tensor.utils import adafactor
23+
24+
import tensorflow as tf
25+
26+
27+
class AdafactorTest(tf.test.TestCase):
28+
29+
def testCallableLearningRate(self):
30+
def lr():
31+
return 0.01
32+
33+
opt = adafactor.AdafactorOptimizer(learning_rate=lr)
34+
v1 = tf.Variable([1., 2.])
35+
v2 = tf.Variable([3., 4.])
36+
with tf.GradientTape() as tape:
37+
tape.watch([v1, v2])
38+
loss = v1 * v2
39+
v1_grad, v2_grad = tape.gradient(loss, [v1, v2])
40+
opt.apply_gradients(((v1_grad, v1), (v2_grad, v2)))
41+
42+
43+
if __name__ == '__main__':
44+
tf.test.main()

0 commit comments

Comments
 (0)