@@ -209,7 +209,8 @@ def call(self, x, training):
209209
210210 # Build the entropy model for the hyperprior (z).
211211 em_z = tfc .ContinuousBatchedEntropyModel (
212- self .hyperprior , coding_rank = 3 , compression = False )
212+ self .hyperprior , coding_rank = 3 , compression = False ,
213+ offset_heuristic = False )
213214
214215 # When training, z_bpp is based on the noisy version of z (z_tilde).
215216 _ , z_bits = em_z (z , training = training )
@@ -255,7 +256,7 @@ def call(self, x, training):
255256
256257 # For the synthesis transform, use rounding. Note that quantize()
257258 # overrides the gradient to create a straight-through estimator.
258- y_hat_slice = em_y .quantize (y_slice , sigma , loc = mu )
259+ y_hat_slice = em_y .quantize (y_slice , loc = mu )
259260
260261 # Add latent residual prediction (LRP).
261262 lrp_support = tf .concat ([mean_support , y_hat_slice ], axis = - 1 )
@@ -318,7 +319,8 @@ def fit(self, *args, **kwargs):
318319 retval = super ().fit (* args , ** kwargs )
319320 # After training, fix range coding tables.
320321 self .em_z = tfc .ContinuousBatchedEntropyModel (
321- self .hyperprior , coding_rank = 3 , compression = True )
322+ self .hyperprior , coding_rank = 3 , compression = True ,
323+ offset_heuristic = False )
322324 self .em_y = tfc .LocationScaleIndexedEntropyModel (
323325 tfc .NoisyNormal , num_scales = self .num_scales , scale_fn = self .scale_fn ,
324326 coding_rank = 3 , compression = True )
0 commit comments