Skip to content

Commit 9b8a002

Browse files
Niall Egansusodapop
authored andcommitted
Simple fetchall and execute methods
This PR does a few things: 1. Introduce simple .execute() method 2. Introduce a simple .fetchall method that will return the results (this will only return 1 row, follow ups will fix) 3. Change the server side and protocol to send back a physical Arrow schema instead of just the gRPC server Some follow up tickets: 1. Fetch-ahead to internal buffer in background thread 2. Proper results deserialisation 3. Fetchmany etc + fetching multiple rows, properly checking for the number of results 4. Server holding on to poll status request Also dropped support for Spark 3.0 - New Driverlocal tests - New unit tests Author: Niall Egan <niall.egan@databricks.com>
1 parent a7f3994 commit 9b8a002

File tree

3 files changed

+326
-84
lines changed

3 files changed

+326
-84
lines changed

cmdexec/clients/python/command_exec_client.py

Lines changed: 187 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from typing import Dict, Tuple, List
33

44
import grpc
5-
from google.protobuf import message
5+
import pyarrow
66

77
import cmdexec.clients.python.sql_command_service_pb2 as command_pb2
88
from cmdexec.clients.python.sql_command_service_pb2_grpc import SqlCommandServiceStub
9-
from cmdexec.clients.python.errors import OperationalError, InterfaceError
9+
from cmdexec.clients.python.errors import OperationalError, InterfaceError, DatabaseError, Error
10+
11+
import time
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -23,31 +25,199 @@ def __init__(self, **kwargs):
2325
except KeyError:
2426
raise InterfaceError("Please include arguments HOST and PORT in kwargs for Connection")
2527

26-
self._base_client = CmdExecBaseHttpClient(self.host, self.port, kwargs.get("metadata", []))
28+
self.base_client = CmdExecBaseHttpClient(self.host, self.port, kwargs.get("metadata", []))
2729
open_session_request = command_pb2.OpenSessionRequest(
2830
configuration={},
2931
client_session_id=None,
3032
session_info_fields=None,
3133
)
3234

33-
try:
34-
resp = self._base_client.make_request(self._base_client.stub.OpenSession,
35-
open_session_request)
36-
self.session_id = resp.id
37-
logger.info("Successfully opened session " + str(self.session_id.hex()))
38-
except grpc.RpcError as error:
39-
raise OperationalError("Error during database connection", error)
35+
resp = self.base_client.make_request(self.base_client.stub.OpenSession,
36+
open_session_request)
37+
self.session_id = resp.id
38+
self.open = True
39+
logger.info("Successfully opened session " + str(self.session_id.hex()))
40+
self._cursors = []
4041

4142
def cursor(self):
42-
pass
43+
if not self.open:
44+
raise Error("Cannot create cursor from closed connection")
45+
cursor = Cursor(self)
46+
self._cursors.append(cursor)
47+
return cursor
4348

4449
def close(self):
50+
close_session_request = command_pb2.CloseSessionRequest(id=self.session_id)
51+
self.base_client.make_request(self.base_client.stub.CloseSession, close_session_request)
52+
self.open = False
53+
54+
for cursor in self._cursors:
55+
cursor.close()
56+
57+
58+
class Cursor:
59+
def __init__(self, connection):
60+
self.connection = connection
61+
self.description = None
62+
self.rowcount = -1
63+
self.arraysize = 1
64+
65+
self.active_result_set = None
66+
# Note that Cursor closed => active result set closed, but not vice versa
67+
self.open = True
68+
69+
def _response_to_result_set(self, execute_command_response, status):
70+
command_id = execute_command_response.command_id
71+
arrow_results = execute_command_response.results.arrow_ipc_stream
72+
has_been_closed_server_side = execute_command_response.closed
73+
74+
return ResultSet(self.connection, command_id, status, has_been_closed_server_side,
75+
arrow_results)
76+
77+
def _close_and_clear_active_result_set(self):
4578
try:
46-
close_session_request = command_pb2.CloseSessionRequest(id=self.session_id)
47-
self._base_client.make_request(self._base_client.stub.CloseSession,
48-
close_session_request)
49-
except grpc.RpcError as error:
50-
raise OperationalError("Error during database connection close", error)
79+
if self.active_result_set:
80+
self.active_result_set.close()
81+
finally:
82+
self.active_result_set = None
83+
84+
def _check_not_closed(self):
85+
if not self.open:
86+
raise Error("Attempting operation on closed cursor")
87+
88+
def _check_response_for_error(self, resp, command_id):
89+
status = resp.status.state
90+
if status == command_pb2.ERROR:
91+
raise DatabaseError(
92+
"Command %s failed with error message %s" % (command_id, resp.status.error_message))
93+
elif status == command_pb2.CLOSED:
94+
raise DatabaseError("Command %s closed before results could be fetched" % command_id)
95+
96+
def _poll_for_state(self, command_id):
97+
get_status_request = command_pb2.GetCommandStatusRequest(id=command_id)
98+
99+
resp = self.connection.base_client.make_request(
100+
self.connection.base_client.stub.GetCommandStatus, get_status_request)
101+
102+
logger.info("Status for command %s is: %s" % (command_id, resp.status))
103+
return resp
104+
105+
def _wait_until_command_done(self, command_id, initial_status):
106+
status = initial_status
107+
print("initial status: %s" % status)
108+
while status in [command_pb2.PENDING, command_pb2.RUNNING]:
109+
resp = self._poll_for_state(command_id)
110+
status = resp.status.state
111+
self._check_response_for_error(resp, command_id)
112+
print("status is: %s" % status)
113+
114+
# TODO: Remove this sleep once we have long-polling on the server (SC-77653)
115+
time.sleep(1)
116+
return status
117+
118+
def execute(self, operation, query_params=None, metadata=None):
119+
self._check_not_closed()
120+
self._close_and_clear_active_result_set()
121+
122+
# Execute the command
123+
execute_command_request = command_pb2.ExecuteCommandRequest(
124+
session_id=self.connection.session_id,
125+
client_command_id=None,
126+
command=operation,
127+
conf_overlay=None,
128+
row_limit=None,
129+
result_options=None,
130+
)
131+
132+
execute_command_response = self.connection.base_client.make_request(
133+
self.connection.base_client.stub.ExecuteCommand, execute_command_request)
134+
initial_status = execute_command_response.status.state
135+
command_id = execute_command_response.command_id
136+
137+
self._check_response_for_error(execute_command_response, command_id)
138+
final_status = self._wait_until_command_done(command_id, initial_status)
139+
self.active_result_set = self._response_to_result_set(execute_command_response,
140+
final_status)
141+
142+
return self
143+
144+
def fetchall(self):
145+
self._check_not_closed()
146+
if self.active_result_set:
147+
return self.active_result_set.fetchall()
148+
else:
149+
raise Error("There is no active result set")
150+
151+
def close(self):
152+
self.open = False
153+
if self.active_result_set:
154+
self._close_and_clear_active_result_set()
155+
156+
157+
class ResultSet:
158+
def __init__(self,
159+
connection,
160+
command_id,
161+
status,
162+
has_been_closed_server_side,
163+
arrow_ipc_stream=None):
164+
self.connection = connection
165+
self.command_id = command_id
166+
self.status = status
167+
self.has_been_closed_server_side = has_been_closed_server_side
168+
169+
assert (self.status not in [command_pb2.PENDING, command_pb2.RUNNING])
170+
171+
if arrow_ipc_stream:
172+
self.results = self._deserialize_arrow_ipc_stream(arrow_ipc_stream)
173+
else:
174+
self.results = None
175+
176+
def _deserialize_arrow_ipc_stream(self, ipc_stream):
177+
# TODO: Proper results deserialization, taking into account the logical schema (SC-77871)
178+
reader = pyarrow.ipc.open_stream(ipc_stream)
179+
return reader.read_all().to_pandas().values.tolist()
180+
181+
def _fetch_results(self):
182+
# TODO: Offsets, number of rows (SC-77872)
183+
184+
fetch_results_request = command_pb2.FetchCommandResultsRequest(
185+
id=self.command_id,
186+
options=command_pb2.CommandResultOptions(
187+
max_rows=1,
188+
include_metadata=True,
189+
))
190+
191+
return self.connection.base_client.make_request(
192+
self.connection.base_client.stub.FetchCommandResults, fetch_results_request).results
193+
194+
def fetchall(self):
195+
# TODO: Check that these results are in the right place (SC-77872)
196+
if self.status == command_pb2.SUCCESS:
197+
return self.results
198+
elif self.status in [command_pb2.PENDING, command_pb2.RUNNING]:
199+
# TODO: Pre-fetch results (SC-77868)
200+
result_message = self._fetch_results()
201+
self.results = self._deserialize_arrow_ipc_stream(result_message.arrow_ipc_stream)
202+
return self.results
203+
elif self.status == command_pb2.CLOSED:
204+
raise Error("Can't fetch results on closed command %s" % self.command_id)
205+
elif self.status == command_pb2.ERROR:
206+
raise DatabaseError("Command %s failed" % self.command_id)
207+
else:
208+
raise Error(
209+
"Command %s is in an unrecognised state: %s" % (self.command_id, self.status))
210+
211+
def close(self):
212+
try:
213+
if self.status != command_pb2.CLOSED and not self.has_been_closed_server_side \
214+
and not self.connection.closed:
215+
close_command_request = command_pb2.CloseCommandRequest(id=self.command_id)
216+
self.connection.base_client.make_request(
217+
self.connection.base_client.stub.CloseCommand, close_command_request)
218+
finally:
219+
self.has_been_closed_server_side = True
220+
self.status = command_pb2.CLOSED
51221

52222

53223
class CmdExecBaseHttpClient:
@@ -68,4 +238,4 @@ def make_request(self, method, request):
68238
return response
69239
except grpc.RpcError as error:
70240
logger.error("Received error during gRPC request: %s", error)
71-
raise error
241+
raise OperationalError("Error during gRPC request", error)

0 commit comments

Comments
 (0)