This repository was archived by the owner on Jul 7, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +16
-3
lines changed
Expand file tree Collapse file tree 2 files changed +16
-3
lines changed Original file line number Diff line number Diff line change @@ -4043,9 +4043,6 @@ def _data_dep_init(self, inputs):
40434043
40444044 def build (self , input_shape = None ):
40454045 """Build `Layer`."""
4046- input_shape = tf .TensorShape (input_shape ).as_list ()
4047- self .input_spec = layers ().InputSpec (shape = input_shape )
4048-
40494046 if not self .layer .built :
40504047 self .layer .build (input_shape )
40514048 self .layer .built = False
@@ -4072,6 +4069,7 @@ def build(self, input_shape=None):
40724069 self ._compute_weights ()
40734070
40744071 self .layer .built = True
4072+ self .input_spec = self .layer .input_spec
40754073
40764074 super (WeightNorm , self ).build ()
40774075 self .built = True
Original file line number Diff line number Diff line change @@ -965,5 +965,20 @@ def fn_recompute(x):
965965 self .assertAllClose (g1 , g2 )
966966
967967
968+ class WeightNormTest (tf .test .TestCase ):
969+
970+ def testInputSpec (self ):
971+ """Test that WeighNorm does not overspecify the input_spec."""
972+ conv = common_layers .WeightNorm (
973+ tf .keras .layers .Conv1D (filters = 8 , kernel_size = 3 ))
974+ # Call with one batch size:
975+ conv (tf .zeros ([1 , 16 , 2 ]))
976+ # Should allow call with another batch size.
977+ conv (tf .zeros ([2 , 16 , 2 ]))
978+ # Input spec does detect incorrect input feature dim.
979+ with self .assertRaises (ValueError ):
980+ conv (tf .zeros ([2 , 16 , 3 ]))
981+
982+
968983if __name__ == "__main__" :
969984 tf .test .main ()
You can’t perform that action at this time.
0 commit comments