@@ -81,11 +81,97 @@ class ClientTestSuite(unittest.TestCase):
8181 "access_token" : "tok" ,
8282 }
8383
84- @patch (
85- "%s.session.ThriftDatabricksClient" % PACKAGE_NAME ,
86- ThriftDatabricksClientMockFactory .new (),
87- )
88- def test_closing_connection_closes_commands (self ):
84+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME )
85+ def test_close_uses_the_correct_session_id (self , mock_client_class ):
86+ instance = mock_client_class .return_value
87+
88+ mock_open_session_resp = MagicMock (spec = TOpenSessionResp )()
89+ mock_open_session_resp .sessionHandle .sessionId = b"\x22 "
90+ instance .open_session .return_value = mock_open_session_resp
91+
92+ connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
93+ connection .close ()
94+
95+ # Check the close session request has an id of x22
96+ close_session_id = instance .close_session .call_args [0 ][0 ].sessionId
97+ self .assertEqual (close_session_id , b"\x22 " )
98+
99+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME )
100+ def test_auth_args (self , mock_client_class ):
101+ # Test that the following auth args work:
102+ # token = foo,
103+ # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True
104+ connection_args = [
105+ {
106+ "server_hostname" : "foo" ,
107+ "http_path" : None ,
108+ "access_token" : "tok" ,
109+ },
110+ {
111+ "server_hostname" : "foo" ,
112+ "http_path" : None ,
113+ "_tls_client_cert_file" : "something" ,
114+ "_use_cert_as_auth" : True ,
115+ "access_token" : None ,
116+ },
117+ ]
118+
119+ for args in connection_args :
120+ connection = databricks .sql .connect (** args )
121+ host , port , http_path , * _ = mock_client_class .call_args [0 ]
122+ self .assertEqual (args ["server_hostname" ], host )
123+ self .assertEqual (args ["http_path" ], http_path )
124+ connection .close ()
125+
126+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME )
127+ def test_http_header_passthrough (self , mock_client_class ):
128+ http_headers = [("foo" , "bar" )]
129+ databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS , http_headers = http_headers )
130+
131+ call_args = mock_client_class .call_args [0 ][3 ]
132+ self .assertIn (("foo" , "bar" ), call_args )
133+
134+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME )
135+ def test_tls_arg_passthrough (self , mock_client_class ):
136+ databricks .sql .connect (
137+ ** self .DUMMY_CONNECTION_ARGS ,
138+ _tls_verify_hostname = "hostname" ,
139+ _tls_trusted_ca_file = "trusted ca file" ,
140+ _tls_client_cert_key_file = "trusted client cert" ,
141+ _tls_client_cert_key_password = "key password" ,
142+ )
143+
144+ kwargs = mock_client_class .call_args [1 ]
145+ self .assertEqual (kwargs ["_tls_verify_hostname" ], "hostname" )
146+ self .assertEqual (kwargs ["_tls_trusted_ca_file" ], "trusted ca file" )
147+ self .assertEqual (kwargs ["_tls_client_cert_key_file" ], "trusted client cert" )
148+ self .assertEqual (kwargs ["_tls_client_cert_key_password" ], "key password" )
149+
150+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME )
151+ def test_useragent_header (self , mock_client_class ):
152+ databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
153+
154+ http_headers = mock_client_class .call_args [0 ][3 ]
155+ user_agent_header = (
156+ "User-Agent" ,
157+ "{}/{}" .format (databricks .sql .USER_AGENT_NAME , databricks .sql .__version__ ),
158+ )
159+ self .assertIn (user_agent_header , http_headers )
160+
161+ databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS , user_agent_entry = "foobar" )
162+ user_agent_header_with_entry = (
163+ "User-Agent" ,
164+ "{}/{} ({})" .format (
165+ databricks .sql .USER_AGENT_NAME , databricks .sql .__version__ , "foobar"
166+ ),
167+ )
168+ http_headers = mock_client_class .call_args [0 ][3 ]
169+ self .assertIn (user_agent_header_with_entry , http_headers )
170+
171+ @patch ("%s.session.ThriftBackend" % PACKAGE_NAME , ThriftBackendMockFactory .new ())
172+ @patch ("%s.client.ResultSet" % PACKAGE_NAME )
173+ def test_closing_connection_closes_commands (self , mock_result_set_class ):
174+ # Test once with has_been_closed_server side, once without
89175 for closed in (True , False ):
90176 with self .subTest (closed = closed ):
91177 connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
0 commit comments