Skip to content

Commit b0afc6b

Browse files
Niall Egansusodapop
authored andcommitted
SSL args for Cmd exec Python client
This PR adds the SSL args for using the Python client (setting the various certs etc). Author: Niall Egan <niall.egan@databricks.com>
1 parent 1151736 commit b0afc6b

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

cmdexec/clients/python/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
install_requires=[
99
"grpcio", # TODO: Minimum versions
1010
"pyarrow",
11-
"protobuf"
11+
"protobuf",
12+
"cryptography",
1213
],
1314
author="Databricks",
1415
)

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import base64
2+
from cryptography import x509
3+
from cryptography.hazmat.backends import default_backend
4+
from cryptography.x509.oid import NameOID
25
import datetime
36
from decimal import Decimal
47
import logging
@@ -46,6 +49,13 @@ def __init__(self, server_hostname, http_path, access_token, metadata=None, **kw
4649
# Which port to connect to
4750
# _skip_routing_headers:
4851
# Don't set routing headers if set to True (for use when connecting directly to server)
52+
# _tls_verify_hostname
53+
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
54+
# _tls_trusted_ca_file
55+
# Set to the path of the file containing trusted CA certificates for server certificate
56+
# verification. If not provide, uses system truststore.
57+
# _tls_client_cert_file, _tls_client_cert_key_file
58+
# Set client SSL certificate.
4959

5060
self.host = server_hostname
5161
self.port = kwargs.get("_port", 443)
@@ -76,7 +86,11 @@ def __init__(self, server_hostname, http_path, access_token, metadata=None, **kw
7686
self.base_client = CmdExecBaseHttpClient(
7787
self.host,
7888
self.port, (metadata or []) + base_headers,
79-
enable_ssl=kwargs.get("_enable_ssl", True))
89+
enable_ssl=kwargs.get("_enable_ssl", True),
90+
root_ca_path=kwargs.get("_tls_trusted_ca_file"),
91+
cert_chain_path=kwargs.get("_tls_client_cert_file"),
92+
cert_key_path=kwargs.get("_tls_client_cert_key_file"),
93+
verify_hostname=kwargs.get("_tls_verify_hostname", True))
8094

8195
open_session_request = messages_pb2.OpenSessionRequest(
8296
configuration={},
@@ -492,14 +506,62 @@ class CmdExecBaseHttpClient:
492506
A thin wrapper around a gRPC channel that takes cares of headers etc.
493507
"""
494508

495-
def __init__(self, host: str, port: int, http_headers: List[Tuple[str, str]], enable_ssl=True):
509+
def __init__(self,
510+
host: str,
511+
port: int,
512+
http_headers: List[Tuple[str, str]],
513+
enable_ssl=True,
514+
root_ca_path=None,
515+
cert_chain_path=None,
516+
cert_key_path=None,
517+
verify_hostname=True):
496518
self.host_url = host + ":" + str(port)
497-
self.http_headers = [(k.lower(), v) for (k, v) in http_headers]
519+
self.http_headers = [(k.lower(), str(v)) for (k, v) in http_headers]
498520
if enable_ssl:
521+
if root_ca_path:
522+
try:
523+
with open(root_ca_path, 'rb') as f:
524+
root_ca = f.read()
525+
except OSError as e:
526+
raise OperationalError(
527+
"Error while trying to read root SSL certificate %s:" % root_ca_path, e)
528+
else:
529+
root_ca = None
530+
531+
if cert_chain_path:
532+
try:
533+
with open(cert_chain_path, 'rb') as f:
534+
cert_chain = f.read()
535+
except OSError as e:
536+
raise OperationalError(
537+
"Error while trying to read SSL certificate chain %s:" % cert_chain, e)
538+
else:
539+
cert_chain = None
540+
541+
if cert_key_path:
542+
try:
543+
with open(cert_chain_path, 'rb') as f:
544+
cert_key = f.read()
545+
except OSError as e:
546+
raise OperationalError(
547+
"Error while trying to read SSL certificate key %s:" % cert_key_path, e)
548+
else:
549+
cert_key = None
550+
551+
if not verify_hostname and root_ca:
552+
# gRPC doesn't have a flag that lets us completely disable the cn name check,
553+
# so we just set the target name override so they match.
554+
cert_info = x509.load_pem_x509_certificate(root_ca, default_backend())
555+
cn = cert_info.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
556+
target_name_override_opt = [('grpc.ssl_target_name_override', cn)]
557+
else:
558+
target_name_override_opt = []
559+
499560
self.channel = grpc.secure_channel(
500561
self.host_url,
501-
options=[('grpc.max_receive_message_length', -1)],
502-
credentials=grpc.ssl_channel_credentials())
562+
options=[('grpc.max_receive_message_length', -1)] + target_name_override_opt,
563+
credentials=grpc.ssl_channel_credentials(
564+
root_certificates=root_ca, certificate_chain=cert_chain, private_key=cert_key))
503565
else:
504566
self.channel = grpc.insecure_channel(
505567
self.host_url,

0 commit comments

Comments
 (0)