Skip to content

Commit 73c3d6a

Browse files
Niall Egansusodapop
authored andcommitted
Simple fetchall and execute methods (redux)
While we're waiting for dev tools to upgrade the Jenkins AMI so we can use Pandas, I have just removed the use of Pandas to avoid the problems in the commit builder we were seeing (see the `_deserialize_arrow_ipc_stream` method). Author: Niall Egan <niall.egan@databricks.com>
1 parent 863e271 commit 73c3d6a

File tree

3 files changed

+330
-84
lines changed

3 files changed

+330
-84
lines changed

cmdexec/clients/python/command_exec_client.py

Lines changed: 191 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from typing import Dict, Tuple, List
33

44
import grpc
5-
from google.protobuf import message
5+
import pyarrow
66

77
import cmdexec.clients.python.sql_command_service_pb2 as command_pb2
88
from cmdexec.clients.python.sql_command_service_pb2_grpc import SqlCommandServiceStub
9-
from cmdexec.clients.python.errors import OperationalError, InterfaceError
9+
from cmdexec.clients.python.errors import OperationalError, InterfaceError, DatabaseError, Error
10+
11+
import time
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -23,31 +25,203 @@ def __init__(self, **kwargs):
2325
except KeyError:
2426
raise InterfaceError("Please include arguments HOST and PORT in kwargs for Connection")
2527

26-
self._base_client = CmdExecBaseHttpClient(self.host, self.port, kwargs.get("metadata", []))
28+
self.base_client = CmdExecBaseHttpClient(self.host, self.port, kwargs.get("metadata", []))
2729
open_session_request = command_pb2.OpenSessionRequest(
2830
configuration={},
2931
client_session_id=None,
3032
session_info_fields=None,
3133
)
3234

33-
try:
34-
resp = self._base_client.make_request(self._base_client.stub.OpenSession,
35-
open_session_request)
36-
self.session_id = resp.id
37-
logger.info("Successfully opened session " + str(self.session_id.hex()))
38-
except grpc.RpcError as error:
39-
raise OperationalError("Error during database connection", error)
35+
resp = self.base_client.make_request(self.base_client.stub.OpenSession,
36+
open_session_request)
37+
self.session_id = resp.id
38+
self.open = True
39+
logger.info("Successfully opened session " + str(self.session_id.hex()))
40+
self._cursors = []
4041

4142
def cursor(self):
42-
pass
43+
if not self.open:
44+
raise Error("Cannot create cursor from closed connection")
45+
cursor = Cursor(self)
46+
self._cursors.append(cursor)
47+
return cursor
4348

4449
def close(self):
50+
close_session_request = command_pb2.CloseSessionRequest(id=self.session_id)
51+
self.base_client.make_request(self.base_client.stub.CloseSession, close_session_request)
52+
self.open = False
53+
54+
for cursor in self._cursors:
55+
cursor.close()
56+
57+
58+
class Cursor:
59+
def __init__(self, connection):
60+
self.connection = connection
61+
self.description = None
62+
self.rowcount = -1
63+
self.arraysize = 1
64+
65+
self.active_result_set = None
66+
# Note that Cursor closed => active result set closed, but not vice versa
67+
self.open = True
68+
69+
def _response_to_result_set(self, execute_command_response, status):
70+
command_id = execute_command_response.command_id
71+
arrow_results = execute_command_response.results.arrow_ipc_stream
72+
has_been_closed_server_side = execute_command_response.closed
73+
74+
return ResultSet(self.connection, command_id, status, has_been_closed_server_side,
75+
arrow_results)
76+
77+
def _close_and_clear_active_result_set(self):
4578
try:
46-
close_session_request = command_pb2.CloseSessionRequest(id=self.session_id)
47-
self._base_client.make_request(self._base_client.stub.CloseSession,
48-
close_session_request)
49-
except grpc.RpcError as error:
50-
raise OperationalError("Error during database connection close", error)
79+
if self.active_result_set:
80+
self.active_result_set.close()
81+
finally:
82+
self.active_result_set = None
83+
84+
def _check_not_closed(self):
85+
if not self.open:
86+
raise Error("Attempting operation on closed cursor")
87+
88+
def _check_response_for_error(self, resp, command_id):
89+
status = resp.status.state
90+
if status == command_pb2.ERROR:
91+
raise DatabaseError(
92+
"Command %s failed with error message %s" % (command_id, resp.status.error_message))
93+
elif status == command_pb2.CLOSED:
94+
raise DatabaseError("Command %s closed before results could be fetched" % command_id)
95+
96+
def _poll_for_state(self, command_id):
97+
get_status_request = command_pb2.GetCommandStatusRequest(id=command_id)
98+
99+
resp = self.connection.base_client.make_request(
100+
self.connection.base_client.stub.GetCommandStatus, get_status_request)
101+
102+
logger.info("Status for command %s is: %s" % (command_id, resp.status))
103+
return resp
104+
105+
def _wait_until_command_done(self, command_id, initial_status):
106+
status = initial_status
107+
print("initial status: %s" % status)
108+
while status in [command_pb2.PENDING, command_pb2.RUNNING]:
109+
resp = self._poll_for_state(command_id)
110+
status = resp.status.state
111+
self._check_response_for_error(resp, command_id)
112+
print("status is: %s" % status)
113+
114+
# TODO: Remove this sleep once we have long-polling on the server (SC-77653)
115+
time.sleep(1)
116+
return status
117+
118+
def execute(self, operation, query_params=None, metadata=None):
119+
self._check_not_closed()
120+
self._close_and_clear_active_result_set()
121+
122+
# Execute the command
123+
execute_command_request = command_pb2.ExecuteCommandRequest(
124+
session_id=self.connection.session_id,
125+
client_command_id=None,
126+
command=operation,
127+
conf_overlay=None,
128+
row_limit=None,
129+
result_options=None,
130+
)
131+
132+
execute_command_response = self.connection.base_client.make_request(
133+
self.connection.base_client.stub.ExecuteCommand, execute_command_request)
134+
initial_status = execute_command_response.status.state
135+
command_id = execute_command_response.command_id
136+
137+
self._check_response_for_error(execute_command_response, command_id)
138+
final_status = self._wait_until_command_done(command_id, initial_status)
139+
self.active_result_set = self._response_to_result_set(execute_command_response,
140+
final_status)
141+
142+
return self
143+
144+
def fetchall(self):
145+
self._check_not_closed()
146+
if self.active_result_set:
147+
return self.active_result_set.fetchall()
148+
else:
149+
raise Error("There is no active result set")
150+
151+
def close(self):
152+
self.open = False
153+
if self.active_result_set:
154+
self._close_and_clear_active_result_set()
155+
156+
157+
class ResultSet:
158+
def __init__(self,
159+
connection,
160+
command_id,
161+
status,
162+
has_been_closed_server_side,
163+
arrow_ipc_stream=None):
164+
self.connection = connection
165+
self.command_id = command_id
166+
self.status = status
167+
self.has_been_closed_server_side = has_been_closed_server_side
168+
169+
assert (self.status not in [command_pb2.PENDING, command_pb2.RUNNING])
170+
171+
if arrow_ipc_stream:
172+
self.results = self._deserialize_arrow_ipc_stream(arrow_ipc_stream)
173+
else:
174+
self.results = None
175+
176+
def _deserialize_arrow_ipc_stream(self, ipc_stream):
177+
# TODO: Proper results deserialization, taking into account the logical schema, convert
178+
# via pd df for efficiency (SC-77871)
179+
pyarrow_table = pyarrow.ipc.open_stream(ipc_stream).read_all()
180+
dict_repr = pyarrow_table.to_pydict()
181+
n_rows, n_cols = pyarrow_table.shape
182+
list_repr = [[col[i] for col in dict_repr.values()] for i in range(n_rows)]
183+
return list_repr
184+
185+
def _fetch_results(self):
186+
# TODO: Offsets, number of rows (SC-77872)
187+
188+
fetch_results_request = command_pb2.FetchCommandResultsRequest(
189+
id=self.command_id,
190+
options=command_pb2.CommandResultOptions(
191+
max_rows=1,
192+
include_metadata=True,
193+
))
194+
195+
return self.connection.base_client.make_request(
196+
self.connection.base_client.stub.FetchCommandResults, fetch_results_request).results
197+
198+
def fetchall(self):
199+
# TODO: Check that these results are in the right place (SC-77872)
200+
if self.status == command_pb2.SUCCESS:
201+
return self.results
202+
elif self.status in [command_pb2.PENDING, command_pb2.RUNNING]:
203+
# TODO: Pre-fetch results (SC-77868)
204+
result_message = self._fetch_results()
205+
self.results = self._deserialize_arrow_ipc_stream(result_message.arrow_ipc_stream)
206+
return self.results
207+
elif self.status == command_pb2.CLOSED:
208+
raise Error("Can't fetch results on closed command %s" % self.command_id)
209+
elif self.status == command_pb2.ERROR:
210+
raise DatabaseError("Command %s failed" % self.command_id)
211+
else:
212+
raise Error(
213+
"Command %s is in an unrecognised state: %s" % (self.command_id, self.status))
214+
215+
def close(self):
216+
try:
217+
if self.status != command_pb2.CLOSED and not self.has_been_closed_server_side \
218+
and not self.connection.closed:
219+
close_command_request = command_pb2.CloseCommandRequest(id=self.command_id)
220+
self.connection.base_client.make_request(
221+
self.connection.base_client.stub.CloseCommand, close_command_request)
222+
finally:
223+
self.has_been_closed_server_side = True
224+
self.status = command_pb2.CLOSED
51225

52226

53227
class CmdExecBaseHttpClient:
@@ -68,4 +242,4 @@ def make_request(self, method, request):
68242
return response
69243
except grpc.RpcError as error:
70244
logger.error("Received error during gRPC request: %s", error)
71-
raise error
245+
raise OperationalError("Error during gRPC request", error)

0 commit comments

Comments
 (0)