Skip to content

Commit 37d60cf

Browse files
Update adversarial_regularization.py
1 parent b30cba8 commit 37d60cf

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import collections
2121
import functools
2222
import types
23+
from keras.utils import traceback_utils
2324

2425
import attr
2526
import neural_structured_learning.configs as nsl_configs
@@ -699,6 +700,61 @@ def save(self, *args, **kwargs):
699700
'Saving `AdversarialRegularization` models is currently not supported. '
700701
'Consider using `save_weights` or saving the `base_model`.')
701702

703+
704+
@traceback_utils.filter_traceback
705+
def save(self,
706+
filepath,
707+
overwrite=True,
708+
include_optimizer=True,
709+
save_format=None,
710+
signatures=None,
711+
options=None,
712+
save_traces=True):
713+
# pylint: disable=line-too-long
714+
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
715+
Please see `tf.keras.models.save_model` or the
716+
[Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
717+
for details.
718+
Args:
719+
filepath: String, PathLike, path to SavedModel or H5 file to save the
720+
model.
721+
overwrite: Whether to silently overwrite any existing file at the
722+
target location, or provide the user with a manual prompt.
723+
include_optimizer: If True, save optimizer's state together.
724+
save_format: Either `'tf'` or `'h5'`, indicating whether to save the
725+
model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
726+
and 'h5' in TF 1.X.
727+
signatures: Signatures to save with the SavedModel. Applicable to the
728+
'tf' format only. Please see the `signatures` argument in
729+
`tf.saved_model.save` for details.
730+
options: (only applies to SavedModel format)
731+
`tf.saved_model.SaveOptions` object that specifies options for
732+
saving to SavedModel.
733+
save_traces: (only applies to SavedModel format) When enabled, the
734+
SavedModel will store the function traces for each layer. This
735+
can be disabled, so that only the configs of each layer are stored.
736+
Defaults to `True`. Disabling this will decrease serialization time
737+
and reduce file size, but it requires that all custom layers/models
738+
implement a `get_config()` method.
739+
Example:
740+
```python
741+
from keras.models import load_model
742+
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
743+
del model # deletes the existing model
744+
# returns a compiled model
745+
# identical to the previous one
746+
model = load_model('my_model.h5')
747+
```
748+
"""
749+
# pylint: enable=line-too-long
750+
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
751+
signatures, options, save_traces)
752+
753+
754+
755+
756+
757+
702758
def perturb_on_batch(self, x, **config_kwargs):
703759
"""Perturbs the given input to generates adversarial examples.
704760

0 commit comments

Comments
 (0)