|
19 | 19 | import json |
20 | 20 | import logging |
21 | 21 | import mimetypes |
| 22 | +import secrets |
| 23 | +import time |
22 | 24 | from typing import Any, Iterator, Optional, Union |
23 | 25 | from urllib.parse import urlencode |
24 | 26 |
|
| 27 | +from google import genai |
| 28 | +from google.cloud import iam_credentials_v1 |
25 | 29 | from google.genai import _api_module |
26 | 30 | from google.genai import _common |
| 31 | +from google.genai import types as genai_types |
27 | 32 | from google.genai._common import get_value_by_path as getv |
28 | 33 | from google.genai._common import set_value_by_path as setv |
29 | 34 | from google.genai.pagers import Pager |
@@ -704,6 +709,113 @@ def delete( |
704 | 709 | """ |
705 | 710 | return self._delete(name=name, config=config) |
706 | 711 |
|
| 712 | + def generate_access_token( |
| 713 | + self, |
| 714 | + service_account_email: str, |
| 715 | + sandbox_id: str, |
| 716 | + port: str = "8080", |
| 717 | + timeout: int = 3600, |
| 718 | + ) -> str: |
| 719 | + """Signs a JWT with a Google Cloud service account.""" |
| 720 | + client = iam_credentials_v1.IAMCredentialsClient() |
| 721 | + name = f"projects/-/serviceAccounts/{service_account_email}" |
| 722 | + custom_claims = {"port": port, "sandbox_id": sandbox_id} |
| 723 | + payload = { |
| 724 | + "iat": int(time.time()), |
| 725 | + "exp": int(time.time()) + timeout, |
| 726 | + "iss": service_account_email, |
| 727 | + "nonce": secrets.randbelow(1000000000) + 1, |
| 728 | + "aud": "vmaas-proxy-api", # default audience for sandbox proxy |
| 729 | + **custom_claims, |
| 730 | + } |
| 731 | + request = iam_credentials_v1.SignJwtRequest( |
| 732 | + name=name, |
| 733 | + payload=json.dumps(payload), |
| 734 | + ) |
| 735 | + response = client.sign_jwt(request=request) |
| 736 | + return response.signed_jwt |
| 737 | + |
| 738 | + def send_command( |
| 739 | + self, |
| 740 | + http_method: str, |
| 741 | + access_token: str, |
| 742 | + sandbox_environment: types.SandboxEnvironment, |
| 743 | + path: str = None, |
| 744 | + query_params: Optional[dict[str, object]] = None, |
| 745 | + headers: Optional[dict[str, str]] = None, |
| 746 | + request_dict: Optional[dict[str, object]] = None, |
| 747 | + ) -> genai_types.HttpResponse: |
| 748 | + """Sends a command to the sandbox.""" |
| 749 | + headers = headers or {} |
| 750 | + request_dict = request_dict or {} |
| 751 | + connection_info = sandbox_environment.connection_info |
| 752 | + if not connection_info: |
| 753 | + raise ValueError("Connection info is not available.") |
| 754 | + if connection_info.load_balancer_hostname: |
| 755 | + endpoint = "https://" + connection_info.load_balancer_hostname |
| 756 | + elif connection_info.load_balancer_ip: |
| 757 | + endpoint = "http://" + connection_info.load_balancer_ip |
| 758 | + else: |
| 759 | + raise ValueError("Load balancer hostname or ip is not available.") |
| 760 | + |
| 761 | + path = path or "" |
| 762 | + if query_params: |
| 763 | + path = f"{path}?{urlencode(query_params)}" |
| 764 | + headers["Authorization"] = f"Bearer {access_token}" |
| 765 | + endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path |
| 766 | + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) |
| 767 | + http_client = genai.Client(vertexai=True, http_options=http_options) |
| 768 | + # Full path is constructed in this function. The passed in path into request |
| 769 | + # function will not be used. |
| 770 | + response = http_client._api_client.request(http_method, path, request_dict) |
| 771 | + return genai_types.HttpResponse( |
| 772 | + headers=response.headers, |
| 773 | + body=response.body, |
| 774 | + ) |
| 775 | + |
| 776 | + def generate_browser_ws_headers( |
| 777 | + self, |
| 778 | + sandbox_environment: types.SandboxEnvironment, |
| 779 | + service_account_email: str, |
| 780 | + timeout: int = 3600, |
| 781 | + ) -> tuple[str, dict[str, str]]: |
| 782 | + """Generates the websocket upgrade headers for the browser.""" |
| 783 | + sandbox_id = sandbox_environment.name |
| 784 | + # port 8080 is the default port for http endpoint. |
| 785 | + http_access_token = self.generate_access_token( |
| 786 | + service_account_email, sandbox_id, "8080", timeout |
| 787 | + ) |
| 788 | + response = self.send_command( |
| 789 | + "GET", |
| 790 | + http_access_token, |
| 791 | + sandbox_environment, |
| 792 | + path="/cdp_ws_endpoint", |
| 793 | + ) |
| 794 | + if not response: |
| 795 | + raise ValueError("Failed to get the websocket endpoint.") |
| 796 | + body_dict = json.loads(response.body) |
| 797 | + ws_path = body_dict["endpoint"] |
| 798 | + |
| 799 | + ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog" |
| 800 | + if sandbox_environment and sandbox_environment.connection_info: |
| 801 | + connection_info = sandbox_environment.connection_info |
| 802 | + if connection_info.load_balancer_hostname: |
| 803 | + ws_url = "wss://" + connection_info.load_balancer_hostname |
| 804 | + elif connection_info.load_balancer_ip: |
| 805 | + ws_url = "ws://" + connection_info.load_balancer_ip |
| 806 | + else: |
| 807 | + raise ValueError("Load balancer hostname or ip is not available.") |
| 808 | + ws_url = ws_url + "/" + ws_path |
| 809 | + |
| 810 | + # port 9222 is the default port for the browser websocket endpoint. |
| 811 | + ws_access_token = self.generate_access_token( |
| 812 | + service_account_email, sandbox_id, "9222", timeout |
| 813 | + ) |
| 814 | + |
| 815 | + headers = {} |
| 816 | + headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}" |
| 817 | + return ws_url, headers |
| 818 | + |
707 | 819 |
|
708 | 820 | class AsyncSandboxes(_api_module.BaseModule): |
709 | 821 |
|
|
0 commit comments