@@ -32,7 +32,7 @@ def test_noise(self):
3232 inputs = tf .placeholder (tf .float32 , (None , 1 ))
3333 layer = entropy_models .EntropyBottleneck ()
3434 noisy , _ = layer (inputs , training = True )
35- with self .test_session () as sess :
35+ with self .cached_session () as sess :
3636 sess .run (tf .global_variables_initializer ())
3737 values = np .linspace (- 50 , 50 , 100 )[:, None ]
3838 noisy , = sess .run ([noisy ], {inputs : values })
@@ -45,7 +45,7 @@ def test_quantization_init(self):
4545 inputs = tf .placeholder (tf .float32 , (None , 1 ))
4646 layer = entropy_models .EntropyBottleneck ()
4747 quantized , _ = layer (inputs , training = False )
48- with self .test_session () as sess :
48+ with self .cached_session () as sess :
4949 sess .run (tf .global_variables_initializer ())
5050 values = np .linspace (- 50 , 50 , 100 )[:, None ]
5151 quantized , = sess .run ([quantized ], {inputs : values })
@@ -61,7 +61,7 @@ def test_quantization(self):
6161 opt = tf .train .GradientDescentOptimizer (learning_rate = 1 )
6262 self .assertEqual (1 , len (layer .losses ))
6363 step = opt .minimize (layer .losses [0 ])
64- with self .test_session () as sess :
64+ with self .cached_session () as sess :
6565 sess .run (tf .global_variables_initializer ())
6666 sess .run (step )
6767 values = np .linspace (- 50 , 50 , 100 )[:, None ]
@@ -79,7 +79,7 @@ def test_codec_init(self):
7979 data_format = "channels_last" , init_scale = 30 )
8080 bitstrings = layer .compress (inputs )
8181 decoded = layer .decompress (bitstrings , tf .shape (inputs )[1 :])
82- with self .test_session () as sess :
82+ with self .cached_session () as sess :
8383 sess .run (tf .global_variables_initializer ())
8484 values = np .linspace (- 50 , 50 , 100 )[None , :, None ]
8585 decoded , = sess .run ([decoded ], {inputs : values })
@@ -98,7 +98,7 @@ def test_codec(self):
9898 opt = tf .train .GradientDescentOptimizer (learning_rate = 1 )
9999 self .assertEqual (1 , len (layer .losses ))
100100 step = opt .minimize (layer .losses [0 ])
101- with self .test_session () as sess :
101+ with self .cached_session () as sess :
102102 sess .run (tf .global_variables_initializer ())
103103 sess .run (step )
104104 self .assertEqual (1 , len (layer .updates ))
@@ -120,7 +120,7 @@ def test_channels_last(self):
120120 quantized , _ = layer (inputs , training = False )
121121 bitstrings = layer .compress (inputs )
122122 decoded = layer .decompress (bitstrings , tf .shape (inputs )[1 :])
123- with self .test_session () as sess :
123+ with self .cached_session () as sess :
124124 sess .run (tf .global_variables_initializer ())
125125 self .assertEqual (1 , len (layer .updates ))
126126 sess .run (layer .updates [0 ])
@@ -141,7 +141,7 @@ def test_channels_first(self):
141141 quantized , _ = layer (inputs , training = False )
142142 bitstrings = layer .compress (inputs )
143143 decoded = layer .decompress (bitstrings , tf .shape (inputs )[1 :])
144- with self .test_session () as sess :
144+ with self .cached_session () as sess :
145145 sess .run (tf .global_variables_initializer ())
146146 self .assertEqual (1 , len (layer .updates ))
147147 sess .run (layer .updates [0 ])
@@ -161,7 +161,7 @@ def test_compress(self):
161161 data_format = "channels_first" , filters = (), init_scale = 2 )
162162 bitstrings = layer .compress (inputs )
163163 decoded = layer .decompress (bitstrings , tf .shape (inputs )[1 :])
164- with self .test_session () as sess :
164+ with self .cached_session () as sess :
165165 values = 8 * np .random .uniform (size = (2 , 3 , 9 )) - 4
166166 sess .run (tf .global_variables_initializer ())
167167 self .assertEqual (1 , len (layer .updates ))
@@ -213,7 +213,7 @@ def test_decompress(self):
213213 layer ._quantized_cdf = quantized_cdf
214214 layer ._cdf_length = cdf_length
215215 decoded = layer .decompress (bitstrings , input_shape [1 :])
216- with self .test_session () as sess :
216+ with self .cached_session () as sess :
217217 sess .run (tf .global_variables_initializer ())
218218 decoded , = sess .run ([decoded ], {
219219 bitstrings : self .bitstrings , input_shape : self .expected .shape ,
@@ -233,7 +233,7 @@ def test_normalization(self):
233233 inputs = tf .placeholder (tf .float32 , (None , 1 ))
234234 layer = entropy_models .EntropyBottleneck (filters = (2 ,))
235235 _ , likelihood = layer (inputs , training = True )
236- with self .test_session () as sess :
236+ with self .cached_session () as sess :
237237 sess .run (tf .global_variables_initializer ())
238238 x = np .repeat (np .arange (- 200 , 201 ), 2000 )[:, None ]
239239 likelihood , = sess .run ([likelihood ], {inputs : x })
@@ -251,7 +251,7 @@ def test_entropy_estimates(self):
251251 _ , likelihood = layer (inputs , training = False )
252252 disc_entropy = tf .reduce_sum (tf .log (likelihood )) / - np .log (2 )
253253 bitstrings = layer .compress (inputs )
254- with self .test_session () as sess :
254+ with self .cached_session () as sess :
255255 sess .run (tf .global_variables_initializer ())
256256 self .assertEqual (1 , len (layer .updates ))
257257 sess .run (layer .updates [0 ])
@@ -272,7 +272,7 @@ def test_noise(self):
272272 scale = tf .placeholder (tf .float32 , [None ])
273273 layer = self .subclass (scale , [1 ])
274274 noisy , _ = layer (inputs , training = True )
275- with self .test_session () as sess :
275+ with self .cached_session () as sess :
276276 sess .run (tf .global_variables_initializer ())
277277 values = np .linspace (- 50 , 50 , 100 )
278278 noisy , = sess .run ([noisy ], {
@@ -288,7 +288,7 @@ def test_quantization(self):
288288 scale = tf .placeholder (tf .float32 , [None ])
289289 layer = self .subclass (scale , [1 ], mean = None )
290290 quantized , _ = layer (inputs , training = False )
291- with self .test_session () as sess :
291+ with self .cached_session () as sess :
292292 sess .run (tf .global_variables_initializer ())
293293 values = np .linspace (- 50 , 50 , 100 )
294294 quantized , = sess .run ([quantized ], {
@@ -305,7 +305,7 @@ def test_quantization_mean(self):
305305 mean = tf .placeholder (tf .float32 , [None ])
306306 layer = self .subclass (scale , [1 ], mean = mean )
307307 quantized , _ = layer (inputs , training = False )
308- with self .test_session () as sess :
308+ with self .cached_session () as sess :
309309 sess .run (tf .global_variables_initializer ())
310310 values = np .linspace (- 50 , 50 , 100 )
311311 mean_values = np .random .normal (size = values .shape )
@@ -327,7 +327,7 @@ def test_codec(self):
327327 scale , [2 ** x for x in range (- 10 , 10 )], mean = None )
328328 bitstrings = layer .compress (inputs )
329329 decoded = layer .decompress (bitstrings )
330- with self .test_session () as sess :
330+ with self .cached_session () as sess :
331331 sess .run (tf .global_variables_initializer ())
332332 values = np .linspace (- 50 , 50 , 100 )[None ]
333333 decoded , = sess .run ([decoded ], {
@@ -346,7 +346,7 @@ def test_codec_mean(self):
346346 scale , [2 ** x for x in range (- 10 , 10 )], mean = mean )
347347 bitstrings = layer .compress (inputs )
348348 decoded = layer .decompress (bitstrings )
349- with self .test_session () as sess :
349+ with self .cached_session () as sess :
350350 sess .run (tf .global_variables_initializer ())
351351 values = np .linspace (- 50 , 50 , 100 )[None ]
352352 mean_values = np .random .normal (size = values .shape )
@@ -369,7 +369,7 @@ def test_multiple_dimensions(self):
369369 quantized , _ = layer (inputs , training = False )
370370 bitstrings = layer .compress (inputs )
371371 decoded = layer .decompress (bitstrings )
372- with self .test_session () as sess :
372+ with self .cached_session () as sess :
373373 sess .run (tf .global_variables_initializer ())
374374 values = 10 * np .random .normal (size = (2 , 5 , 3 , 7 ))
375375 noisy , quantized , decoded = sess .run (
@@ -391,7 +391,7 @@ def test_compress(self):
391391 layer = self .subclass (scale , scale_table , indexes = indexes )
392392 bitstrings = layer .compress (inputs )
393393 decoded = layer .decompress (bitstrings )
394- with self .test_session () as sess :
394+ with self .cached_session () as sess :
395395 values = 8 * np .random .uniform (size = shape ) - 4
396396 indexes = np .random .randint (
397397 0 , len (scale_table ), size = shape , dtype = np .int32 )
@@ -415,7 +415,7 @@ def test_decompress(self):
415415 layer = self .subclass (
416416 scale , scale_table , indexes = indexes , dtype = tf .float32 )
417417 decoded = layer .decompress (bitstrings )
418- with self .test_session () as sess :
418+ with self .cached_session () as sess :
419419 sess .run (tf .global_variables_initializer ())
420420 decoded , = sess .run ([decoded ], {
421421 bitstrings : self .bitstrings ,
@@ -437,7 +437,7 @@ def test_quantile_function(self):
437437 # Test that quantile function inverts cumulative.
438438 scale = tf .placeholder (tf .float64 , [None ])
439439 layer = self .subclass (scale , [1 ], dtype = tf .float64 )
440- with self .test_session () as sess :
440+ with self .cached_session () as sess :
441441 sess .run (tf .global_variables_initializer ())
442442 quantiles = np .array ([1e-5 , 1e-2 , .1 , .5 , .6 , .8 ])
443443 locations = layer ._standardized_quantile (quantiles )
@@ -452,7 +452,7 @@ def test_distribution(self):
452452 scale = tf .placeholder (tf .float32 , [None , None ])
453453 layer = self .subclass (scale , [1 ], scale_bound = 0 , mean = None )
454454 _ , likelihood = layer (inputs , training = False )
455- with self .test_session () as sess :
455+ with self .cached_session () as sess :
456456 sess .run (tf .global_variables_initializer ())
457457 values = np .arange (- 5 , 1 )[:, None ] # must be integers due to quantization
458458 scales = 2 ** np .linspace (- 3 , 3 , 10 )[None , :]
@@ -476,7 +476,7 @@ def test_entropy_estimates(self):
476476 disc_entropy = tf .reduce_mean (tf .log (likelihood ), axis = 1 )
477477 disc_entropy /= - np .log (2 )
478478 bitstrings = layer .compress (inputs )
479- with self .test_session () as sess :
479+ with self .cached_session () as sess :
480480 sess .run (tf .global_variables_initializer ())
481481 scales = np .repeat ([layer .scale_table ], 10000 , axis = 0 ).T
482482 values = self .scipy_class .rvs (scale = scales , size = scales .shape )
0 commit comments