4545from __future__ import print_function
4646
4747import collections
48+ import copy
4849import json
4950from enum import Enum
5051
5152import tensorflow .compat .v1 as tf
53+ from tensorflow .compat .v1 import layers as tf_layers
5254from tensorflow .contrib import framework
53- from tensorflow .contrib import layers
55+ from tensorflow .contrib import layers as contrib_layers
56+ from tensorflow .contrib import slim as slim_layers
5457# gfile = tf.gfile # Aliase needed for mock.
5558
5659VANISHED = 0.0
57- NUM_OUTPUTS = 'num_outputs'
60+ _DEFAULT_NUM_OUTPUTS_KWARG = 'num_outputs'
61+
62+ _DEFAULT_FUNCTION_DICT = {
63+ 'fully_connected' : contrib_layers .fully_connected ,
64+ 'conv2d' : contrib_layers .conv2d ,
65+ 'separable_conv2d' : contrib_layers .separable_conv2d ,
66+ 'concat' : tf .concat ,
67+ 'add_n' : tf .add_n ,
68+ 'avg_pool2d' : contrib_layers .avg_pool2d ,
69+ 'max_pool2d' : contrib_layers .max_pool2d ,
70+ 'batch_norm' : contrib_layers .batch_norm ,
71+ }
72+
73+ _OP_SCOPE_DEFAULTS = {
74+ tf_layers .conv2d : 'conv2d' ,
75+ slim_layers .conv2d : 'Conv' ,
76+ contrib_layers .conv2d : 'Conv' ,
77+
78+ tf_layers .separable_conv2d : 'separable_conv2d' ,
79+ slim_layers .separable_conv2d : 'SeparableConv2d' ,
80+ contrib_layers .separable_conv2d : 'SeparableConv2d' ,
81+
82+ tf_layers .dense : 'dense' ,
83+ slim_layers .fully_connected : 'fully_connected' ,
84+ contrib_layers .fully_connected : 'fully_connected' ,
85+ }
86+
87+ # Maps function names to the suffix of the name of the regularized ops.
88+ _SUFFIX_DICT = {
89+ 'fully_connected' : 'MatMul' ,
90+ 'conv2d' : 'Conv2D' ,
91+ 'separable_conv2d' : 'separable_conv2d' ,
92+ }
93+
94+
95+ def get_function_dict (overrides = None ):
96+ """Get mapping from function name to function for ConfigurableOps.
97+
98+ Args:
99+ overrides: Dict: str -> function. Optionally replace entries in
100+ `_DEFAULT_FUNCTION_DICT`.
101+
102+ Returns:
103+ Dict: function name (str) to function.
104+ """
105+ overrides = overrides or {}
106+ function_dict = copy .deepcopy (_DEFAULT_FUNCTION_DICT )
107+ function_dict .update (overrides )
108+ return function_dict
58109
59110
60111def is_vanished (maybe_tensor ):
@@ -80,25 +131,6 @@ class FallbackRule(Enum):
80131 zero = 'zero'
81132
82133
83- DEFAULT_FUNCTION_DICT = {
84- 'fully_connected' : layers .fully_connected ,
85- 'conv2d' : layers .conv2d ,
86- 'separable_conv2d' : layers .separable_conv2d ,
87- 'concat' : tf .concat ,
88- 'add_n' : tf .add_n ,
89- 'avg_pool2d' : layers .avg_pool2d ,
90- 'max_pool2d' : layers .max_pool2d ,
91- 'batch_norm' : layers .batch_norm ,
92- }
93-
94- # Maps function names to the suffix of the name of the regularized ops.
95- SUFFIX_DICT = {
96- 'fully_connected' : 'MatMul' ,
97- 'conv2d' : 'Conv2D' ,
98- 'separable_conv2d' : 'separable_conv2d' ,
99- }
100-
101-
102134class ConfigurableOps (object ):
103135 """A class that facilitates structure modification of a Tensorflow graph.
104136
@@ -134,7 +166,7 @@ def __init__(self,
134166 integer which is the target NUM_OUTPUTS.
135167 function_dict: A dict between names of ops (strings) and functions
136168 which accept num_outputs as the second argument. If None defaults to
137- DEFAULT_FUNCTION_DICT .
169+ _DEFAULT_FUNCTION_DICT .
138170 fallback_rule: A `FallbackRule` enum which controls fallback behavior:
139171 * 'pass_through' provided NUM_OUTPUTS is passed to decorated
140172 function (default).
@@ -152,14 +184,17 @@ def __init__(self,
152184 isinstance (fallback_rule , str )):
153185 raise ValueError ('fallback_rule must be a string or FallbackRule Enum' )
154186
155- self ._function_dict = function_dict or DEFAULT_FUNCTION_DICT
156- self ._suffix_dict = SUFFIX_DICT
187+ self ._function_dict = function_dict or _DEFAULT_FUNCTION_DICT
188+ self ._suffix_dict = _SUFFIX_DICT
157189 self ._constructed_ops = collections .OrderedDict ()
158190 if isinstance (fallback_rule , str ):
159191 fallback_rule = FallbackRule [fallback_rule ] # Converts from string.
160192 self ._default_to_zero = fallback_rule == FallbackRule .zero
161193 self ._strict = fallback_rule == FallbackRule .strict
162194
195+ # To keep track of the number of identical scopes encountered
196+ self ._scope_counts = {}
197+
163198 @property
164199 def parameterization (self ):
165200 """Returns the parameterization dict mapping op names to num_outputs."""
@@ -235,16 +270,17 @@ def separable_conv2d(self, *args, **kwargs):
235270 Raises:
236271 ValueError: If kwargs does not contain a key named 'scope'.
237272 """
238- num_outputs = _get_from_args_or_kwargs (NUM_OUTPUTS , 1 , args , kwargs ,
239- False )
273+ # This function actually only decorates the num_outputs of the Conv2D after
274+ # the depthwise convolution, as the former does not have any free params.
275+ fn , suffix = self ._get_function_and_suffix ('separable_conv2d' )
276+ num_outputs_kwarg_name = self ._get_num_outputs_kwarg_name (fn )
277+ num_outputs = _get_from_args_or_kwargs (
278+ num_outputs_kwarg_name , 1 , args , kwargs , False )
240279 if num_outputs is None :
241280 tf .logging .warning (
242281 'Trying to decorate separable_conv2d with num_outputs = None' )
243- kwargs [NUM_OUTPUTS ] = None
244- # This function actually only decorates the num_outputs of the Conv2D after
245- # the depthwise convolution, as the former does not have any free params.
282+ kwargs [num_outputs_kwarg_name ] = None
246283
247- fn , suffix = self ._get_function_and_suffix ('separable_conv2d' )
248284 return self ._mask (fn , suffix , * args , ** kwargs )
249285
250286 def _mask (self , function , suffix , * args , ** kwargs ):
@@ -262,7 +298,7 @@ def _mask(self, function, suffix, *args, **kwargs):
262298
263299 Args:
264300 function: A callable function to mask the NUM_OUTPUTS parameter from.
265- Examples for functions are in DEFAULT_FUNCTION_DICT .
301+ Examples for functions are in _DEFAULT_FUNCTION_DICT .
266302 The callable function must accept a NUM_OUTPUTS parameter as the
267303 second argument.
268304 suffix: A string with the suffix of the op name.
@@ -277,22 +313,42 @@ def _mask(self, function, suffix, *args, **kwargs):
277313 Raises:
278314 ValueError: If kwargs does not contain a key named 'scope'.
279315 """
280- if ('scope' not in kwargs ) and ('name' not in kwargs ):
281- raise ValueError ('kwargs must contain key \' scope\' or \' name\' ' )
282316 inputs = args [0 ] if args else kwargs .pop ('inputs' )
283317 if is_vanished (inputs ):
284318 return VANISHED
285319
286- # Support for tf.contrib.layers and tf.layers API.
287- op_scope = kwargs .get ('scope' ) or kwargs .get ('name' )
288320 current_scope = framework .get_name_scope () or ''
289321 if current_scope and not current_scope .endswith ('/' ):
290322 current_scope += '/'
291- op_name = '' .join ([current_scope , op_scope , '/' , suffix ])
323+
324+ op_scope = kwargs .get ('scope' ) or kwargs .get ('name' )
325+ if op_scope :
326+ if op_scope .endswith ('/' ):
327+ raise ValueError (
328+ 'Scope `{}` ends with `/` which leads to unexpected '
329+ 'behavior.' .format (op_scope ))
330+ full_scope = current_scope + op_scope
331+ else :
332+ # Use default scope, optionally appending a unique ID if scope exists
333+ if function not in _OP_SCOPE_DEFAULTS :
334+ raise ValueError (
335+ 'No `scope` or `name` found in kwargs, and no default scope '
336+ 'defined for {}' .format (_get_function_name (function )))
337+ op_scope = _OP_SCOPE_DEFAULTS [function ]
338+ full_scope = current_scope + op_scope
339+ if full_scope in self ._scope_counts :
340+ new_scope = full_scope + '_' + str (self ._scope_counts [full_scope ])
341+ self ._scope_counts [full_scope ] += 1
342+ full_scope = new_scope
343+ else :
344+ self ._scope_counts [full_scope ] = 1
345+
346+ op_name = full_scope + '/' + suffix
292347
293348 # Assumes `inputs` is the first argument and `num_outputs` is the second
294349 # argument.
295- num_outputs = self ._parse_num_outputs (op_name , args , kwargs )
350+ num_outputs = self ._parse_num_outputs (
351+ op_name , self ._get_num_outputs_kwarg_name (function ), args , kwargs )
296352 args = args [2 :] # Possibly and empty list of < 3 positional args are used.
297353
298354 self ._insert_to_parameterization_log (op_name , num_outputs )
@@ -336,7 +392,16 @@ def batch_norm(self, *args, **kwargs):
336392 return self ._pass_through_mask (
337393 self ._function_dict ['batch_norm' ], * args , ** kwargs )
338394
339- def _parse_num_outputs (self , op_name , args , kwargs ):
395+ def _get_num_outputs_kwarg_name (self , function ):
396+ """Gets the `num_outputs`-equivalent kwarg for a supported function."""
397+ alt_num_outputs_kwarg = {
398+ tf_layers .conv2d : 'filters' ,
399+ tf_layers .separable_conv2d : 'filters' ,
400+ tf_layers .dense : 'units' ,
401+ }
402+ return alt_num_outputs_kwarg .get (function , _DEFAULT_NUM_OUTPUTS_KWARG )
403+
404+ def _parse_num_outputs (self , op_name , num_outputs_kwarg_name , args , kwargs ):
340405 """Computes the target NUM_OUTPUTS and adjusts kwargs in place.
341406
342407 Will try to extract the number of outputs from the op_name's
@@ -346,6 +411,8 @@ def _parse_num_outputs(self, op_name, args, kwargs):
346411
347412 Args:
348413 op_name: A string, the name of the op to get NUM_OUTPUTS for.
414+ num_outputs_kwarg_name: A string, the name of the `num_outputs`-equivalent
415+ kwarg.
349416 args: Position arguments for the callable. Assumes that NUM_OUTPUTS
350417 position is 1.
351418 kwargs: key word arguments for the callable.
@@ -361,8 +428,9 @@ def _parse_num_outputs(self, op_name, args, kwargs):
361428 raise KeyError ('op_name \" %s\" not found in parameterization' % op_name )
362429
363430 # Assumes that the position of num_outputs is 1.
364- base_num_outputs = _get_from_args_or_kwargs (NUM_OUTPUTS , 1 , args , kwargs )
365- kwargs .pop (NUM_OUTPUTS , None ) # Removes num_outputs from kwargs if there.
431+ base_num_outputs = _get_from_args_or_kwargs (
432+ num_outputs_kwarg_name , 1 , args , kwargs )
433+ kwargs .pop (num_outputs_kwarg_name , None ) # Removes num_outputs from kwargs.
366434
367435 default_num_outputs = 0 if self ._default_to_zero else base_num_outputs
368436 return self ._parameterization .get (op_name , default_num_outputs )
@@ -423,6 +491,11 @@ def _get_from_args_or_kwargs(name, index, args, kwargs, is_required=True):
423491 return None
424492
425493
494+ def _get_function_name (function ):
495+ """Get a descriptive identifier for `function`."""
496+ return '{}.{}' .format (function .__module__ , function .__name__ )
497+
498+
426499def hijack_module_functions (configurable_ops , module ):
427500 """Hijacks the functions from module using configurable_ops.
428501
@@ -458,7 +531,7 @@ def build_layer_not_affected(inputs):
458531
459532 Args:
460533 configurable_ops: An ConfigurableOps object, to use functions as defined in
461- 'DEFAULT_FUNCTION_DICT '.
534+ '_DEFAULT_FUNCTION_DICT '.
462535 module: A module name to override its functions.
463536
464537 Returns:
@@ -480,7 +553,7 @@ def maybe_setattr(attr):
480553 originals [attr ] = getattr (module , attr )
481554 setattr (module , attr , getattr (configurable_ops , attr ))
482555
483- for fn in DEFAULT_FUNCTION_DICT :
556+ for fn in _DEFAULT_FUNCTION_DICT :
484557 maybe_setattr (fn )
485558 return originals
486559
@@ -490,7 +563,7 @@ def recover_module_functions(originals, module):
490563
491564 Args:
492565 originals: Dict of functions to recover. Assumes keys are a contained in
493- 'DEFAULT_FUNCTION_DICT '.
566+ '_DEFAULT_FUNCTION_DICT '.
494567 module: A module name to recover its functions.
495568
496569 """
@@ -499,7 +572,7 @@ def recover_module_functions(originals, module):
499572
500573
501574def decorator_from_parameterization_file (
502- filename , fallback_rule = FallbackRule .pass_through ):
575+ filename , fallback_rule = FallbackRule .pass_through , ** kwargs ):
503576 """Create a ConfigurableOps from a parameterization file.
504577
505578 Loads a json parameterization file from disk
@@ -510,11 +583,13 @@ def decorator_from_parameterization_file(
510583 filename: Path to a parameterization file in json format.
511584 fallback_rule: A `FallbackRule` enum which controls fallback behavior
512585 (see __init__ for more detail.)
586+ **kwargs: Miscellaneous args for ConfigurableOps.
513587
514588 Returns:
515589 An ConfigurableOps instance with the parameterization from `filename`.
516590 """
517591 with tf .gfile .Open (filename , 'r' ) as f :
518592 parameterization = json .loads (f .read ())
519593 return ConfigurableOps (
520- parameterization = parameterization , fallback_rule = fallback_rule )
594+ parameterization = parameterization , fallback_rule = fallback_rule ,
595+ ** kwargs )
0 commit comments