3535
3636
3737def load_image (filename ):
38+ """Loads a PNG image file."""
39+
3840 string = tf .read_file (filename )
3941 image = tf .image .decode_image (string , channels = 3 )
4042 image = tf .cast (image , tf .float32 )
@@ -43,6 +45,8 @@ def load_image(filename):
4345
4446
4547def save_image (filename , image ):
48+ """Saves an image to a PNG file."""
49+
4650 image = tf .clip_by_value (image , 0 , 1 )
4751 image = tf .round (image * 255 )
4852 image = tf .cast (image , tf .uint8 )
@@ -51,6 +55,8 @@ def save_image(filename, image):
5155
5256
5357def analysis_transform (tensor , num_filters ):
58+ """Builds the analysis transform."""
59+
5460 with tf .variable_scope ("analysis" ):
5561 with tf .variable_scope ("layer_0" ):
5662 layer = tfc .SignalConv2D (
@@ -74,6 +80,8 @@ def analysis_transform(tensor, num_filters):
7480
7581
7682def synthesis_transform (tensor , num_filters ):
83+ """Builds the synthesis transform."""
84+
7785 with tf .variable_scope ("synthesis" ):
7886 with tf .variable_scope ("layer_0" ):
7987 layer = tfc .SignalConv2D (
@@ -96,11 +104,16 @@ def synthesis_transform(tensor, num_filters):
96104 return tensor
97105
98106
99- def train (args ):
107+ def train ():
108+ """Trains the model."""
109+
110+ if args .verbose :
111+ tf .logging .set_verbosity (tf .logging .INFO )
112+
100113 # Load all training images into a constant.
101114 images = tf .map_fn (
102- load_image , tf .matching_files (args .data_glob ),
103- dtype = tf .float32 , back_prop = False )
115+ load_image , tf .matching_files (args .data_glob ),
116+ dtype = tf .float32 , back_prop = False )
104117 with tf .Session () as sess :
105118 images = tf .constant (sess .run (images ), name = "images" )
106119
@@ -119,7 +132,9 @@ def train(args):
119132 train_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
120133
121134 # Mean squared error across pixels.
122- train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde )) / num_pixels
135+ train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde ))
136+ # Multiply by 255^2 to correct for rescaling.
137+ train_mse *= 255 ** 2 / num_pixels
123138
124139 # The rate-distortion cost.
125140 train_loss = args .lmbda * train_mse + train_bpp
@@ -134,17 +149,25 @@ def train(args):
134149
135150 train_op = tf .group (main_step , aux_step , entropy_bottleneck .updates [0 ])
136151
152+ logged_tensors = [
153+ tf .identity (train_loss , name = "train_loss" ),
154+ tf .identity (train_bpp , name = "train_bpp" ),
155+ tf .identity (train_mse , name = "train_mse" ),
156+ ]
137157 hooks = [
138158 tf .train .StopAtStepHook (last_step = args .last_step ),
139159 tf .train .NanTensorHook (train_loss ),
160+ tf .train .LoggingTensorHook (logged_tensors , every_n_secs = 60 ),
140161 ]
141162 with tf .train .MonitoredTrainingSession (
142163 hooks = hooks , checkpoint_dir = args .checkpoint_dir ) as sess :
143164 while not sess .should_stop ():
144165 sess .run (train_op )
145166
146167
147- def compress (args ):
168+ def compress ():
169+ """Compresses an image."""
170+
148171 # Load input image and add batch dimension.
149172 x = load_image (args .input )
150173 x = tf .expand_dims (x , 0 )
@@ -166,7 +189,9 @@ def compress(args):
166189 eval_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
167190
168191 # Mean squared error across pixels.
169- mse = tf .reduce_sum (tf .squared_difference (x , x_hat )) / num_pixels
192+ x_hat = tf .clip_by_value (x_hat , 0 , 1 )
193+ x_hat = tf .round (x_hat * 255 )
194+ mse = tf .reduce_sum (tf .squared_difference (x * 255 , x_hat )) / num_pixels
170195
171196 with tf .Session () as sess :
172197 # Load the latest model checkpoint, get the compressed string and the tensor
@@ -176,10 +201,10 @@ def compress(args):
176201 string , x_shape , y_shape = sess .run ([string , tf .shape (x ), tf .shape (y )])
177202
178203 # Write a binary file with the shape information and the compressed string.
179- with open (args .output , "wb" ) as file :
180- file .write (np .array (x_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
181- file .write (np .array (y_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
182- file .write (string )
204+ with open (args .output , "wb" ) as f :
205+ f .write (np .array (x_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
206+ f .write (np .array (y_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
207+ f .write (string )
183208
184209 # If requested, transform the quantized image back and measure performance.
185210 if args .verbose :
@@ -193,14 +218,15 @@ def compress(args):
193218 print ("Actual bits per pixel for this image: {:0.4}" .format (bpp ))
194219
195220
196- def decompress (args ):
221+ def decompress ():
222+ """Decompresses an image."""
223+
197224 # Read the shape information and compressed string from the binary file.
198- with open (args .input , "rb" ) as file :
199- x_shape = np .frombuffer (file .read (4 ), dtype = np .uint16 )
200- y_shape = np .frombuffer (file .read (4 ), dtype = np .uint16 )
201- string = file .read ()
225+ with open (args .input , "rb" ) as f :
226+ x_shape = np .frombuffer (f .read (4 ), dtype = np .uint16 )
227+ y_shape = np .frombuffer (f .read (4 ), dtype = np .uint16 )
228+ string = f .read ()
202229
203- bits = 8 * len (string )
204230 y_shape = [int (s ) for s in y_shape ] + [args .num_filters ]
205231
206232 # Add a batch dimension, then decompress and transform the image back.
@@ -242,34 +268,42 @@ def decompress(args):
242268 parser .add_argument (
243269 "output" , nargs = "?" ,
244270 help = "Output filename." )
245- parser .add_argument ("--verbose" , "-v" , action = "store_true" ,
271+ parser .add_argument (
272+ "--verbose" , "-v" , action = "store_true" ,
246273 help = "Report bitrate and distortion when training or compressing." )
247- parser .add_argument ("--num_filters" , type = int , default = 128 ,
274+ parser .add_argument (
275+ "--num_filters" , type = int , default = 128 ,
248276 help = "Number of filters per layer." )
249- parser .add_argument ("--checkpoint_dir" , default = "train" ,
277+ parser .add_argument (
278+ "--checkpoint_dir" , default = "train" ,
250279 help = "Directory where to save/load model checkpoints." )
251- parser .add_argument ("--data_glob" , default = "images/*.png" ,
280+ parser .add_argument (
281+ "--data_glob" , default = "images/*.png" ,
252282 help = "Glob pattern identifying training data. This pattern must expand "
253283 "to a list of RGB images in PNG format which all have the same "
254284 "shape." )
255- parser .add_argument ("--batchsize" , type = int , default = 8 ,
285+ parser .add_argument (
286+ "--batchsize" , type = int , default = 8 ,
256287 help = "Batch size for training." )
257- parser .add_argument ("--patchsize" , type = int , default = 128 ,
288+ parser .add_argument (
289+ "--patchsize" , type = int , default = 128 ,
258290 help = "Size of image patches for training." )
259- parser .add_argument ("--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
291+ parser .add_argument (
292+ "--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
260293 help = "Lambda for rate-distortion tradeoff." )
261- parser .add_argument ("--last_step" , type = int , default = 1000000 ,
294+ parser .add_argument (
295+ "--last_step" , type = int , default = 1000000 ,
262296 help = "Train up to this number of steps." )
263297
264298 args = parser .parse_args ()
265299
266300 if args .command == "train" :
267- train (args )
301+ train ()
268302 elif args .command == "compress" :
269303 if args .input is None or args .output is None :
270304 raise ValueError ("Need input and output filename for compression." )
271- compress (args )
305+ compress ()
272306 elif args .command == "decompress" :
273307 if args .input is None or args .output is None :
274308 raise ValueError ("Need input and output filename for decompression." )
275- decompress (args )
309+ decompress ()
0 commit comments