6161 "HyperInfo" ,
6262 "decoded latent_shape hyper_latent_shape "
6363 "nbpp side_nbpp total_nbpp qbpp side_qbpp total_qbpp "
64- "bitstring side_bitstring " ,
64+ "bitstream_tensors " ,
6565)
6666
6767
@@ -86,7 +86,7 @@ def __init__(self,
8686 model = [
8787 tf .keras .layers .Conv2D (
8888 filters = num_filters_base , kernel_size = 7 , padding = "same" ),
89- LayerNorm (),
89+ ChannelNorm (),
9090 tf .keras .layers .ReLU ()
9191 ]
9292
@@ -95,7 +95,7 @@ def __init__(self,
9595 tf .keras .layers .Conv2D (
9696 filters = num_filters_base * 2 ** (i + 1 ),
9797 kernel_size = 3 , padding = "same" , strides = 2 ),
98- LayerNorm (),
98+ ChannelNorm (),
9999 tf .keras .layers .ReLU ()])
100100
101101 model .append (
@@ -127,11 +127,11 @@ def __init__(self,
127127 num_filters_base: base number of filters.
128128 num_residual_blocks: number of residual blocks.
129129 """
130- head = [LayerNorm (),
130+ head = [ChannelNorm (),
131131 tf .keras .layers .Conv2D (
132132 filters = num_filters_base * (2 ** num_up ),
133133 kernel_size = 3 , padding = "same" ),
134- LayerNorm ()]
134+ ChannelNorm ()]
135135
136136 residual_blocks = []
137137 for block_idx in range (num_residual_blocks ):
@@ -151,7 +151,7 @@ def __init__(self,
151151 filters = filters ,
152152 kernel_size = 3 , padding = "same" ,
153153 strides = 2 ),
154- LayerNorm (),
154+ ChannelNorm (),
155155 tf .keras .layers .ReLU ()]
156156
157157 # Final conv layer.
@@ -201,19 +201,19 @@ def __init__(
201201
202202 block = [
203203 tf .keras .layers .Conv2D (** kwargs_conv2d ),
204- LayerNorm (),
204+ ChannelNorm (),
205205 tf .keras .layers .Activation (activation ),
206206 tf .keras .layers .Conv2D (** kwargs_conv2d ),
207- LayerNorm ()]
207+ ChannelNorm ()]
208208
209209 self .block = tf .keras .Sequential (name = name , layers = block )
210210
211211 def call (self , inputs , ** kwargs ):
212212 return inputs + self .block (inputs , ** kwargs )
213213
214214
215- class LayerNorm (tf .keras .layers .Layer ):
216- """Implement LayerNorm .
215+ class ChannelNorm (tf .keras .layers .Layer ):
216+ """Implement ChannelNorm .
217217
218218 Based on this paper and keras' InstanceNorm layer:
219219 Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
@@ -238,7 +238,7 @@ def __init__(self,
238238 gamma_initializer: Initializer for gamma.
239239 **kwargs: Passed to keras.
240240 """
241- super (LayerNorm , self ).__init__ (** kwargs )
241+ super (ChannelNorm , self ).__init__ (** kwargs )
242242
243243 self .axis = - 1
244244 self .epsilon = epsilon
@@ -478,6 +478,14 @@ def _make_synthesis(syn_name):
478478
479479 self ._side_entropy_model = FactorizedPriorLayer ()
480480
481+ @property
482+ def losses (self ):
483+ return self ._side_entropy_model .losses
484+
485+ @property
486+ def updates (self ):
487+ return self ._side_entropy_model .updates
488+
481489 @property
482490 def transform_layers (self ):
483491 return [self ._analysis , self ._synthesis_scale , self ._synthesis_mean ]
@@ -529,7 +537,7 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
529537
530538 compressed = None
531539 if training :
532- latents_decoded = _quantize (latents , latent_means )
540+ latents_decoded = _ste_quantize (latents , latent_means )
533541 elif validation :
534542 latents_decoded = entropy_info .quantized
535543 else :
@@ -546,16 +554,25 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
546554 qbpp = entropy_info .qbpp ,
547555 side_qbpp = side_info .total_qbpp ,
548556 total_qbpp = entropy_info .qbpp + side_info .total_qbpp ,
549- bitstring = compressed ,
550- side_bitstring = side_info .bitstring )
557+ # We put everything that's needed for real arithmetic coding into
558+ # the bistream_tensors tuple.
559+ bitstream_tensors = (compressed , side_info .bitstring ,
560+ image_shape , latent_shape , side_info .latent_shape ))
551561
552562 tf .summary .scalar ("bpp/total/noisy" , info .total_nbpp )
553563 tf .summary .scalar ("bpp/total/quantized" , info .total_qbpp )
554564
565+ tf .summary .scalar ("bpp/latent/noisy" , entropy_info .nbpp )
566+ tf .summary .scalar ("bpp/latent/quantized" , entropy_info .qbpp )
567+
568+ tf .summary .scalar ("bpp/side/noisy" , side_info .total_nbpp )
569+ tf .summary .scalar ("bpp/side/quantized" , side_info .total_qbpp )
570+
555571 return info
556572
557573
558- def _quantize (inputs , mean ):
574+ def _ste_quantize (inputs , mean ):
575+ """Calculates quantize(inputs - mean) + mean, sets straight-through grads."""
559576 half = tf .constant (.5 , dtype = tf .float32 )
560577 outputs = inputs
561578 outputs -= mean
0 commit comments