Skip to content

Commit 5a3e7b9

Browse files
shraman-rcmn-robot
authored andcommitted
Support for other (deprecated) layer APIs (e.g., tf.layers) and parameterization with default scopes.
PiperOrigin-RevId: 304289816
1 parent ce86b4a commit 5a3e7b9

File tree

3 files changed

+187
-45
lines changed

3 files changed

+187
-45
lines changed

morph_net/tools/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ py_library(
4141
"//third_party/py/tensorflow",
4242
"//third_party/tensorflow/contrib/framework:framework_py",
4343
"//third_party/tensorflow/contrib/layers:layers_py",
44+
"//third_party/tensorflow/contrib/slim",
4445
],
4546
)
4647

morph_net/tools/configurable_ops.py

Lines changed: 120 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,67 @@
4545
from __future__ import print_function
4646

4747
import collections
48+
import copy
4849
import json
4950
from enum import Enum
5051

5152
import tensorflow.compat.v1 as tf
53+
from tensorflow.compat.v1 import layers as tf_layers
5254
from tensorflow.contrib import framework
53-
from tensorflow.contrib import layers
55+
from tensorflow.contrib import layers as contrib_layers
56+
from tensorflow.contrib import slim as slim_layers
5457
# gfile = tf.gfile # Aliase needed for mock.
5558

5659
VANISHED = 0.0
57-
NUM_OUTPUTS = 'num_outputs'
60+
_DEFAULT_NUM_OUTPUTS_KWARG = 'num_outputs'
61+
62+
_DEFAULT_FUNCTION_DICT = {
63+
'fully_connected': contrib_layers.fully_connected,
64+
'conv2d': contrib_layers.conv2d,
65+
'separable_conv2d': contrib_layers.separable_conv2d,
66+
'concat': tf.concat,
67+
'add_n': tf.add_n,
68+
'avg_pool2d': contrib_layers.avg_pool2d,
69+
'max_pool2d': contrib_layers.max_pool2d,
70+
'batch_norm': contrib_layers.batch_norm,
71+
}
72+
73+
_OP_SCOPE_DEFAULTS = {
74+
tf_layers.conv2d: 'conv2d',
75+
slim_layers.conv2d: 'Conv',
76+
contrib_layers.conv2d: 'Conv',
77+
78+
tf_layers.separable_conv2d: 'separable_conv2d',
79+
slim_layers.separable_conv2d: 'SeparableConv2d',
80+
contrib_layers.separable_conv2d: 'SeparableConv2d',
81+
82+
tf_layers.dense: 'dense',
83+
slim_layers.fully_connected: 'fully_connected',
84+
contrib_layers.fully_connected: 'fully_connected',
85+
}
86+
87+
# Maps function names to the suffix of the name of the regularized ops.
88+
_SUFFIX_DICT = {
89+
'fully_connected': 'MatMul',
90+
'conv2d': 'Conv2D',
91+
'separable_conv2d': 'separable_conv2d',
92+
}
93+
94+
95+
def get_function_dict(overrides=None):
96+
"""Get mapping from function name to function for ConfigurableOps.
97+
98+
Args:
99+
overrides: Dict: str -> function. Optionally replace entries in
100+
`_DEFAULT_FUNCTION_DICT`.
101+
102+
Returns:
103+
Dict: function name (str) to function.
104+
"""
105+
overrides = overrides or {}
106+
function_dict = copy.deepcopy(_DEFAULT_FUNCTION_DICT)
107+
function_dict.update(overrides)
108+
return function_dict
58109

59110

60111
def is_vanished(maybe_tensor):
@@ -80,25 +131,6 @@ class FallbackRule(Enum):
80131
zero = 'zero'
81132

82133

83-
DEFAULT_FUNCTION_DICT = {
84-
'fully_connected': layers.fully_connected,
85-
'conv2d': layers.conv2d,
86-
'separable_conv2d': layers.separable_conv2d,
87-
'concat': tf.concat,
88-
'add_n': tf.add_n,
89-
'avg_pool2d': layers.avg_pool2d,
90-
'max_pool2d': layers.max_pool2d,
91-
'batch_norm': layers.batch_norm,
92-
}
93-
94-
# Maps function names to the suffix of the name of the regularized ops.
95-
SUFFIX_DICT = {
96-
'fully_connected': 'MatMul',
97-
'conv2d': 'Conv2D',
98-
'separable_conv2d': 'separable_conv2d',
99-
}
100-
101-
102134
class ConfigurableOps(object):
103135
"""A class that facilitates structure modification of a Tensorflow graph.
104136
@@ -134,7 +166,7 @@ def __init__(self,
134166
integer which is the target NUM_OUTPUTS.
135167
function_dict: A dict between names of ops (strings) and functions
136168
which accept num_outputs as the second argument. If None defaults to
137-
DEFAULT_FUNCTION_DICT.
169+
_DEFAULT_FUNCTION_DICT.
138170
fallback_rule: A `FallbackRule` enum which controls fallback behavior:
139171
* 'pass_through' provided NUM_OUTPUTS is passed to decorated
140172
function (default).
@@ -152,14 +184,17 @@ def __init__(self,
152184
isinstance(fallback_rule, str)):
153185
raise ValueError('fallback_rule must be a string or FallbackRule Enum')
154186

155-
self._function_dict = function_dict or DEFAULT_FUNCTION_DICT
156-
self._suffix_dict = SUFFIX_DICT
187+
self._function_dict = function_dict or _DEFAULT_FUNCTION_DICT
188+
self._suffix_dict = _SUFFIX_DICT
157189
self._constructed_ops = collections.OrderedDict()
158190
if isinstance(fallback_rule, str):
159191
fallback_rule = FallbackRule[fallback_rule] # Converts from string.
160192
self._default_to_zero = fallback_rule == FallbackRule.zero
161193
self._strict = fallback_rule == FallbackRule.strict
162194

195+
# To keep track of the number of identical scopes encountered
196+
self._scope_counts = {}
197+
163198
@property
164199
def parameterization(self):
165200
"""Returns the parameterization dict mapping op names to num_outputs."""
@@ -235,16 +270,17 @@ def separable_conv2d(self, *args, **kwargs):
235270
Raises:
236271
ValueError: If kwargs does not contain a key named 'scope'.
237272
"""
238-
num_outputs = _get_from_args_or_kwargs(NUM_OUTPUTS, 1, args, kwargs,
239-
False)
273+
# This function actually only decorates the num_outputs of the Conv2D after
274+
# the depthwise convolution, as the former does not have any free params.
275+
fn, suffix = self._get_function_and_suffix('separable_conv2d')
276+
num_outputs_kwarg_name = self._get_num_outputs_kwarg_name(fn)
277+
num_outputs = _get_from_args_or_kwargs(
278+
num_outputs_kwarg_name, 1, args, kwargs, False)
240279
if num_outputs is None:
241280
tf.logging.warning(
242281
'Trying to decorate separable_conv2d with num_outputs = None')
243-
kwargs[NUM_OUTPUTS] = None
244-
# This function actually only decorates the num_outputs of the Conv2D after
245-
# the depthwise convolution, as the former does not have any free params.
282+
kwargs[num_outputs_kwarg_name] = None
246283

247-
fn, suffix = self._get_function_and_suffix('separable_conv2d')
248284
return self._mask(fn, suffix, *args, **kwargs)
249285

250286
def _mask(self, function, suffix, *args, **kwargs):
@@ -262,7 +298,7 @@ def _mask(self, function, suffix, *args, **kwargs):
262298
263299
Args:
264300
function: A callable function to mask the NUM_OUTPUTS parameter from.
265-
Examples for functions are in DEFAULT_FUNCTION_DICT.
301+
Examples for functions are in _DEFAULT_FUNCTION_DICT.
266302
The callable function must accept a NUM_OUTPUTS parameter as the
267303
second argument.
268304
suffix: A string with the suffix of the op name.
@@ -277,22 +313,42 @@ def _mask(self, function, suffix, *args, **kwargs):
277313
Raises:
278314
ValueError: If kwargs does not contain a key named 'scope'.
279315
"""
280-
if ('scope' not in kwargs) and ('name' not in kwargs):
281-
raise ValueError('kwargs must contain key \'scope\' or \'name\'')
282316
inputs = args[0] if args else kwargs.pop('inputs')
283317
if is_vanished(inputs):
284318
return VANISHED
285319

286-
# Support for tf.contrib.layers and tf.layers API.
287-
op_scope = kwargs.get('scope') or kwargs.get('name')
288320
current_scope = framework.get_name_scope() or ''
289321
if current_scope and not current_scope.endswith('/'):
290322
current_scope += '/'
291-
op_name = ''.join([current_scope, op_scope, '/', suffix])
323+
324+
op_scope = kwargs.get('scope') or kwargs.get('name')
325+
if op_scope:
326+
if op_scope.endswith('/'):
327+
raise ValueError(
328+
'Scope `{}` ends with `/` which leads to unexpected '
329+
'behavior.'.format(op_scope))
330+
full_scope = current_scope + op_scope
331+
else:
332+
# Use default scope, optionally appending a unique ID if scope exists
333+
if function not in _OP_SCOPE_DEFAULTS:
334+
raise ValueError(
335+
'No `scope` or `name` found in kwargs, and no default scope '
336+
'defined for {}'.format(_get_function_name(function)))
337+
op_scope = _OP_SCOPE_DEFAULTS[function]
338+
full_scope = current_scope + op_scope
339+
if full_scope in self._scope_counts:
340+
new_scope = full_scope + '_' + str(self._scope_counts[full_scope])
341+
self._scope_counts[full_scope] += 1
342+
full_scope = new_scope
343+
else:
344+
self._scope_counts[full_scope] = 1
345+
346+
op_name = full_scope + '/' + suffix
292347

293348
# Assumes `inputs` is the first argument and `num_outputs` is the second
294349
# argument.
295-
num_outputs = self._parse_num_outputs(op_name, args, kwargs)
350+
num_outputs = self._parse_num_outputs(
351+
op_name, self._get_num_outputs_kwarg_name(function), args, kwargs)
296352
args = args[2:] # Possibly and empty list of < 3 positional args are used.
297353

298354
self._insert_to_parameterization_log(op_name, num_outputs)
@@ -336,7 +392,16 @@ def batch_norm(self, *args, **kwargs):
336392
return self._pass_through_mask(
337393
self._function_dict['batch_norm'], *args, **kwargs)
338394

339-
def _parse_num_outputs(self, op_name, args, kwargs):
395+
def _get_num_outputs_kwarg_name(self, function):
396+
"""Gets the `num_outputs`-equivalent kwarg for a supported function."""
397+
alt_num_outputs_kwarg = {
398+
tf_layers.conv2d: 'filters',
399+
tf_layers.separable_conv2d: 'filters',
400+
tf_layers.dense: 'units',
401+
}
402+
return alt_num_outputs_kwarg.get(function, _DEFAULT_NUM_OUTPUTS_KWARG)
403+
404+
def _parse_num_outputs(self, op_name, num_outputs_kwarg_name, args, kwargs):
340405
"""Computes the target NUM_OUTPUTS and adjusts kwargs in place.
341406
342407
Will try to extract the number of outputs from the op_name's
@@ -346,6 +411,8 @@ def _parse_num_outputs(self, op_name, args, kwargs):
346411
347412
Args:
348413
op_name: A string, the name of the op to get NUM_OUTPUTS for.
414+
num_outputs_kwarg_name: A string, the name of the `num_outputs`-equivalent
415+
kwarg.
349416
args: Position arguments for the callable. Assumes that NUM_OUTPUTS
350417
position is 1.
351418
kwargs: key word arguments for the callable.
@@ -361,8 +428,9 @@ def _parse_num_outputs(self, op_name, args, kwargs):
361428
raise KeyError('op_name \"%s\" not found in parameterization' % op_name)
362429

363430
# Assumes that the position of num_outputs is 1.
364-
base_num_outputs = _get_from_args_or_kwargs(NUM_OUTPUTS, 1, args, kwargs)
365-
kwargs.pop(NUM_OUTPUTS, None) # Removes num_outputs from kwargs if there.
431+
base_num_outputs = _get_from_args_or_kwargs(
432+
num_outputs_kwarg_name, 1, args, kwargs)
433+
kwargs.pop(num_outputs_kwarg_name, None) # Removes num_outputs from kwargs.
366434

367435
default_num_outputs = 0 if self._default_to_zero else base_num_outputs
368436
return self._parameterization.get(op_name, default_num_outputs)
@@ -423,6 +491,11 @@ def _get_from_args_or_kwargs(name, index, args, kwargs, is_required=True):
423491
return None
424492

425493

494+
def _get_function_name(function):
495+
"""Get a descriptive identifier for `function`."""
496+
return '{}.{}'.format(function.__module__, function.__name__)
497+
498+
426499
def hijack_module_functions(configurable_ops, module):
427500
"""Hijacks the functions from module using configurable_ops.
428501
@@ -458,7 +531,7 @@ def build_layer_not_affected(inputs):
458531
459532
Args:
460533
configurable_ops: An ConfigurableOps object, to use functions as defined in
461-
'DEFAULT_FUNCTION_DICT'.
534+
'_DEFAULT_FUNCTION_DICT'.
462535
module: A module name to override its functions.
463536
464537
Returns:
@@ -480,7 +553,7 @@ def maybe_setattr(attr):
480553
originals[attr] = getattr(module, attr)
481554
setattr(module, attr, getattr(configurable_ops, attr))
482555

483-
for fn in DEFAULT_FUNCTION_DICT:
556+
for fn in _DEFAULT_FUNCTION_DICT:
484557
maybe_setattr(fn)
485558
return originals
486559

@@ -490,7 +563,7 @@ def recover_module_functions(originals, module):
490563
491564
Args:
492565
originals: Dict of functions to recover. Assumes keys are a contained in
493-
'DEFAULT_FUNCTION_DICT'.
566+
'_DEFAULT_FUNCTION_DICT'.
494567
module: A module name to recover its functions.
495568
496569
"""
@@ -499,7 +572,7 @@ def recover_module_functions(originals, module):
499572

500573

501574
def decorator_from_parameterization_file(
502-
filename, fallback_rule=FallbackRule.pass_through):
575+
filename, fallback_rule=FallbackRule.pass_through, **kwargs):
503576
"""Create a ConfigurableOps from a parameterization file.
504577
505578
Loads a json parameterization file from disk
@@ -510,11 +583,13 @@ def decorator_from_parameterization_file(
510583
filename: Path to a parameterization file in json format.
511584
fallback_rule: A `FallbackRule` enum which controls fallback behavior
512585
(see __init__ for more detail.)
586+
**kwargs: Miscellaneous args for ConfigurableOps.
513587
514588
Returns:
515589
An ConfigurableOps instance with the parameterization from `filename`.
516590
"""
517591
with tf.gfile.Open(filename, 'r') as f:
518592
parameterization = json.loads(f.read())
519593
return ConfigurableOps(
520-
parameterization=parameterization, fallback_rule=fallback_rule)
594+
parameterization=parameterization, fallback_rule=fallback_rule,
595+
**kwargs)

morph_net/tools/configurable_ops_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from morph_net.tools import configurable_ops as ops
1515

1616
import tensorflow.compat.v1 as tf
17+
import tensorflow.contrib as tf_contrib
1718

1819
from tensorflow.contrib import layers
1920
from tensorflow.contrib.framework import add_arg_scope
@@ -424,6 +425,71 @@ def testBatchNorm(self):
424425
self.assertAllEqual(decorator_regular_output, tf_output)
425426
self.assertTrue(ops.is_vanished(decorator_zero_output))
426427

428+
@parameterized.named_parameters(
429+
('_SlimLayers', tf_contrib.slim.conv2d, 'num_outputs', 'Conv'),
430+
('_ContribLayers', tf_contrib.layers.conv2d, 'num_outputs', 'Conv'),
431+
('_TfLayer', tf.layers.conv2d, 'filters', 'conv2d'))
432+
def testDefaultScopes_Conv(
433+
self, conv_fn, num_outputs_kwarg, expected_op_scope):
434+
inputs = tf.ones([1, 3, 3, 2])
435+
parameterization = {
436+
'{}/Conv2D'.format(expected_op_scope): 5
437+
}
438+
decorator = ops.ConfigurableOps(
439+
parameterization=parameterization, function_dict={'conv2d': conv_fn})
440+
_ = decorator.conv2d(inputs, **{num_outputs_kwarg: 8, 'kernel_size': 2})
441+
self.assertDictEqual(parameterization, decorator.constructed_ops)
442+
443+
@parameterized.named_parameters(
444+
('_SlimLayers',
445+
tf_contrib.slim.fully_connected, 'num_outputs', 'fully_connected'),
446+
('_ContribLayers',
447+
tf_contrib.layers.fully_connected, 'num_outputs', 'fully_connected'),
448+
('_TfLayer',
449+
tf.layers.dense, 'units', 'dense'))
450+
def testDefaultScopes_Dense(
451+
self, dense_fn, num_outputs_kwarg, expected_op_scope):
452+
inputs = tf.ones([1, 2])
453+
parameterization = {
454+
'{}/MatMul'.format(expected_op_scope): 5
455+
}
456+
decorator = ops.ConfigurableOps(
457+
parameterization=parameterization,
458+
function_dict={'fully_connected': dense_fn})
459+
_ = decorator.fully_connected(inputs, **{num_outputs_kwarg: 8})
460+
self.assertDictEqual(parameterization, decorator.constructed_ops)
461+
462+
def testDefaultScopesRepeated(self):
463+
inputs = tf.ones([1, 3, 3, 2])
464+
parameterization = {
465+
's1/SeparableConv2d/separable_conv2d': 1,
466+
's1/SeparableConv2d_1/separable_conv2d': 2,
467+
's1/s2/SeparableConv2d/separable_conv2d': 3,
468+
's1/s2/SeparableConv2d_1/separable_conv2d': 4,
469+
}
470+
decorator = ops.ConfigurableOps(
471+
parameterization=parameterization,
472+
function_dict={'separable_conv2d': tf_contrib.slim.separable_conv2d})
473+
474+
with tf.variable_scope('s1'):
475+
# first call in s1: op scope should be `s1/SeparableConv2d`
476+
_ = decorator.separable_conv2d(inputs, num_outputs=8, kernel_size=2)
477+
478+
with tf.variable_scope('s2'):
479+
# first call in s2: op scope should be `s1/s2/SeparableConv2d`
480+
_ = decorator.separable_conv2d(inputs, num_outputs=8, kernel_size=2)
481+
482+
# second call in s2: op scope should be `s1/s2/SeparableConv2d_1`
483+
_ = decorator.separable_conv2d(inputs, num_outputs=8, kernel_size=2)
484+
485+
# second call in s1: op scope should be `s1/SeparableConv2d_1`
486+
_ = decorator.separable_conv2d(inputs, num_outputs=8, kernel_size=2)
487+
488+
conv_op_names = [op.name for op in tf.get_default_graph().get_operations()
489+
if op.name.endswith('separable_conv2d')]
490+
self.assertCountEqual(parameterization, conv_op_names)
491+
self.assertDictEqual(parameterization, decorator.constructed_ops)
492+
427493

428494
class Fake(object):
429495
# This Class is a cheap simulation of a module.

0 commit comments

Comments
 (0)