Skip to content

Commit f3964a0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Enable Vertex Model Garden Managed OSS Fine Tuning.
PiperOrigin-RevId: 815806517
1 parent 7757886 commit f3964a0

File tree

3 files changed

+93
-10
lines changed

3 files changed

+93
-10
lines changed

vertexai/tuning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
# We just want to re-export certain classes
1818
# pylint: disable=g-multiple-import,g-importing-member
1919
from vertexai.tuning._tuning import TuningJob
20+
from vertexai.tuning._tuning import SourceModel
2021

2122
__all__ = [
23+
"SourceModel",
2224
"TuningJob",
2325
]

vertexai/tuning/_supervised_tuning.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,22 @@
2222
)
2323
from vertexai import generative_models
2424
from vertexai.tuning import _tuning
25+
from vertexai.tuning._tuning import SourceModel
2526

2627

2728
def train(
2829
*,
29-
source_model: Union[str, generative_models.GenerativeModel],
30+
source_model: Union[str, generative_models.GenerativeModel, SourceModel],
3031
train_dataset: Union[str, datasets.MultimodalDataset],
3132
validation_dataset: Optional[Union[str, datasets.MultimodalDataset]] = None,
3233
tuned_model_display_name: Optional[str] = None,
34+
tuning_mode: Optional[Literal["FULL", "PEFT_ADAPTER"]] = None,
3335
epochs: Optional[int] = None,
36+
learning_rate: Optional[float] = None,
3437
learning_rate_multiplier: Optional[float] = None,
3538
adapter_size: Optional[Literal[1, 4, 8, 16, 32]] = None,
3639
labels: Optional[Dict[str, str]] = None,
40+
output_uri: Optional[str] = None,
3741
) -> "SupervisedTuningJob":
3842
"""Tunes a model using supervised training.
3943
@@ -45,13 +49,39 @@ def train(
4549
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4650
128 characters long and can consist of any UTF-8 characters.
4751
epochs: Number of training epoches for this tuning job.
48-
learning_rate_multiplier: Learning rate multiplier for tuning.
52+
tuning_mode: Tuning mode for this tuning job. Can only be used with OSS
53+
models.
54+
learning_rate: Learning rate for tuning. Can only be used with OSS
55+
models. Mutually exclusive with `learning_rate_multiplier`.
56+
learning_rate_multiplier: Learning rate multiplier for tuning. Mutually
57+
exclusive with `learning_rate`.
4958
adapter_size: Adapter size for tuning.
5059
labels: User-defined metadata to be associated with trained models
60+
output_uri: The Google Cloud Storage URI to write the tuned model to.
5161
5262
Returns:
5363
A `TuningJob` object.
5464
"""
65+
if tuning_mode is None:
66+
tuning_mode_value = None
67+
elif tuning_mode == "FULL":
68+
tuning_mode_value = (
69+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_FULL
70+
)
71+
elif tuning_mode == "PEFT_ADAPTER":
72+
tuning_mode_value = (
73+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_PEFT_ADAPTER
74+
)
75+
else:
76+
raise ValueError(
77+
f"Unsupported tuning mode: {tuning_mode}. The supported tuning modes are [FULL, PEFT_ADAPTER]"
78+
)
79+
80+
if learning_rate and learning_rate_multiplier:
81+
raise ValueError(
82+
"Only one of `learning_rate` and `learning_rate_multiplier` can be set."
83+
)
84+
5585
if adapter_size is None:
5686
adapter_size_value = None
5787
elif adapter_size == 1:
@@ -83,10 +113,12 @@ def train(
83113
if isinstance(validation_dataset, datasets.MultimodalDataset):
84114
validation_dataset = validation_dataset.resource_name
85115
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
116+
tuning_mode=tuning_mode_value,
86117
training_dataset_uri=train_dataset,
87118
validation_dataset_uri=validation_dataset,
88119
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
89120
epoch_count=epochs,
121+
learning_rate=learning_rate,
90122
learning_rate_multiplier=learning_rate_multiplier,
91123
adapter_size=adapter_size_value,
92124
),
@@ -95,12 +127,18 @@ def train(
95127
if isinstance(source_model, generative_models.GenerativeModel):
96128
source_model = source_model._prediction_resource_name.rpartition("/")[-1]
97129

130+
if labels is None:
131+
labels = {}
132+
if "mg-source" not in labels:
133+
labels["mg-source"] = "sdk"
134+
98135
supervised_tuning_job = (
99136
SupervisedTuningJob._create( # pylint: disable=protected-access
100137
base_model=source_model,
101138
tuning_spec=supervised_tuning_spec,
102139
tuned_model_display_name=tuned_model_display_name,
103140
labels=labels,
141+
output_uri=output_uri,
104142
)
105143
)
106144
_ipython_utils.display_model_tuning_button(supervised_tuning_job)

vertexai/tuning/_tuning.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@
4343
_LOGGER = aiplatform_base.Logger(__name__)
4444

4545

46+
class SourceModel:
47+
r"""A model that is used in managed OSS supervised tuning.
48+
49+
Usage:
50+
```
51+
model = SourceModel(
52+
base_model="meta/llama3.1-8b", # OSS model name <publisher>/<model_name>
53+
custom_base_model="gs://user-bucket/custom-weights",
54+
)
55+
sft_tuning_job = sft.train(
56+
source_model=model,
57+
train_dataset="gs://my-bucket/train.jsonl",
58+
validation_dataset="gs://my-bucket/validation.jsonl",
59+
epochs=4,
60+
tuned_model_display_name="my-tuned-model",
61+
output_uri="gs://user-bucket/tuned-model"
62+
)
63+
64+
while not sft_tuning_job.has_ended:
65+
time.sleep(60)
66+
sft_tuning_job.refresh()
67+
68+
tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name)
69+
```
70+
"""
71+
72+
def __init__(
73+
self,
74+
base_model: str,
75+
custom_base_model: str = "",
76+
):
77+
r"""Initializes SourceModel."""
78+
self.base_model = base_model
79+
self.custom_base_model = custom_base_model
80+
81+
4682
class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
4783
_is_temporary = True
4884
_default_version = compat.V1BETA1
@@ -133,7 +169,7 @@ def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
133169
def _create(
134170
cls,
135171
*,
136-
base_model: str,
172+
base_model: Union[str, SourceModel],
137173
tuning_spec: Union[
138174
gca_tuning_job_types.SupervisedTuningSpec,
139175
gca_tuning_job_types.DistillationSpec,
@@ -144,15 +180,13 @@ def _create(
144180
project: Optional[str] = None,
145181
location: Optional[str] = None,
146182
credentials: Optional[auth_credentials.Credentials] = None,
183+
output_uri: Optional[str] = None,
147184
) -> "TuningJob":
148185
r"""Submits TuningJob.
149186
150187
Args:
151-
base_model (str):
152-
Model name for tuning, e.g., "gemini-1.0-pro"
153-
or "gemini-1.0-pro-001".
154-
155-
This field is a member of `oneof`_ ``source_model``.
188+
base_model: Model for tuning.
189+
Supported types: str, SourceModel.
156190
tuning_spec: Tuning Spec for Fine Tuning.
157191
Supported types: SupervisedTuningSpec, DistillationSpec.
158192
tuned_model_display_name: The display name of the
@@ -179,6 +213,7 @@ def _create(
179213
Overrides location set in aiplatform.init.
180214
credentials: Custom credentials to use to call tuning job service.
181215
Overrides credentials set in aiplatform.init.
216+
output_uri: The Google Cloud Storage location to write the artifacts. This is only used for OSS models.
182217
183218
Returns:
184219
Submitted TuningJob.
@@ -192,17 +227,25 @@ def _create(
192227
tuned_model_display_name = cls._generate_display_name()
193228

194229
gca_tuning_job = gca_tuning_job_types.TuningJob(
195-
base_model=base_model,
196230
tuned_model_display_name=tuned_model_display_name,
197231
description=description,
198232
labels=labels,
199-
# The tuning_spec one_of is set later
233+
# The tuning_spec one_of is set later.
234+
output_uri=output_uri,
200235
)
201236

202237
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
203238
gca_tuning_job.supervised_tuning_spec = tuning_spec
239+
if isinstance(base_model, SourceModel):
240+
gca_tuning_job.base_model = base_model.base_model
241+
gca_tuning_job.custom_base_model = base_model.custom_base_model
242+
else:
243+
gca_tuning_job.base_model = base_model
204244
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
205245
gca_tuning_job.distillation_spec = tuning_spec
246+
if isinstance(base_model, SourceModel):
247+
raise RuntimeError("Distillation is not supported for custom models.")
248+
gca_tuning_job.base_model = base_model
206249
else:
207250
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
208251

0 commit comments

Comments
 (0)