Skip to content

Commit fdf088f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Support agent engine sandbox http request in genai sdk
PiperOrigin-RevId: 816865842
1 parent df0976e commit fdf088f

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
"opentelemetry-exporter-otlp-proto-http < 2",
169169
"pydantic >= 2.11.1, < 3",
170170
"typing_extensions",
171+
"google-cloud-iam",
171172
]
172173

173174
evaluation_extra_require = [

vertexai/_genai/sandboxes.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
import json
2020
import logging
2121
import mimetypes
22+
import secrets
23+
import time
2224
from typing import Any, Iterator, Optional, Union
2325
from urllib.parse import urlencode
2426

27+
from google import genai
28+
from google.cloud import iam_credentials_v1
2529
from google.genai import _api_module
2630
from google.genai import _common
31+
from google.genai import types as genai_types
2732
from google.genai._common import get_value_by_path as getv
2833
from google.genai._common import set_value_by_path as setv
2934
from google.genai.pagers import Pager
@@ -704,6 +709,161 @@ def delete(
704709
"""
705710
return self._delete(name=name, config=config)
706711

712+
def generate_access_token(
713+
self,
714+
service_account_email: str,
715+
sandbox_id: str,
716+
port: str = "8080",
717+
timeout: int = 3600,
718+
) -> str:
719+
"""Signs a JWT with a Google Cloud service account.
720+
721+
Args:
722+
service_account_email (str):
723+
Required. The email of the service account to use for signing.
724+
sandbox_id (str):
725+
Required. The resource name of the sandbox to generate a token for.
726+
port (str):
727+
Optional. The port to use for the token. Defaults to "8080".
728+
timeout (int):
729+
Optional. The timeout in seconds for the token. Defaults to 3600.
730+
731+
Returns:
732+
str: The signed JWT.
733+
"""
734+
client = iam_credentials_v1.IAMCredentialsClient()
735+
name = f"projects/-/serviceAccounts/{service_account_email}"
736+
custom_claims = {"port": port, "sandbox_id": sandbox_id}
737+
payload = {
738+
"iat": int(time.time()),
739+
"exp": int(time.time()) + timeout,
740+
"iss": service_account_email,
741+
"nonce": secrets.randbelow(1000000000) + 1,
742+
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
743+
**custom_claims,
744+
}
745+
request = iam_credentials_v1.SignJwtRequest(
746+
name=name,
747+
payload=json.dumps(payload),
748+
)
749+
response = client.sign_jwt(request=request)
750+
return response.signed_jwt
751+
752+
def send_command(
753+
self,
754+
*,
755+
http_method: str,
756+
access_token: str,
757+
sandbox_environment: types.SandboxEnvironment,
758+
path: str = None,
759+
query_params: Optional[dict[str, object]] = None,
760+
headers: Optional[dict[str, str]] = None,
761+
request_dict: Optional[dict[str, object]] = None,
762+
) -> genai_types.HttpResponse:
763+
"""Sends a command to the sandbox.
764+
765+
Args:
766+
http_method (str):
767+
Required. The HTTP method to use for the command.
768+
access_token (str):
769+
Required. The access token to use for authorization.
770+
sandbox_environment (types.SandboxEnvironment):
771+
Required. The sandbox environment to send the command to.
772+
path (str):
773+
Optional. The path to send the command to.
774+
query_params (dict[str, object]):
775+
Optional. The query parameters to include in the command.
776+
headers (dict[str, str]):
777+
Optional. The headers to include in the command.
778+
request_dict (dict[str, object]):
779+
Optional. The request body to include in the command.
780+
781+
Returns:
782+
genai_types.HttpResponse: The response from the sandbox.
783+
"""
784+
headers = headers or {}
785+
request_dict = request_dict or {}
786+
connection_info = sandbox_environment.connection_info
787+
if not connection_info:
788+
raise ValueError("Connection info is not available.")
789+
if connection_info.load_balancer_hostname:
790+
endpoint = "https://" + connection_info.load_balancer_hostname
791+
elif connection_info.load_balancer_ip:
792+
endpoint = "http://" + connection_info.load_balancer_ip
793+
else:
794+
raise ValueError("Load balancer hostname or ip is not available.")
795+
796+
path = path or ""
797+
if query_params:
798+
path = f"{path}?{urlencode(query_params)}"
799+
headers["Authorization"] = f"Bearer {access_token}"
800+
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
801+
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
802+
http_client = genai.Client(vertexai=True, http_options=http_options)
803+
# Full path is constructed in this function. The passed in path into request
804+
# function will not be used.
805+
response = http_client._api_client.request(http_method, path, request_dict)
806+
return genai_types.HttpResponse(
807+
headers=response.headers,
808+
body=response.body,
809+
)
810+
811+
def generate_browser_ws_headers(
812+
self,
813+
sandbox_environment: types.SandboxEnvironment,
814+
service_account_email: str,
815+
timeout: int = 3600,
816+
) -> tuple[str, dict[str, str]]:
817+
"""Generates the websocket upgrade headers for the browser.
818+
819+
Args:
820+
sandbox_environment (types.SandboxEnvironment):
821+
Required. The sandbox environment to generate websocket headers for.
822+
service_account_email (str):
823+
Required. The email of the service account to use for signing.
824+
timeout (int):
825+
Optional. The timeout in seconds for the token. Defaults to 3600.
826+
827+
Returns:
828+
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
829+
the headers for websocket upgrade.
830+
"""
831+
sandbox_id = sandbox_environment.name
832+
# port 8080 is the default port for http endpoint.
833+
http_access_token = self.generate_access_token(
834+
service_account_email, sandbox_id, "8080", timeout
835+
)
836+
response = self.send_command(
837+
http_method="GET",
838+
access_token=http_access_token,
839+
sandbox_environment=sandbox_environment,
840+
path="/cdp_ws_endpoint",
841+
)
842+
if not response:
843+
raise ValueError("Failed to get the websocket endpoint.")
844+
body_dict = json.loads(response.body)
845+
ws_path = body_dict["endpoint"]
846+
847+
ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
848+
if sandbox_environment and sandbox_environment.connection_info:
849+
connection_info = sandbox_environment.connection_info
850+
if connection_info.load_balancer_hostname:
851+
ws_url = "wss://" + connection_info.load_balancer_hostname
852+
elif connection_info.load_balancer_ip:
853+
ws_url = "ws://" + connection_info.load_balancer_ip
854+
else:
855+
raise ValueError("Load balancer hostname or ip is not available.")
856+
ws_url = ws_url + "/" + ws_path
857+
858+
# port 9222 is the default port for the browser websocket endpoint.
859+
ws_access_token = self.generate_access_token(
860+
service_account_email, sandbox_id, "9222", timeout
861+
)
862+
863+
headers = {}
864+
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
865+
return ws_url, headers
866+
707867

708868
class AsyncSandboxes(_api_module.BaseModule):
709869

0 commit comments

Comments
 (0)