|
20 | 20 | import collections |
21 | 21 | import functools |
22 | 22 | import types |
| 23 | +from keras.utils import traceback_utils |
23 | 24 |
|
24 | 25 | import attr |
25 | 26 | import neural_structured_learning.configs as nsl_configs |
@@ -699,6 +700,61 @@ def save(self, *args, **kwargs): |
699 | 700 | 'Saving `AdversarialRegularization` models is currently not supported. ' |
700 | 701 | 'Consider using `save_weights` or saving the `base_model`.') |
701 | 702 |
|
| 703 | + |
| 704 | + @traceback_utils.filter_traceback |
| 705 | + def save(self, |
| 706 | + filepath, |
| 707 | + overwrite=True, |
| 708 | + include_optimizer=True, |
| 709 | + save_format=None, |
| 710 | + signatures=None, |
| 711 | + options=None, |
| 712 | + save_traces=True): |
| 713 | + # pylint: disable=line-too-long |
| 714 | + """Saves the model to Tensorflow SavedModel or a single HDF5 file. |
| 715 | + Please see `tf.keras.models.save_model` or the |
| 716 | + [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) |
| 717 | + for details. |
| 718 | + Args: |
| 719 | + filepath: String, PathLike, path to SavedModel or H5 file to save the |
| 720 | + model. |
| 721 | + overwrite: Whether to silently overwrite any existing file at the |
| 722 | + target location, or provide the user with a manual prompt. |
| 723 | + include_optimizer: If True, save optimizer's state together. |
| 724 | + save_format: Either `'tf'` or `'h5'`, indicating whether to save the |
| 725 | + model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, |
| 726 | + and 'h5' in TF 1.X. |
| 727 | + signatures: Signatures to save with the SavedModel. Applicable to the |
| 728 | + 'tf' format only. Please see the `signatures` argument in |
| 729 | + `tf.saved_model.save` for details. |
| 730 | + options: (only applies to SavedModel format) |
| 731 | + `tf.saved_model.SaveOptions` object that specifies options for |
| 732 | + saving to SavedModel. |
| 733 | + save_traces: (only applies to SavedModel format) When enabled, the |
| 734 | + SavedModel will store the function traces for each layer. This |
| 735 | + can be disabled, so that only the configs of each layer are stored. |
| 736 | + Defaults to `True`. Disabling this will decrease serialization time |
| 737 | + and reduce file size, but it requires that all custom layers/models |
| 738 | + implement a `get_config()` method. |
| 739 | + Example: |
| 740 | + ```python |
| 741 | + from keras.models import load_model |
| 742 | + model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' |
| 743 | + del model # deletes the existing model |
| 744 | + # returns a compiled model |
| 745 | + # identical to the previous one |
| 746 | + model = load_model('my_model.h5') |
| 747 | + ``` |
| 748 | + """ |
| 749 | + # pylint: enable=line-too-long |
| 750 | + save.save_model(self, filepath, overwrite, include_optimizer, save_format, |
| 751 | + signatures, options, save_traces) |
| 752 | + |
| 753 | + |
| 754 | + |
| 755 | + |
| 756 | + |
| 757 | + |
702 | 758 | def perturb_on_batch(self, x, **config_kwargs): |
703 | 759 | """Perturbs the given input to generates adversarial examples. |
704 | 760 |
|
|
0 commit comments