|
53 | 53 | _GEMINI_2_FLASH_EXP_ENDPOINT = "gemini-2.0-flash-exp" |
54 | 54 | _GEMINI_2_FLASH_001_ENDPOINT = "gemini-2.0-flash-001" |
55 | 55 | _GEMINI_2_FLASH_LITE_001_ENDPOINT = "gemini-2.0-flash-lite-001" |
| 56 | +_GEMINI_2P5_PRO_PREVIEW_ENDPOINT = "gemini-2.5-pro-preview-05-06" |
56 | 57 | _GEMINI_ENDPOINTS = ( |
57 | 58 | _GEMINI_1P5_PRO_PREVIEW_ENDPOINT, |
58 | 59 | _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, |
|
104 | 105 |
|
105 | 106 | _REMOVE_DEFAULT_MODEL_WARNING = "Since upgrading the default model can cause unintended breakages, the default model will be removed in BigFrames 3.0. Please supply an explicit model to avoid this message." |
106 | 107 |
|
| 108 | +_GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING = ( |
| 109 | + "The model '{model_name}' may not be fully supported by GeminiTextGenerator for Multimodal prompts. " |
| 110 | + "GeminiTextGenerator is known to support the following models for Multimodal prompts: {known_models}. " |
| 111 | + "If you proceed with '{model_name}', it might not work as expected or could lead to errors with multimodal inputs." |
| 112 | +) |
| 113 | + |
107 | 114 |
|
108 | 115 | @log_adapter.class_logger |
109 | 116 | class TextEmbeddingGenerator(base.RetriableRemotePredictor): |
@@ -540,9 +547,10 @@ def fit( |
540 | 547 | GeminiTextGenerator: Fitted estimator. |
541 | 548 | """ |
542 | 549 | if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: |
543 | | - raise NotImplementedError( |
| 550 | + msg = exceptions.format_message( |
544 | 551 | "fit() only supports gemini-1.5-pro-002, or gemini-1.5-flash-002 model." |
545 | 552 | ) |
| 553 | + warnings.warn(msg) |
546 | 554 |
|
547 | 555 | X, y = utils.batch_convert_to_dataframe(X, y) |
548 | 556 |
|
@@ -651,9 +659,13 @@ def predict( |
651 | 659 |
|
652 | 660 | if prompt: |
653 | 661 | if self.model_name not in _GEMINI_MULTIMODAL_ENDPOINTS: |
654 | | - raise NotImplementedError( |
655 | | - f"GeminiTextGenerator only supports model_name {', '.join(_GEMINI_MULTIMODAL_ENDPOINTS)} for Multimodal prompt." |
| 662 | + msg = exceptions.format_message( |
| 663 | + _GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING.format( |
| 664 | + model_name=self.model_name, |
| 665 | + known_models=", ".join(_GEMINI_MULTIMODAL_ENDPOINTS), |
| 666 | + ) |
656 | 667 | ) |
| 668 | + warnings.warn(msg) |
657 | 669 |
|
658 | 670 | df_prompt = X[[X.columns[0]]].rename( |
659 | 671 | columns={X.columns[0]: "bigframes_placeholder_col"} |
@@ -750,9 +762,10 @@ def score( |
750 | 762 | raise RuntimeError("A model must be fitted before score") |
751 | 763 |
|
752 | 764 | if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: |
753 | | - raise NotImplementedError( |
| 765 | + msg = exceptions.format_message( |
754 | 766 | "score() only supports gemini-1.5-pro-002, gemini-1.5-flash-2, gemini-2.0-flash-001, and gemini-2.0-flash-lite-001 model." |
755 | 767 | ) |
| 768 | + warnings.warn(msg) |
756 | 769 |
|
757 | 770 | X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) |
758 | 771 |
|
|
0 commit comments