Skip to content

Commit 8b9b190

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Support agent engine sandbox http request in genai sdk
PiperOrigin-RevId: 816865842
1 parent 2c362b3 commit 8b9b190

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
"opentelemetry-exporter-otlp-proto-http < 2",
169169
"pydantic >= 2.11.1, < 3",
170170
"typing_extensions",
171+
"google-cloud-iam",
171172
]
172173

173174
evaluation_extra_require = [

vertexai/_genai/sandboxes.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
import json
2020
import logging
2121
import mimetypes
22+
import secrets
23+
import time
2224
from typing import Any, Iterator, Optional, Union
2325
from urllib.parse import urlencode
2426

27+
from google import genai
28+
from google.cloud import iam_credentials_v1
2529
from google.genai import _api_module
2630
from google.genai import _common
31+
from google.genai import types as genai_types
2732
from google.genai._common import get_value_by_path as getv
2833
from google.genai._common import set_value_by_path as setv
2934
from google.genai.pagers import Pager
@@ -704,6 +709,113 @@ def delete(
704709
"""
705710
return self._delete(name=name, config=config)
706711

712+
def generate_access_token(
713+
self,
714+
service_account_email: str,
715+
sandbox_id: str,
716+
port: str = "8080",
717+
timeout: int = 3600,
718+
) -> str:
719+
"""Signs a JWT with a Google Cloud service account."""
720+
client = iam_credentials_v1.IAMCredentialsClient()
721+
name = f"projects/-/serviceAccounts/{service_account_email}"
722+
custom_claims = {"port": port, "sandbox_id": sandbox_id}
723+
payload = {
724+
"iat": int(time.time()),
725+
"exp": int(time.time()) + timeout,
726+
"iss": service_account_email,
727+
"nonce": secrets.randbelow(1000000000) + 1,
728+
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
729+
**custom_claims,
730+
}
731+
request = iam_credentials_v1.SignJwtRequest(
732+
name=name,
733+
payload=json.dumps(payload),
734+
)
735+
response = client.sign_jwt(request=request)
736+
return response.signed_jwt
737+
738+
def send_command(
739+
self,
740+
http_method: str,
741+
access_token: str,
742+
sandbox_environment: types.SandboxEnvironment,
743+
path: str = None,
744+
query_params: Optional[dict[str, object]] = None,
745+
headers: Optional[dict[str, str]] = None,
746+
request_dict: Optional[dict[str, object]] = None,
747+
) -> genai_types.HttpResponse:
748+
"""Sends a command to the sandbox."""
749+
headers = headers or {}
750+
request_dict = request_dict or {}
751+
connection_info = sandbox_environment.connection_info
752+
if not connection_info:
753+
raise ValueError("Connection info is not available.")
754+
if connection_info.load_balancer_hostname:
755+
endpoint = "https://" + connection_info.load_balancer_hostname
756+
elif connection_info.load_balancer_ip:
757+
endpoint = "http://" + connection_info.load_balancer_ip
758+
else:
759+
raise ValueError("Load balancer hostname or ip is not available.")
760+
761+
path = path or ""
762+
if query_params:
763+
path = f"{path}?{urlencode(query_params)}"
764+
headers["Authorization"] = f"Bearer {access_token}"
765+
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
766+
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
767+
http_client = genai.Client(vertexai=True, http_options=http_options)
768+
# Full path is constructed in this function. The passed in path into request
769+
# function will not be used.
770+
response = http_client._api_client.request(http_method, path, request_dict)
771+
return genai_types.HttpResponse(
772+
headers=response.headers,
773+
body=response.body,
774+
)
775+
776+
def generate_browser_ws_headers(
777+
self,
778+
sandbox_environment: types.SandboxEnvironment,
779+
service_account_email: str,
780+
timeout: int = 3600,
781+
) -> tuple[str, dict[str, str]]:
782+
"""Generates the websocket upgrade headers for the browser."""
783+
sandbox_id = sandbox_environment.name
784+
# port 8080 is the default port for http endpoint.
785+
http_access_token = self.generate_access_token(
786+
service_account_email, sandbox_id, "8080", timeout
787+
)
788+
response = self.send_command(
789+
"GET",
790+
http_access_token,
791+
sandbox_environment,
792+
path="/cdp_ws_endpoint",
793+
)
794+
if not response:
795+
raise ValueError("Failed to get the websocket endpoint.")
796+
body_dict = json.loads(response.body)
797+
ws_path = body_dict["endpoint"]
798+
799+
ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
800+
if sandbox_environment and sandbox_environment.connection_info:
801+
connection_info = sandbox_environment.connection_info
802+
if connection_info.load_balancer_hostname:
803+
ws_url = "wss://" + connection_info.load_balancer_hostname
804+
elif connection_info.load_balancer_ip:
805+
ws_url = "ws://" + connection_info.load_balancer_ip
806+
else:
807+
raise ValueError("Load balancer hostname or ip is not available.")
808+
ws_url = ws_url + "/" + ws_path
809+
810+
# port 9222 is the default port for the browser websocket endpoint.
811+
ws_access_token = self.generate_access_token(
812+
service_account_email, sandbox_id, "9222", timeout
813+
)
814+
815+
headers = {}
816+
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
817+
return ws_url, headers
818+
707819

708820
class AsyncSandboxes(_api_module.BaseModule):
709821

0 commit comments

Comments
 (0)