2222)
2323from vertexai import generative_models
2424from vertexai .tuning import _tuning
25+ from vertexai .tuning ._tuning import SourceModel
2526
2627
2728def train (
2829 * ,
29- source_model : Union [str , generative_models .GenerativeModel ],
30+ source_model : Union [str , generative_models .GenerativeModel , SourceModel ],
3031 train_dataset : Union [str , datasets .MultimodalDataset ],
3132 validation_dataset : Optional [Union [str , datasets .MultimodalDataset ]] = None ,
3233 tuned_model_display_name : Optional [str ] = None ,
34+ tuning_mode : Optional [Literal ["FULL" , "PEFT_ADAPTER" ]] = None ,
3335 epochs : Optional [int ] = None ,
36+ learning_rate : Optional [float ] = None ,
3437 learning_rate_multiplier : Optional [float ] = None ,
3538 adapter_size : Optional [Literal [1 , 4 , 8 , 16 , 32 ]] = None ,
3639 labels : Optional [Dict [str , str ]] = None ,
40+ output_uri : Optional [str ] = None ,
3741) -> "SupervisedTuningJob" :
3842 """Tunes a model using supervised training.
3943
@@ -45,13 +49,39 @@ def train(
4549 [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4650 128 characters long and can consist of any UTF-8 characters.
4751 epochs: Number of training epoches for this tuning job.
48- learning_rate_multiplier: Learning rate multiplier for tuning.
52+ tuning_mode: Tuning mode for this tuning job. Can only be used with OSS
53+ models.
54+ learning_rate: Learning rate for tuning. Can only be used with OSS
55+ models. Mutually exclusive with `learning_rate_multiplier`.
56+ learning_rate_multiplier: Learning rate multiplier for tuning. Mutually
57+ exclusive with `learning_rate`.
4958 adapter_size: Adapter size for tuning.
5059 labels: User-defined metadata to be associated with trained models
60+ output_uri: The Google Cloud Storage URI to write the tuned model to.
5161
5262 Returns:
5363 A `TuningJob` object.
5464 """
65+ if tuning_mode is None :
66+ tuning_mode_value = None
67+ elif tuning_mode == "FULL" :
68+ tuning_mode_value = (
69+ gca_tuning_job_types .SupervisedTuningSpec .TuningMode .TUNING_MODE_FULL
70+ )
71+ elif tuning_mode == "PEFT_ADAPTER" :
72+ tuning_mode_value = (
73+ gca_tuning_job_types .SupervisedTuningSpec .TuningMode .TUNING_MODE_PEFT_ADAPTER
74+ )
75+ else :
76+ raise ValueError (
77+ f"Unsupported tuning mode: { tuning_mode } . The supported tuning modes are [FULL, PEFT_ADAPTER]"
78+ )
79+
80+ if learning_rate and learning_rate_multiplier :
81+ raise ValueError (
82+ "Only one of `learning_rate` and `learning_rate_multiplier` can be set."
83+ )
84+
5585 if adapter_size is None :
5686 adapter_size_value = None
5787 elif adapter_size == 1 :
@@ -83,10 +113,12 @@ def train(
83113 if isinstance (validation_dataset , datasets .MultimodalDataset ):
84114 validation_dataset = validation_dataset .resource_name
85115 supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
116+ tuning_mode = tuning_mode_value ,
86117 training_dataset_uri = train_dataset ,
87118 validation_dataset_uri = validation_dataset ,
88119 hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
89120 epoch_count = epochs ,
121+ learning_rate = learning_rate ,
90122 learning_rate_multiplier = learning_rate_multiplier ,
91123 adapter_size = adapter_size_value ,
92124 ),
@@ -95,12 +127,18 @@ def train(
95127 if isinstance (source_model , generative_models .GenerativeModel ):
96128 source_model = source_model ._prediction_resource_name .rpartition ("/" )[- 1 ]
97129
130+ if labels is None :
131+ labels = {}
132+ if "mg-source" not in labels :
133+ labels ["mg-source" ] = "sdk"
134+
98135 supervised_tuning_job = (
99136 SupervisedTuningJob ._create ( # pylint: disable=protected-access
100137 base_model = source_model ,
101138 tuning_spec = supervised_tuning_spec ,
102139 tuned_model_display_name = tuned_model_display_name ,
103140 labels = labels ,
141+ output_uri = output_uri ,
104142 )
105143 )
106144 _ipython_utils .display_model_tuning_button (supervised_tuning_job )
0 commit comments