1919Ballé, Laparra, Simoncelli (2017):
2020End-to-end optimized image compression
2121https://arxiv.org/abs/1611.01704
22+
23+ With patches from Victor Xing <victor.t.xing@gmail.com>
2224"""
2325
2426from __future__ import absolute_import
2527from __future__ import division
2628from __future__ import print_function
2729
2830import argparse
31+ import glob
2932
3033# Dependency imports
3134
@@ -44,12 +47,16 @@ def load_image(filename):
4447 return image
4548
4649
47- def save_image (filename , image ):
48- """Saves an image to a PNG file."""
49-
50+ def quantize_image (image ):
5051 image = tf .clip_by_value (image , 0 , 1 )
5152 image = tf .round (image * 255 )
5253 image = tf .cast (image , tf .uint8 )
54+ return image
55+
56+
57+ def save_image (filename , image ):
58+ """Saves an image to a PNG file."""
59+ image = quantize_image (image )
5360 string = tf .image .encode_png (image )
5461 return tf .write_file (filename , string )
5562
@@ -110,17 +117,22 @@ def train():
110117 if args .verbose :
111118 tf .logging .set_verbosity (tf .logging .INFO )
112119
113- # Load all training images into a constant.
114- images = tf .map_fn (
115- load_image , tf .matching_files (args .data_glob ),
116- dtype = tf .float32 , back_prop = False )
117- with tf .Session () as sess :
118- images = tf .constant (sess .run (images ), name = "images" )
120+ # Create input data pipeline.
121+ with tf .device ('/cpu:0' ):
122+ train_files = glob .glob (args .train_glob )
123+ train_dataset = tf .data .Dataset .from_tensor_slices (train_files )
124+ train_dataset = train_dataset .shuffle (buffer_size = len (train_files )).repeat ()
125+ train_dataset = train_dataset .map (
126+ load_image , num_parallel_calls = args .preprocess_threads )
127+ train_dataset = train_dataset .map (
128+ lambda x : tf .random_crop (x , (args .patchsize , args .patchsize , 3 )))
129+ train_dataset = train_dataset .batch (args .batchsize )
130+ train_dataset = train_dataset .prefetch (32 )
131+
132+ num_pixels = args .batchsize * args .patchsize ** 2
119133
120- # Training inputs are random crops out of the images tensor.
121- crop_shape = (args .batchsize , args .patchsize , args .patchsize , 3 )
122- x = tf .random_crop (images , crop_shape )
123- num_pixels = np .prod (crop_shape [:- 1 ])
134+ # Get training patch from dataset.
135+ x = train_dataset .make_one_shot_iterator ().get_next ()
124136
125137 # Build autoencoder.
126138 y = analysis_transform (x , args .num_filters )
@@ -132,9 +144,9 @@ def train():
132144 train_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
133145
134146 # Mean squared error across pixels.
135- train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde ))
147+ train_mse = tf .reduce_mean (tf .squared_difference (x , x_tilde ))
136148 # Multiply by 255^2 to correct for rescaling.
137- train_mse *= 255 ** 2 / num_pixels
149+ train_mse *= 255 ** 2
138150
139151 # The rate-distortion cost.
140152 train_loss = args .lmbda * train_mse + train_bpp
@@ -149,18 +161,24 @@ def train():
149161
150162 train_op = tf .group (main_step , aux_step , entropy_bottleneck .updates [0 ])
151163
152- logged_tensors = [
153- tf .identity (train_loss , name = "train_loss" ),
154- tf .identity (train_bpp , name = "train_bpp" ),
155- tf .identity (train_mse , name = "train_mse" ),
156- ]
164+ tf .summary .scalar ("loss" , train_loss )
165+ tf .summary .scalar ("bpp" , train_bpp )
166+ tf .summary .scalar ("mse" , train_mse )
167+
168+ tf .summary .image ("original" , quantize_image (x ))
169+ tf .summary .image ("reconstruction" , quantize_image (x_tilde ))
170+
171+ # Creates summary for the probability mass function (PMF) estimated in the
172+ # bottleneck.
173+ entropy_bottleneck .visualize ()
174+
157175 hooks = [
158176 tf .train .StopAtStepHook (last_step = args .last_step ),
159177 tf .train .NanTensorHook (train_loss ),
160- tf .train .LoggingTensorHook (logged_tensors , every_n_secs = 60 ),
161178 ]
162179 with tf .train .MonitoredTrainingSession (
163- hooks = hooks , checkpoint_dir = args .checkpoint_dir ) as sess :
180+ hooks = hooks , checkpoint_dir = args .checkpoint_dir ,
181+ save_checkpoint_secs = 300 , save_summaries_secs = 60 ) as sess :
164182 while not sess .should_stop ():
165183 sess .run (train_op )
166184
@@ -188,10 +206,14 @@ def compress():
188206 # Total number of bits divided by number of pixels.
189207 eval_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
190208
191- # Mean squared error across pixels.
209+ # Bring both images back to 0..255 range.
210+ x *= 255
192211 x_hat = tf .clip_by_value (x_hat , 0 , 1 )
193212 x_hat = tf .round (x_hat * 255 )
194- mse = tf .reduce_sum (tf .squared_difference (x * 255 , x_hat )) / num_pixels
213+
214+ mse = tf .reduce_mean (tf .squared_difference (x , x_hat ))
215+ psnr = tf .squeeze (tf .image .psnr (x_hat , x , 255 ))
216+ msssim = tf .squeeze (tf .image .ssim_multiscale (x_hat , x , 255 ))
195217
196218 with tf .Session () as sess :
197219 # Load the latest model checkpoint, get the compressed string and the tensor
@@ -208,14 +230,18 @@ def compress():
208230
209231 # If requested, transform the quantized image back and measure performance.
210232 if args .verbose :
211- eval_bpp , mse , num_pixels = sess .run ([eval_bpp , mse , num_pixels ])
233+ eval_bpp , mse , psnr , msssim , num_pixels = sess .run (
234+ [eval_bpp , mse , psnr , msssim , num_pixels ])
212235
213236 # The actual bits per pixel including overhead.
214237 bpp = (8 + len (string )) * 8 / num_pixels
215238
216- print ("Mean squared error: {:0.4}" .format (mse ))
217- print ("Information content of this image in bpp: {:0.4}" .format (eval_bpp ))
218- print ("Actual bits per pixel for this image: {:0.4}" .format (bpp ))
239+ print ("Mean squared error: {:0.4f}" .format (mse ))
240+ print ("PSNR (dB): {:0.2f}" .format (psnr ))
241+ print ("Multiscale SSIM: {:0.4f}" .format (msssim ))
242+ print ("Multiscale SSIM (dB): {:0.2f}" .format (- 10 * np .log10 (1 - msssim )))
243+ print ("Information content in bpp: {:0.4f}" .format (eval_bpp ))
244+ print ("Actual bits per pixel: {:0.4f}" .format (bpp ))
219245
220246
221247def decompress ():
@@ -278,22 +304,25 @@ def decompress():
278304 "--checkpoint_dir" , default = "train" ,
279305 help = "Directory where to save/load model checkpoints." )
280306 parser .add_argument (
281- "--data_glob " , default = "images/*.png" ,
307+ "--train_glob " , default = "images/*.png" ,
282308 help = "Glob pattern identifying training data. This pattern must expand "
283- "to a list of RGB images in PNG format which all have the same "
284- "shape." )
309+ "to a list of RGB images in PNG format." )
285310 parser .add_argument (
286311 "--batchsize" , type = int , default = 8 ,
287312 help = "Batch size for training." )
288313 parser .add_argument (
289- "--patchsize" , type = int , default = 128 ,
314+ "--patchsize" , type = int , default = 256 ,
290315 help = "Size of image patches for training." )
291316 parser .add_argument (
292- "--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
317+ "--lambda" , type = float , default = 0.01 , dest = "lmbda" ,
293318 help = "Lambda for rate-distortion tradeoff." )
294319 parser .add_argument (
295320 "--last_step" , type = int , default = 1000000 ,
296321 help = "Train up to this number of steps." )
322+ parser .add_argument (
323+ "--preprocess_threads" , type = int , default = 16 ,
324+ help = "Number of CPU threads to use for parallel decoding of training "
325+ "images." )
297326
298327 args = parser .parse_args ()
299328
0 commit comments