Skip to content

Commit c84b7af

Browse files
minnendcopybara-github
authored andcommitted
Updates estimate_tails to avoid infinite loops and find better solutions:
1. Terminate the search loop if the loss is small. This avoid an infinite loop when the loss is NaN or when the initial guess is correct (a local minimum). 2. Keep track of the best value (lowest loss) encountered so far. This helps since the final value does not always have the lowest loss. 3. Reduce the learning (from 0.1 down to 0.01) during the 100 steps taken after a zero-crossing is found. This typically leads to a lower final loss. PiperOrigin-RevId: 474370252 Change-Id: I6eab66547a8d920c7a12647681766140858faeb3
1 parent d323fa6 commit c84b7af

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

tensorflow_compression/python/distributions/helpers.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,32 +57,48 @@ def estimate_tails(func, target, shape, dtype):
5757
shape = tf.convert_to_tensor(shape, tf.int32)
5858
target = tf.convert_to_tensor(target, dtype)
5959

60-
def loop_cond(tails, m, v, count):
61-
del tails, m, v # unused
62-
return tf.reduce_min(count) < 100
63-
64-
def loop_body(tails, prev_m, prev_v, count):
60+
def loop_cond(tails, m, v, loss, count, best_tails, best_loss):
61+
del tails, m, v, best_tails, best_loss # unused
62+
# By checking `loss`, we catch NaNs and protect against infinite loops
63+
# from perfect initial guesses where there is no zero-crossing.
64+
return tf.logical_and(tf.reduce_max(loss) > 1e-8,
65+
tf.reduce_min(count) < 100)
66+
67+
def loop_body(tails, prev_m, prev_v, loss, count, best_tails, best_loss):
68+
del loss # always recomputed
6569
with tf.GradientTape(watch_accessed_variables=False) as tape:
6670
tape.watch(tails)
6771
loss = abs(func(tails) - target)
72+
73+
# Keep track of the best (lowest loss) value so far.
74+
condition = (loss < best_loss)
75+
best_tails = tf.where(condition, tails, best_tails)
76+
best_loss = tf.where(condition, loss, best_loss)
77+
6878
grad = tape.gradient(loss, tails)
6979
m = (prev_m + grad) / 2 # Adam mean estimate.
7080
v = (prev_v + tf.square(grad)) / 2 # Adam variance estimate.
71-
tails -= .1 * m / (tf.sqrt(v) + 1e-20)
81+
82+
# Reduce learning rate as count increases. This should lead to a more
83+
# accurate final value.
84+
k = tf.math.sqrt(tf.cast(count + 1, m.dtype))
85+
tails -= 0.1 * m / (k * tf.sqrt(v) + 1e-20)
86+
7287
# Start counting when the gradient flips sign. Since the function is
7388
# monotonic, m must have the same sign in all initial iterations, until
7489
# the optimal point is crossed. At that point the gradient flips sign.
75-
count = tf.where(
76-
tf.math.logical_or(count > 0, prev_m * grad < 0),
77-
count + 1, count)
78-
return tails, m, v, count
90+
count = tf.where(tf.math.logical_or(count > 0, prev_m * grad < 0),
91+
count + 1, count)
92+
return tails, m, v, loss, count, best_tails, best_loss
7993

8094
init_tails = tf.zeros(shape, dtype=dtype)
8195
init_m = tf.zeros(shape, dtype=dtype)
8296
init_v = tf.ones(shape, dtype=dtype)
97+
init_loss = init_v * dtype.max
8398
init_count = tf.zeros(shape, dtype=tf.int32)
84-
return tf.while_loop(
85-
loop_cond, loop_body, (init_tails, init_m, init_v, init_count))[0]
99+
loop_vars = (init_tails, init_m, init_v, init_loss, init_count,
100+
init_tails, init_loss)
101+
return tf.while_loop(loop_cond, loop_body, loop_vars)[-2]
86102

87103

88104
def quantization_offset(distribution):

tensorflow_compression/python/distributions/helpers_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@
2222

2323
class HelpersTest(tf.test.TestCase):
2424

25+
def test_nan_terminates(self):
26+
# Return a NaN tensor that would otherwise have a gradient wrt x.
27+
func = lambda x: tf.math.tanh(x) * float("nan")
28+
helpers.estimate_tails(func, target=0.5, shape=(), dtype=tf.float32)
29+
30+
def test_perfect_initial_guess_terminates(self):
31+
# The initial guess is zero, which causes problems if the minimum is also
32+
# at zero since then there's no zero-crossing to trigger the count.
33+
helpers.estimate_tails(tf.math.tanh, target=0.0, shape=(), dtype=tf.float32)
34+
2535
def test_cauchy_quantizes_to_mode_decimal_part(self):
2636
dist = tfp.distributions.Cauchy(loc=1.4, scale=3.)
2737
self.assertAllClose(helpers.quantization_offset(dist), 0.4)

0 commit comments

Comments
 (0)