@@ -39,13 +39,14 @@ def soft_round(x, alpha, eps=1e-3):
3939 Returns:
4040 tf.Tensor
4141 """
42- if isinstance (alpha , (float , int )) and alpha < eps :
43- return tf .identity (x , name = "soft_round" )
42+ # This guards the gradient of tf.where below against NaNs, while maintaining
43+ # correctness, as for alpha < eps the result is ignored.
44+ alpha_bounded = tf .maximum (alpha , eps )
4445
4546 m = tf .floor (x ) + .5
4647 r = x - m
47- z = tf .tanh (alpha / 2. ) * 2.
48- y = m + tf .tanh (alpha * r ) / z
48+ z = tf .tanh (alpha_bounded / 2. ) * 2.
49+ y = m + tf .tanh (alpha_bounded * r ) / z
4950
5051 # For very low alphas, soft_round behaves like identity
5152 return tf .where (alpha < eps , x , y , name = "soft_round" )
@@ -68,12 +69,12 @@ def soft_round_inverse(y, alpha, eps=1e-3):
6869 Returns:
6970 tf.Tensor
7071 """
71- if isinstance ( alpha , ( float , int )) and alpha < eps :
72- return tf . identity ( y , name = "soft_round_inverse" )
73-
72+ # This guards the gradient of tf.where below against NaNs, while maintaining
73+ # correctness, as for alpha < eps the result is ignored.
74+ alpha_bounded = tf . maximum ( alpha , eps )
7475 m = tf .floor (y ) + .5
75- s = (y - m ) * (tf .tanh (alpha / 2. ) * 2. )
76- r = tf .atanh (s ) / alpha
76+ s = (y - m ) * (tf .tanh (alpha_bounded / 2. ) * 2. )
77+ r = tf .atanh (s ) / alpha_bounded
7778 # `r` must be between -.5 and .5 by definition. In case atanh becomes +-inf
7879 # due to numerical instability, this prevents the forward pass from yielding
7980 # infinite values. Note that it doesn't prevent the backward pass from
0 commit comments