From 2cca131ba9da1243156186c930a2782317f9e681 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 23 Feb 2026 07:13:26 -0800 Subject: [PATCH] fix: add safeguards and warnings for remote code execution during pickle-based model deserialization PiperOrigin-RevId: 874058205 --- google/cloud/aiplatform/metadata/_models.py | 8 ++ .../cloud/aiplatform/prediction/predictor.py | 5 +- .../prediction/sklearn/predictor.py | 31 +++++++- .../prediction/xgboost/predictor.py | 46 +++++++++-- .../aiplatform/utils/prediction_utils.py | 8 ++ .../aiplatform/test_prediction_security.py | 78 +++++++++++++++++++ 6 files changed, 167 insertions(+), 9 deletions(-) create mode 100644 tests/unit/aiplatform/test_prediction_security.py diff --git a/google/cloud/aiplatform/metadata/_models.py b/google/cloud/aiplatform/metadata/_models.py index c5f8fecf07..8a6a2caa25 100644 --- a/google/cloud/aiplatform/metadata/_models.py +++ b/google/cloud/aiplatform/metadata/_models.py @@ -19,6 +19,7 @@ import os import pickle import tempfile +import warnings from typing import Any, Dict, Optional, Sequence, Union from google.auth import credentials as auth_credentials @@ -147,6 +148,13 @@ def _load_sklearn_model( f"You are using sklearn {sklearn.__version__}." "Attempting to load model..." ) + + warnings.warn( + "Loading a scikit-learn model via pickle is insecure. " + "Ensure the model artifact is from a trusted source.", + RuntimeWarning + ) + with open(model_file, "rb") as f: sk_model = pickle.load(f) diff --git a/google/cloud/aiplatform/prediction/predictor.py b/google/cloud/aiplatform/prediction/predictor.py index 0560b97d30..996ffd2150 100644 --- a/google/cloud/aiplatform/prediction/predictor.py +++ b/google/cloud/aiplatform/prediction/predictor.py @@ -40,12 +40,15 @@ def __init__(self): return @abstractmethod - def load(self, artifacts_uri: str) -> None: + def load(self, artifacts_uri: str, **kwargs) -> None: """Loads the model artifact. Args: artifacts_uri (str): Required. The value of the environment variable AIP_STORAGE_URI. + **kwargs: + Optional. Additional keyword arguments for security or + configuration (e.g., allowed_extensions). """ pass diff --git a/google/cloud/aiplatform/prediction/sklearn/predictor.py b/google/cloud/aiplatform/prediction/sklearn/predictor.py index e5f3d96497..0aafb24091 100644 --- a/google/cloud/aiplatform/prediction/sklearn/predictor.py +++ b/google/cloud/aiplatform/prediction/sklearn/predictor.py @@ -19,6 +19,7 @@ import numpy as np import os import pickle +import warnings from google.cloud.aiplatform.constants import prediction from google.cloud.aiplatform.utils import prediction_utils @@ -31,21 +32,45 @@ class SklearnPredictor(Predictor): def __init__(self): return - def load(self, artifacts_uri: str) -> None: + def load(self, artifacts_uri: str, **kwargs) -> None: """Loads the model artifact. Args: artifacts_uri (str): Required. The value of the environment variable AIP_STORAGE_URI. + **kwargs: + Optional. Additional keyword arguments for security or + configuration. Supported arguments: + allowed_extensions (list[str]): + The allowed file extensions for model artifacts. + If not provided, a UserWarning is issued. Raises: ValueError: If there's no required model files provided in the artifacts uri. """ + + allowed_extensions = kwargs.get("allowed_extensions", None) + + if allowed_extensions is None: + warnings.warn( + "No 'allowed_extensions' provided. Loading model artifacts from " + "untrusted sources may lead to remote code execution.", + UserWarning + ) + prediction_utils.download_model_artifacts(artifacts_uri) if os.path.exists(prediction.MODEL_FILENAME_JOBLIB): self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists(prediction.MODEL_FILENAME_PKL): + elif os.path.exists(prediction.MODEL_FILENAME_PKL) and prediction_utils.is_allowed( + filename=prediction.MODEL_FILENAME_PKL, + allowed_extensions=allowed_extensions + ): + warnings.warn( + f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " + "Only load files from trusted sources.", + RuntimeWarning + ) self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) else: valid_filenames = [ @@ -53,7 +78,7 @@ def load(self, artifacts_uri: str) -> None: prediction.MODEL_FILENAME_PKL, ] raise ValueError( - f"One of the following model files must be provided: {valid_filenames}." + f"One of the following model files must be provided and allowed: {valid_filenames}." ) def preprocess(self, prediction_input: dict) -> np.ndarray: diff --git a/google/cloud/aiplatform/prediction/xgboost/predictor.py b/google/cloud/aiplatform/prediction/xgboost/predictor.py index 005efcb129..a411335c5f 100644 --- a/google/cloud/aiplatform/prediction/xgboost/predictor.py +++ b/google/cloud/aiplatform/prediction/xgboost/predictor.py @@ -19,6 +19,7 @@ import logging import os import pickle +import warnings import numpy as np import xgboost as xgb @@ -34,21 +35,48 @@ class XgboostPredictor(Predictor): def __init__(self): return - def load(self, artifacts_uri: str) -> None: + def load(self, artifacts_uri: str, **kwargs) -> None: """Loads the model artifact. Args: artifacts_uri (str): Required. The value of the environment variable AIP_STORAGE_URI. + **kwargs: + Optional. Additional keyword arguments for security or + configuration. Supported arguments: + allowed_extensions (list[str]): + The allowed file extensions for model artifacts. + If not provided, a UserWarning is issued. Raises: ValueError: If there's no required model files provided in the artifacts uri. """ + allowed_extensions = kwargs.get("allowed_extensions", None) + + if allowed_extensions is None: + warnings.warn( + "No 'allowed_extensions' provided. Loading model artifacts from " + "untrusted sources may lead to remote code execution.", + UserWarning, + ) + prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists(prediction.MODEL_FILENAME_BST): + + if os.path.exists(prediction.MODEL_FILENAME_BST) and prediction_utils.is_allowed( + filename=prediction.MODEL_FILENAME_BST, + allowed_extensions=allowed_extensions + ): booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) - elif os.path.exists(prediction.MODEL_FILENAME_JOBLIB): + elif os.path.exists(prediction.MODEL_FILENAME_JOBLIB) and prediction_utils.is_allowed( + filename=prediction.MODEL_FILENAME_JOBLIB, + allowed_extensions=allowed_extensions + ): + warnings.warn( + f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib (pickle), which is unsafe. " + "Only load files from trusted sources.", + RuntimeWarning, + ) try: booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB) except KeyError: @@ -58,7 +86,15 @@ def load(self, artifacts_uri: str) -> None: ) booster = xgb.Booster() booster.load_model(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists(prediction.MODEL_FILENAME_PKL): + elif os.path.exists(prediction.MODEL_FILENAME_PKL) and prediction_utils.is_allowed( + filename=prediction.MODEL_FILENAME_PKL, + allowed_extensions=allowed_extensions + ): + warnings.warn( + f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " + "Only load files from trusted sources.", + RuntimeWarning, + ) booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) else: valid_filenames = [ @@ -67,7 +103,7 @@ def load(self, artifacts_uri: str) -> None: prediction.MODEL_FILENAME_PKL, ] raise ValueError( - f"One of the following model files must be provided: {valid_filenames}." + f"One of the following model files must be provided and allowed: {valid_filenames}." ) self._booster = booster diff --git a/google/cloud/aiplatform/utils/prediction_utils.py b/google/cloud/aiplatform/utils/prediction_utils.py index 6e71e9dcb8..0e6673e405 100644 --- a/google/cloud/aiplatform/utils/prediction_utils.py +++ b/google/cloud/aiplatform/utils/prediction_utils.py @@ -174,3 +174,11 @@ def add_flex_start_to_dedicated_resources( dedicated_resources.flex_start = gca_machine_resources_compat.FlexStart( max_runtime_duration=duration_pb2.Duration(seconds=max_runtime_duration) ) + + +def is_allowed( + filename: str, allowed_extensions: Optional[list[str]] +) -> bool: + if allowed_extensions is None: + return True + return any(filename.endswith(ext) for ext in allowed_extensions) diff --git a/tests/unit/aiplatform/test_prediction_security.py b/tests/unit/aiplatform/test_prediction_security.py new file mode 100644 index 0000000000..f1f891d0a4 --- /dev/null +++ b/tests/unit/aiplatform/test_prediction_security.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from unittest import mock +from google.cloud.aiplatform.prediction.xgboost.predictor import XgboostPredictor +from google.cloud.aiplatform.prediction.sklearn.predictor import SklearnPredictor + + +class TestPredictorSecurity: + @pytest.mark.parametrize("predictor_class", [XgboostPredictor, SklearnPredictor]) + def test_load_warns_no_allowed_extensions(self, predictor_class): + """Verifies UserWarning is issued when allowed_extensions is missing.""" + predictor = predictor_class() + with mock.patch("google.cloud.aiplatform.aiplatform.utils.prediction_utils.download_model_artifacts"): + with mock.patch("os.path.exists", return_value=True): + with mock.patch("joblib.load"), \ + mock.patch("pickle.load"), \ + mock.patch("google.cloud.aiplatform.aiplatform.prediction.xgboost.predictor.xgb.Booster"), \ + mock.patch("builtins.open", mock.mock_open()): + with pytest.warns(UserWarning, match="No 'allowed_extensions' provided"): + predictor.load("gs://test-bucket") + + def test_xgboost_load_warns_on_joblib(self): + """Verifies RuntimeWarning is issued when loading a .joblib file.""" + predictor = XgboostPredictor() + with mock.patch( + "google.cloud.aiplatform.aiplatform.utils.prediction_utils.download_model_artifacts" + ): + with mock.patch( + "os.path.exists", side_effect=lambda p: p.endswith(".joblib") + ): + with mock.patch("joblib.load"): + with pytest.warns( + RuntimeWarning, match="using joblib \(pickle\), which is unsafe" + ): + predictor.load("gs://test-bucket", allowed_extensions=[".joblib"]) + + def test_xgboost_load_raises_not_allowed(self): + """Verifies ValueError is raised if the file exists but is not allowed.""" + predictor = XgboostPredictor() + with mock.patch( + "google.cloud.aiplatform.aiplatform.utils.prediction_utils.download_model_artifacts" + ): + with mock.patch("google.cloud.aiplatform.aiplatform.prediction.xgboost.predictor.xgb.Booster"): + with mock.patch("os.path.exists", side_effect=lambda p: p.endswith(".pkl")): + with pytest.raises(ValueError, match="must be provided and allowed"): + predictor.load("gs://test-bucket", allowed_extensions=[".bst"]) + + def test_sklearn_load_warns_on_pickle(self): + """Verifies RuntimeWarning is issued when loading a .pkl file.""" + predictor = SklearnPredictor() + with mock.patch( + "google.cloud.aiplatform.aiplatform.utils.prediction_utils.download_model_artifacts" + ): + with mock.patch("os.path.exists", side_effect=lambda p: p.endswith(".pkl")): + with mock.patch("builtins.open", mock.mock_open()): + with mock.patch("pickle.load"): + with pytest.warns( + RuntimeWarning, match="using pickle, which is unsafe" + ): + predictor.load( + "gs://test-bucket", allowed_extensions=[".pkl"] + )