@@ -96,17 +96,34 @@ def instantiate_model_signature(model, signature, inputs=None, outputs=None):
9696 return wrapped_import .prune (inputs , outputs )
9797
9898
99- def compress_image (model , input_image ):
99+ def compress_image (model , input_image , rd_parameter = None ):
100100 """Compresses an image tensor into a bitstring."""
101101 sender = instantiate_model_signature (model , "sender" )
102- tensors = sender (input_image )
102+ if len (sender .inputs ) == 1 :
103+ if rd_parameter is not None :
104+ raise ValueError ("This model doesn't expect an RD parameter." )
105+ tensors = sender (input_image )
106+ elif len (sender .inputs ) == 2 :
107+ if rd_parameter is None :
108+ raise ValueError ("This model expects an RD parameter." )
109+ rd_parameter = tf .constant (rd_parameter , dtype = sender .inputs [1 ].dtype )
110+ tensors = sender (input_image , rd_parameter )
111+ # Find RD parameter and expand it to a 1D tensor so it fits into the
112+ # PackedTensors format.
113+ for i , t in enumerate (tensors ):
114+ if t .dtype .is_floating and t .shape .rank == 0 :
115+ tensors [i ] = tf .expand_dims (t , 0 )
116+ else :
117+ raise RuntimeError ("Unexpected model signature." )
103118 packed = tfc .PackedTensors ()
104119 packed .model = model
105120 packed .pack (tensors )
106121 return packed .string
107122
108123
109- def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
124+ def compress (model , input_file , output_file ,
125+ rd_parameter = None , rd_parameter_tolerance = None ,
126+ target_bpp = None , bpp_strict = False ):
110127 """Compresses a PNG file to a TFCI file."""
111128 if not output_file :
112129 output_file = input_file + ".tfci"
@@ -117,21 +134,35 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
117134
118135 if not target_bpp :
119136 # Just compress with a specific model.
120- bitstring = compress_image (model , input_image )
137+ bitstring = compress_image (model , input_image , rd_parameter = rd_parameter )
121138 else :
122139 # Get model list.
123140 models = load_cached (model + ".models" )
124141 models = models .decode ("ascii" ).split ()
125142
126- # Do a binary search over all RD points.
127- lower = - 1
128- upper = len (models )
143+ try :
144+ lower , upper = [float (m ) for m in models ]
145+ use_rd_parameter = True
146+ except ValueError :
147+ lower = - 1
148+ upper = len (models )
149+ use_rd_parameter = False
150+
151+ # Do a binary search over RD points.
129152 bpp = None
130153 best_bitstring = None
131154 best_bpp = None
132- while bpp != target_bpp and upper - lower > 1 :
133- i = (upper + lower ) // 2
134- bitstring = compress_image (models [i ], input_image )
155+ while bpp != target_bpp :
156+ if use_rd_parameter :
157+ if upper - lower <= rd_parameter_tolerance :
158+ break
159+ i = (upper + lower ) / 2
160+ bitstring = compress_image (model , input_image , rd_parameter = i )
161+ else :
162+ if upper - lower < 2 :
163+ break
164+ i = (upper + lower ) // 2
165+ bitstring = compress_image (models [i ], input_image )
135166 bpp = 8 * len (bitstring ) / num_pixels
136167 is_admissible = bpp <= target_bpp or not bpp_strict
137168 is_better = (best_bpp is None or
@@ -162,6 +193,10 @@ def decompress(input_file, output_file):
162193 packed = tfc .PackedTensors (f .read ())
163194 receiver = instantiate_model_signature (packed .model , "receiver" )
164195 tensors = packed .unpack ([t .dtype for t in receiver .inputs ])
196+ # Find potential RD parameter and turn it back into a scalar.
197+ for i , t in enumerate (tensors ):
198+ if t .dtype .is_floating and t .shape == (1 ,):
199+ tensors [i ] = tf .squeeze (t , 0 )
165200 output_image , = receiver (* tensors )
166201 write_png (output_file , output_image )
167202
@@ -247,7 +282,17 @@ def parse_args(argv):
247282 "'target_bpp' is provided, don't specify the index at the end of "
248283 "the model identifier." )
249284 compress_cmd .add_argument (
250- "--target_bpp" , type = float ,
285+ "--rd_parameter" , "-r" , type = float ,
286+ help = "Rate-distortion parameter (for some models). Ignored if "
287+ "'target_bpp' is set." )
288+ compress_cmd .add_argument (
289+ "--rd_parameter_tolerance" , type = float ,
290+ default = 2 ** - 4 ,
291+ help = "Tolerance for rate-distortion parameter. Only used if 'target_bpp' "
292+ "is set for some models, to determine when to stop the binary "
293+ "search." )
294+ compress_cmd .add_argument (
295+ "--target_bpp" , "-b" , type = float ,
251296 help = "Target bits per pixel. If provided, a binary search is used to try "
252297 "to match the given bpp as close as possible. In this case, don't "
253298 "specify the index at the end of the model identifier. It will be "
@@ -323,6 +368,7 @@ def main(args):
323368 # Invoke subcommand.
324369 if args .command == "compress" :
325370 compress (args .model , args .input_file , args .output_file ,
371+ args .rd_parameter , args .rd_parameter_tolerance ,
326372 args .target_bpp , args .bpp_strict )
327373 elif args .command == "decompress" :
328374 decompress (args .input_file , args .output_file )
0 commit comments