|
1 | 1 | import base64 |
| 2 | +from cryptography import x509 |
| 3 | +from cryptography.hazmat.backends import default_backend |
| 4 | +from cryptography.x509.oid import NameOID |
2 | 5 | import datetime |
3 | 6 | from decimal import Decimal |
4 | 7 | import logging |
@@ -46,6 +49,13 @@ def __init__(self, server_hostname, http_path, access_token, metadata=None, **kw |
46 | 49 | # Which port to connect to |
47 | 50 | # _skip_routing_headers: |
48 | 51 | # Don't set routing headers if set to True (for use when connecting directly to server) |
| 52 | + # _tls_verify_hostname |
| 53 | + # Set to False (Boolean) to disable SSL hostname verification, but check certificate. |
| 54 | + # _tls_trusted_ca_file |
| 55 | + # Set to the path of the file containing trusted CA certificates for server certificate |
| 56 | + # verification. If not provide, uses system truststore. |
| 57 | + # _tls_client_cert_file, _tls_client_cert_key_file |
| 58 | + # Set client SSL certificate. |
49 | 59 |
|
50 | 60 | self.host = server_hostname |
51 | 61 | self.port = kwargs.get("_port", 443) |
@@ -76,7 +86,11 @@ def __init__(self, server_hostname, http_path, access_token, metadata=None, **kw |
76 | 86 | self.base_client = CmdExecBaseHttpClient( |
77 | 87 | self.host, |
78 | 88 | self.port, (metadata or []) + base_headers, |
79 | | - enable_ssl=kwargs.get("_enable_ssl", True)) |
| 89 | + enable_ssl=kwargs.get("_enable_ssl", True), |
| 90 | + root_ca_path=kwargs.get("_tls_trusted_ca_file"), |
| 91 | + cert_chain_path=kwargs.get("_tls_client_cert_file"), |
| 92 | + cert_key_path=kwargs.get("_tls_client_cert_key_file"), |
| 93 | + verify_hostname=kwargs.get("_tls_verify_hostname", True)) |
80 | 94 |
|
81 | 95 | open_session_request = messages_pb2.OpenSessionRequest( |
82 | 96 | configuration={}, |
@@ -492,14 +506,62 @@ class CmdExecBaseHttpClient: |
492 | 506 | A thin wrapper around a gRPC channel that takes cares of headers etc. |
493 | 507 | """ |
494 | 508 |
|
495 | | - def __init__(self, host: str, port: int, http_headers: List[Tuple[str, str]], enable_ssl=True): |
| 509 | + def __init__(self, |
| 510 | + host: str, |
| 511 | + port: int, |
| 512 | + http_headers: List[Tuple[str, str]], |
| 513 | + enable_ssl=True, |
| 514 | + root_ca_path=None, |
| 515 | + cert_chain_path=None, |
| 516 | + cert_key_path=None, |
| 517 | + verify_hostname=True): |
496 | 518 | self.host_url = host + ":" + str(port) |
497 | | - self.http_headers = [(k.lower(), v) for (k, v) in http_headers] |
| 519 | + self.http_headers = [(k.lower(), str(v)) for (k, v) in http_headers] |
498 | 520 | if enable_ssl: |
| 521 | + if root_ca_path: |
| 522 | + try: |
| 523 | + with open(root_ca_path, 'rb') as f: |
| 524 | + root_ca = f.read() |
| 525 | + except OSError as e: |
| 526 | + raise OperationalError( |
| 527 | + "Error while trying to read root SSL certificate %s:" % root_ca_path, e) |
| 528 | + else: |
| 529 | + root_ca = None |
| 530 | + |
| 531 | + if cert_chain_path: |
| 532 | + try: |
| 533 | + with open(cert_chain_path, 'rb') as f: |
| 534 | + cert_chain = f.read() |
| 535 | + except OSError as e: |
| 536 | + raise OperationalError( |
| 537 | + "Error while trying to read SSL certificate chain %s:" % cert_chain, e) |
| 538 | + else: |
| 539 | + cert_chain = None |
| 540 | + |
| 541 | + if cert_key_path: |
| 542 | + try: |
| 543 | + with open(cert_chain_path, 'rb') as f: |
| 544 | + cert_key = f.read() |
| 545 | + except OSError as e: |
| 546 | + raise OperationalError( |
| 547 | + "Error while trying to read SSL certificate key %s:" % cert_key_path, e) |
| 548 | + else: |
| 549 | + cert_key = None |
| 550 | + |
| 551 | + if not verify_hostname and root_ca: |
| 552 | + # gRPC doesn't have a flag that lets us completely disable the cn name check, |
| 553 | + # so we just set the target name override so they match. |
| 554 | + cert_info = x509.load_pem_x509_certificate(root_ca, default_backend()) |
| 555 | + cn = cert_info.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value |
| 556 | + target_name_override_opt = [('grpc.ssl_target_name_override', cn)] |
| 557 | + else: |
| 558 | + target_name_override_opt = [] |
| 559 | + |
499 | 560 | self.channel = grpc.secure_channel( |
500 | 561 | self.host_url, |
501 | | - options=[('grpc.max_receive_message_length', -1)], |
502 | | - credentials=grpc.ssl_channel_credentials()) |
| 562 | + options=[('grpc.max_receive_message_length', -1)] + target_name_override_opt, |
| 563 | + credentials=grpc.ssl_channel_credentials( |
| 564 | + root_certificates=root_ca, certificate_chain=cert_chain, private_key=cert_key)) |
503 | 565 | else: |
504 | 566 | self.channel = grpc.insecure_channel( |
505 | 567 | self.host_url, |
|
0 commit comments