Skip to content

Commit 56d9a6e

Browse files
shraman-rcmn-robot
authored andcommitted
Make tpu_util.write_to_variable compatible with both Keras models and models that use variable_scope.
PiperOrigin-RevId: 302173658
1 parent 9c25b7c commit 56d9a6e

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

morph_net/framework/tpu_util.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,31 +92,23 @@ def maybe_convert_to_variable(tensor):
9292

9393

9494
var_store = {}
95+
top_level_scope = tf.get_variable_scope()
9596

9697

9798
def write_to_variable(tensor, fail_if_exists=True):
9899
"""Saves a tensor for later retrieval on CPU."""
99100
# Only relevant for debugging.
100-
debug_name = 'tpu_util__' + tensor.name.split(':')[0].split('/')[-1]
101-
102-
if fail_if_exists:
103-
# Note: reuse cannot be changed from True to False, so we just check if
104-
# the variable exists.
105-
with tf.variable_scope('', reuse=True):
106-
try:
107-
tf.get_variable(debug_name)
108-
except ValueError:
109-
pass # Variable with name=debug_name does not exist; proceed.
110-
else:
111-
raise ValueError('Variable %s already exists!' % debug_name)
112-
113-
with tf.variable_scope('', reuse=tf.compat.v1.AUTO_REUSE):
101+
debug_name = 'tpu_util__' + tensor.name.split(':')[0]
102+
103+
reuse = False if fail_if_exists else tf.compat.v1.AUTO_REUSE
104+
with tf.variable_scope(top_level_scope, reuse=reuse):
114105
variable = tf.get_variable(
115106
name=debug_name,
116107
shape=tensor.shape,
117108
dtype=tensor.dtype,
118109
trainable=False,
119110
use_resource=True)
111+
120112
var_store[tensor] = variable
121113
with tf.control_dependencies([variable.assign(tensor)]):
122114
tensor_copy = tf.identity(tensor)

morph_net/framework/tpu_util_test.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,30 @@ def test_noop(self):
7979
self.assertEqual(tpu_util.maybe_convert_to_variable(relu), relu)
8080

8181
def test_write_to_variable(self):
82-
with tf.variable_scope(''):
83-
foo = tf.constant(0.)
84-
tpu_util.write_to_variable(foo)
82+
foo = tf.constant(0., name='foo')
83+
tpu_util.write_to_variable(foo)
84+
85+
with self.assertRaises(ValueError):
86+
tpu_util.write_to_variable(foo, fail_if_exists=True)
87+
88+
# Variable sharing behavior should be dictated by `fail_if_exists` which
89+
# overrides the effect of outer scopes.
90+
with tf.variable_scope('', reuse=True):
91+
# should fail to return existing variable even though reuse=True
92+
with self.assertRaises(ValueError):
93+
tpu_util.write_to_variable(foo, fail_if_exists=True)
94+
95+
with tf.variable_scope('', reuse=False):
96+
# should return existing variable even though reuse=False
97+
foo_copy = tpu_util.write_to_variable(foo, fail_if_exists=False)
98+
self.assertEqual(tpu_util.var_store[foo], tpu_util.var_store[foo_copy])
99+
self.assertLen(set(tpu_util.var_store.values()), 1)
100+
85101
with tf.variable_scope('', reuse=True):
86-
bar = tf.constant(0.)
102+
# should create new variable even though reuse=True
103+
bar = tf.constant(0., name='bar')
87104
tpu_util.write_to_variable(bar)
88-
with tf.variable_scope('', reuse=tf.compat.v1.AUTO_REUSE):
89-
zee = tf.constant(0.)
90-
tpu_util.write_to_variable(zee)
105+
self.assertLen(set(tpu_util.var_store.values()), 2)
91106

92107
if __name__ == '__main__':
93108
tf.test.main()

0 commit comments

Comments
 (0)