|
19 | 19 | import json |
20 | 20 | import logging |
21 | 21 | import mimetypes |
| 22 | +import secrets |
| 23 | +import time |
22 | 24 | from typing import Any, Iterator, Optional, Union |
23 | 25 | from urllib.parse import urlencode |
24 | 26 |
|
| 27 | +from google import genai |
| 28 | +from google.cloud import iam_credentials_v1 |
25 | 29 | from google.genai import _api_module |
26 | 30 | from google.genai import _common |
| 31 | +from google.genai import types as genai_types |
27 | 32 | from google.genai._common import get_value_by_path as getv |
28 | 33 | from google.genai._common import set_value_by_path as setv |
29 | 34 | from google.genai.pagers import Pager |
@@ -704,6 +709,161 @@ def delete( |
704 | 709 | """ |
705 | 710 | return self._delete(name=name, config=config) |
706 | 711 |
|
| 712 | + def generate_access_token( |
| 713 | + self, |
| 714 | + service_account_email: str, |
| 715 | + sandbox_id: str, |
| 716 | + port: str = "8080", |
| 717 | + timeout: int = 3600, |
| 718 | + ) -> str: |
| 719 | + """Signs a JWT with a Google Cloud service account. |
| 720 | +
|
| 721 | + Args: |
| 722 | + service_account_email (str): |
| 723 | + Required. The email of the service account to use for signing. |
| 724 | + sandbox_id (str): |
| 725 | + Required. The resource name of the sandbox to generate a token for. |
| 726 | + port (str): |
| 727 | + Optional. The port to use for the token. Defaults to "8080". |
| 728 | + timeout (int): |
| 729 | + Optional. The timeout in seconds for the token. Defaults to 3600. |
| 730 | +
|
| 731 | + Returns: |
| 732 | + str: The signed JWT. |
| 733 | + """ |
| 734 | + client = iam_credentials_v1.IAMCredentialsClient() |
| 735 | + name = f"projects/-/serviceAccounts/{service_account_email}" |
| 736 | + custom_claims = {"port": port, "sandbox_id": sandbox_id} |
| 737 | + payload = { |
| 738 | + "iat": int(time.time()), |
| 739 | + "exp": int(time.time()) + timeout, |
| 740 | + "iss": service_account_email, |
| 741 | + "nonce": secrets.randbelow(1000000000) + 1, |
| 742 | + "aud": "vmaas-proxy-api", # default audience for sandbox proxy |
| 743 | + **custom_claims, |
| 744 | + } |
| 745 | + request = iam_credentials_v1.SignJwtRequest( |
| 746 | + name=name, |
| 747 | + payload=json.dumps(payload), |
| 748 | + ) |
| 749 | + response = client.sign_jwt(request=request) |
| 750 | + return response.signed_jwt |
| 751 | + |
| 752 | + def send_command( |
| 753 | + self, |
| 754 | + *, |
| 755 | + http_method: str, |
| 756 | + access_token: str, |
| 757 | + sandbox_environment: types.SandboxEnvironment, |
| 758 | + path: str = None, |
| 759 | + query_params: Optional[dict[str, object]] = None, |
| 760 | + headers: Optional[dict[str, str]] = None, |
| 761 | + request_dict: Optional[dict[str, object]] = None, |
| 762 | + ) -> genai_types.HttpResponse: |
| 763 | + """Sends a command to the sandbox. |
| 764 | +
|
| 765 | + Args: |
| 766 | + http_method (str): |
| 767 | + Required. The HTTP method to use for the command. |
| 768 | + access_token (str): |
| 769 | + Required. The access token to use for authorization. |
| 770 | + sandbox_environment (types.SandboxEnvironment): |
| 771 | + Required. The sandbox environment to send the command to. |
| 772 | + path (str): |
| 773 | + Optional. The path to send the command to. |
| 774 | + query_params (dict[str, object]): |
| 775 | + Optional. The query parameters to include in the command. |
| 776 | + headers (dict[str, str]): |
| 777 | + Optional. The headers to include in the command. |
| 778 | + request_dict (dict[str, object]): |
| 779 | + Optional. The request body to include in the command. |
| 780 | +
|
| 781 | + Returns: |
| 782 | + genai_types.HttpResponse: The response from the sandbox. |
| 783 | + """ |
| 784 | + headers = headers or {} |
| 785 | + request_dict = request_dict or {} |
| 786 | + connection_info = sandbox_environment.connection_info |
| 787 | + if not connection_info: |
| 788 | + raise ValueError("Connection info is not available.") |
| 789 | + if connection_info.load_balancer_hostname: |
| 790 | + endpoint = "https://" + connection_info.load_balancer_hostname |
| 791 | + elif connection_info.load_balancer_ip: |
| 792 | + endpoint = "http://" + connection_info.load_balancer_ip |
| 793 | + else: |
| 794 | + raise ValueError("Load balancer hostname or ip is not available.") |
| 795 | + |
| 796 | + path = path or "" |
| 797 | + if query_params: |
| 798 | + path = f"{path}?{urlencode(query_params)}" |
| 799 | + headers["Authorization"] = f"Bearer {access_token}" |
| 800 | + endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path |
| 801 | + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) |
| 802 | + http_client = genai.Client(vertexai=True, http_options=http_options) |
| 803 | + # Full path is constructed in this function. The passed in path into request |
| 804 | + # function will not be used. |
| 805 | + response = http_client._api_client.request(http_method, path, request_dict) |
| 806 | + return genai_types.HttpResponse( |
| 807 | + headers=response.headers, |
| 808 | + body=response.body, |
| 809 | + ) |
| 810 | + |
| 811 | + def generate_browser_ws_headers( |
| 812 | + self, |
| 813 | + sandbox_environment: types.SandboxEnvironment, |
| 814 | + service_account_email: str, |
| 815 | + timeout: int = 3600, |
| 816 | + ) -> tuple[str, dict[str, str]]: |
| 817 | + """Generates the websocket upgrade headers for the browser. |
| 818 | +
|
| 819 | + Args: |
| 820 | + sandbox_environment (types.SandboxEnvironment): |
| 821 | + Required. The sandbox environment to generate websocket headers for. |
| 822 | + service_account_email (str): |
| 823 | + Required. The email of the service account to use for signing. |
| 824 | + timeout (int): |
| 825 | + Optional. The timeout in seconds for the token. Defaults to 3600. |
| 826 | +
|
| 827 | + Returns: |
| 828 | + tuple[str, dict[str, str]]: A tuple containing the websocket URL and |
| 829 | + the headers for websocket upgrade. |
| 830 | + """ |
| 831 | + sandbox_id = sandbox_environment.name |
| 832 | + # port 8080 is the default port for http endpoint. |
| 833 | + http_access_token = self.generate_access_token( |
| 834 | + service_account_email, sandbox_id, "8080", timeout |
| 835 | + ) |
| 836 | + response = self.send_command( |
| 837 | + http_method="GET", |
| 838 | + access_token=http_access_token, |
| 839 | + sandbox_environment=sandbox_environment, |
| 840 | + path="/cdp_ws_endpoint", |
| 841 | + ) |
| 842 | + if not response: |
| 843 | + raise ValueError("Failed to get the websocket endpoint.") |
| 844 | + body_dict = json.loads(response.body) |
| 845 | + ws_path = body_dict["endpoint"] |
| 846 | + |
| 847 | + ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog" |
| 848 | + if sandbox_environment and sandbox_environment.connection_info: |
| 849 | + connection_info = sandbox_environment.connection_info |
| 850 | + if connection_info.load_balancer_hostname: |
| 851 | + ws_url = "wss://" + connection_info.load_balancer_hostname |
| 852 | + elif connection_info.load_balancer_ip: |
| 853 | + ws_url = "ws://" + connection_info.load_balancer_ip |
| 854 | + else: |
| 855 | + raise ValueError("Load balancer hostname or ip is not available.") |
| 856 | + ws_url = ws_url + "/" + ws_path |
| 857 | + |
| 858 | + # port 9222 is the default port for the browser websocket endpoint. |
| 859 | + ws_access_token = self.generate_access_token( |
| 860 | + service_account_email, sandbox_id, "9222", timeout |
| 861 | + ) |
| 862 | + |
| 863 | + headers = {} |
| 864 | + headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}" |
| 865 | + return ws_url, headers |
| 866 | + |
707 | 867 |
|
708 | 868 | class AsyncSandboxes(_api_module.BaseModule): |
709 | 869 |
|
|
0 commit comments