Skip to content

Commit 1b9a2b8

Browse files
committed
added statement type to command id
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 1367f69 commit 1b9a2b8

File tree

8 files changed

+40
-47
lines changed

8 files changed

+40
-47
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
from uuid import UUID
1010

1111
from databricks.sql.result_set import ThriftResultSet
12-
12+
from databricks.sql.telemetry.models.event import StatementType
1313

1414
if TYPE_CHECKING:
1515
from databricks.sql.client import Cursor
1616
from databricks.sql.result_set import ResultSet
17-
from databricks.sql.telemetry.models.event import StatementType
1817

1918
from databricks.sql.backend.types import (
2019
CommandState,
@@ -833,7 +832,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
833832
return execute_response, is_direct_results
834833

835834
def get_execution_result(
836-
self, command_id: CommandId, cursor: "Cursor", statement_type: StatementType
835+
self, command_id: CommandId, cursor: "Cursor"
837836
) -> "ResultSet":
838837
thrift_handle = command_id.to_thrift_handle()
839838
if not thrift_handle:
@@ -889,6 +888,7 @@ def get_execution_result(
889888
arrow_schema_bytes=schema_bytes,
890889
result_format=t_result_set_metadata_resp.resultFormat,
891890
)
891+
execute_response.command_id.set_statement_type(StatementType.QUERY)
892892

893893
return ThriftResultSet(
894894
connection=cursor.connection,
@@ -902,7 +902,6 @@ def get_execution_result(
902902
ssl_options=self._ssl_options,
903903
is_direct_results=is_direct_results,
904904
session_id_hex=self._session_id_hex,
905-
statement_type=statement_type,
906905
)
907906

908907
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -968,7 +967,6 @@ def execute_command(
968967
max_bytes: int,
969968
lz4_compression: bool,
970969
cursor: Cursor,
971-
statement_type: StatementType,
972970
use_cloud_fetch=True,
973971
parameters=[],
974972
async_op=False,
@@ -1030,6 +1028,8 @@ def execute_command(
10301028
if resp.directResults and resp.directResults.resultSet:
10311029
t_row_set = resp.directResults.resultSet.results
10321030

1031+
execute_response.command_id.set_statement_type(StatementType.QUERY)
1032+
10331033
return ThriftResultSet(
10341034
connection=cursor.connection,
10351035
execute_response=execute_response,
@@ -1042,7 +1042,6 @@ def execute_command(
10421042
ssl_options=self._ssl_options,
10431043
is_direct_results=is_direct_results,
10441044
session_id_hex=self._session_id_hex,
1045-
statement_type=statement_type,
10461045
)
10471046

10481047
def get_catalogs(
@@ -1051,7 +1050,6 @@ def get_catalogs(
10511050
max_rows: int,
10521051
max_bytes: int,
10531052
cursor: "Cursor",
1054-
statement_type: StatementType,
10551053
) -> "ResultSet":
10561054
thrift_handle = session_id.to_thrift_handle()
10571055
if not thrift_handle:
@@ -1073,6 +1071,8 @@ def get_catalogs(
10731071
if resp.directResults and resp.directResults.resultSet:
10741072
t_row_set = resp.directResults.resultSet.results
10751073

1074+
execute_response.command_id.set_statement_type(StatementType.METADATA)
1075+
10761076
return ThriftResultSet(
10771077
connection=cursor.connection,
10781078
execute_response=execute_response,
@@ -1085,7 +1085,6 @@ def get_catalogs(
10851085
ssl_options=self._ssl_options,
10861086
is_direct_results=is_direct_results,
10871087
session_id_hex=self._session_id_hex,
1088-
statement_id=statement_type,
10891088
)
10901089

10911090
def get_schemas(
@@ -1094,7 +1093,6 @@ def get_schemas(
10941093
max_rows: int,
10951094
max_bytes: int,
10961095
cursor: Cursor,
1097-
statement_type: StatementType,
10981096
catalog_name=None,
10991097
schema_name=None,
11001098
) -> "ResultSet":
@@ -1122,6 +1120,8 @@ def get_schemas(
11221120
if resp.directResults and resp.directResults.resultSet:
11231121
t_row_set = resp.directResults.resultSet.results
11241122

1123+
execute_response.command_id.set_statement_type(StatementType.METADATA)
1124+
11251125
return ThriftResultSet(
11261126
connection=cursor.connection,
11271127
execute_response=execute_response,
@@ -1134,7 +1134,6 @@ def get_schemas(
11341134
ssl_options=self._ssl_options,
11351135
is_direct_results=is_direct_results,
11361136
session_id_hex=self._session_id_hex,
1137-
statement_type=statement_type,
11381137
)
11391138

11401139
def get_tables(
@@ -1143,7 +1142,6 @@ def get_tables(
11431142
max_rows: int,
11441143
max_bytes: int,
11451144
cursor: Cursor,
1146-
statement_type: StatementType,
11471145
catalog_name=None,
11481146
schema_name=None,
11491147
table_name=None,
@@ -1175,6 +1173,8 @@ def get_tables(
11751173
if resp.directResults and resp.directResults.resultSet:
11761174
t_row_set = resp.directResults.resultSet.results
11771175

1176+
execute_response.command_id.set_statement_type(StatementType.METADATA)
1177+
11781178
return ThriftResultSet(
11791179
connection=cursor.connection,
11801180
execute_response=execute_response,
@@ -1187,7 +1187,6 @@ def get_tables(
11871187
ssl_options=self._ssl_options,
11881188
is_direct_results=is_direct_results,
11891189
session_id_hex=self._session_id_hex,
1190-
statement_type=statement_type,
11911190
)
11921191

11931192
def get_columns(
@@ -1196,7 +1195,6 @@ def get_columns(
11961195
max_rows: int,
11971196
max_bytes: int,
11981197
cursor: Cursor,
1199-
statement_type: StatementType,
12001198
catalog_name=None,
12011199
schema_name=None,
12021200
table_name=None,
@@ -1228,6 +1226,8 @@ def get_columns(
12281226
if resp.directResults and resp.directResults.resultSet:
12291227
t_row_set = resp.directResults.resultSet.results
12301228

1229+
execute_response.command_id.set_statement_type(StatementType.METADATA)
1230+
12311231
return ThriftResultSet(
12321232
connection=cursor.connection,
12331233
execute_response=execute_response,
@@ -1240,7 +1240,6 @@ def get_columns(
12401240
ssl_options=self._ssl_options,
12411241
is_direct_results=is_direct_results,
12421242
session_id_hex=self._session_id_hex,
1243-
statement_type=statement_type,
12441243
)
12451244

12461245
def _handle_execute_response(self, resp, cursor):
@@ -1275,7 +1274,6 @@ def fetch_results(
12751274
lz4_compressed: bool,
12761275
arrow_schema_bytes,
12771276
description,
1278-
statement_type,
12791277
chunk_id: int,
12801278
use_cloud_fetch=True,
12811279
):
@@ -1316,7 +1314,7 @@ def fetch_results(
13161314
ssl_options=self._ssl_options,
13171315
session_id_hex=self._session_id_hex,
13181316
statement_id=command_id.to_hex_guid(),
1319-
statement_type=statement_type,
1317+
statement_type=command_id.statement_type,
13201318
chunk_id=chunk_id,
13211319
)
13221320

src/databricks/sql/backend/types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55

66
from databricks.sql.backend.utils.guid_utils import guid_to_hex_id
7+
from databricks.sql.telemetry.models.enums import StatementType
78
from databricks.sql.thrift_api.TCLIService import ttypes
89

910
logger = logging.getLogger(__name__)
@@ -281,6 +282,7 @@ def __init__(
281282
operation_type: Optional[int] = None,
282283
has_result_set: bool = False,
283284
modified_row_count: Optional[int] = None,
285+
statement_type: Optional[StatementType] = None,
284286
):
285287
"""
286288
Initialize a CommandId.
@@ -300,6 +302,7 @@ def __init__(
300302
self.operation_type = operation_type
301303
self.has_result_set = has_result_set
302304
self.modified_row_count = modified_row_count
305+
self._statement_type = statement_type
303306

304307
def __str__(self) -> str:
305308
"""
@@ -411,6 +414,19 @@ def to_hex_guid(self) -> str:
411414
else:
412415
return str(self.guid)
413416

417+
def set_statement_type(self, statement_type: StatementType):
418+
"""
419+
Set the statement type for this command.
420+
"""
421+
self._statement_type = statement_type
422+
423+
@property
424+
def statement_type(self) -> Optional[StatementType]:
425+
"""
426+
Get the statement type for this command.
427+
"""
428+
return self._statement_type
429+
414430

415431
@dataclass
416432
class ExecuteResponse:

src/databricks/sql/client.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,6 @@ def execute(
870870
async_op=False,
871871
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
872872
row_limit=self.row_limit,
873-
statement_type=self.statement_type,
874873
)
875874

876875
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -930,7 +929,6 @@ def execute_async(
930929
async_op=True,
931930
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
932931
row_limit=self.row_limit,
933-
statement_type=self.statement_type,
934932
)
935933

936934
return self
@@ -971,7 +969,7 @@ def get_async_execution_result(self):
971969
operation_state = self.get_query_state()
972970
if operation_state == CommandState.SUCCEEDED:
973971
self.active_result_set = self.backend.get_execution_result(
974-
self.active_command_id, cursor=self, statement_type=self.statement_type
972+
self.active_command_id, self
975973
)
976974

977975
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -1016,7 +1014,6 @@ def catalogs(self) -> "Cursor":
10161014
max_rows=self.arraysize,
10171015
max_bytes=self.buffer_size_bytes,
10181016
cursor=self,
1019-
statement_type=self.statement_type,
10201017
)
10211018
return self
10221019

@@ -1040,7 +1037,6 @@ def schemas(
10401037
cursor=self,
10411038
catalog_name=catalog_name,
10421039
schema_name=schema_name,
1043-
statement_type=self.statement_type,
10441040
)
10451041
return self
10461042

@@ -1071,7 +1067,6 @@ def tables(
10711067
schema_name=schema_name,
10721068
table_name=table_name,
10731069
table_types=table_types,
1074-
statement_type=self.statement_type,
10751070
)
10761071
return self
10771072

@@ -1102,7 +1097,6 @@ def columns(
11021097
schema_name=schema_name,
11031098
table_name=table_name,
11041099
column_name=column_name,
1105-
statement_type=self.statement_type,
11061100
)
11071101
return self
11081102

src/databricks/sql/result_set.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def __init__(
194194
execute_response: "ExecuteResponse",
195195
thrift_client: "ThriftDatabricksClient",
196196
session_id_hex: Optional[str],
197-
statement_type: StatementType,
198197
buffer_size_bytes: int = 104857600,
199198
arraysize: int = 10000,
200199
use_cloud_fetch: bool = True,
@@ -218,7 +217,7 @@ def __init__(
218217
:param ssl_options: SSL options for cloud fetch
219218
:param is_direct_results: Whether there are more rows to fetch
220219
"""
221-
self.statement_type = statement_type
220+
self.statement_type = execute_response.command_id.statement_type
222221
self.chunk_id = 0
223222

224223
# Initialize ThriftResultSet-specific attributes
@@ -241,7 +240,7 @@ def __init__(
241240
ssl_options=ssl_options,
242241
session_id_hex=session_id_hex,
243242
statement_id=execute_response.command_id.to_hex_guid(),
244-
statement_type=statement_type,
243+
statement_type=self.statement_type,
245244
chunk_id=self.chunk_id,
246245
)
247246
if t_row_set and t_row_set.resultLinks:
@@ -278,7 +277,6 @@ def _fill_results_buffer(self):
278277
arrow_schema_bytes=self._arrow_schema_bytes,
279278
description=self.description,
280279
use_cloud_fetch=self._use_cloud_fetch,
281-
statement_type=self.statement_type,
282280
chunk_id=self.chunk_id,
283281
)
284282
self.results = results

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_extractor(obj):
144144
return None
145145

146146

147-
def log_latency(statement_type: StatementType = StatementType.NONE):
147+
def log_latency():
148148
"""
149149
Decorator for logging execution latency and telemetry information.
150150
@@ -158,11 +158,8 @@ def log_latency(statement_type: StatementType = StatementType.NONE):
158158
- Creates a SqlExecutionEvent with execution details
159159
- Sends the telemetry data asynchronously via TelemetryClient
160160
161-
Args:
162-
statement_type (StatementType): The type of SQL statement being executed.
163-
164161
Usage:
165-
@log_latency(StatementType.SQL)
162+
@log_latency()
166163
def execute(self, query):
167164
# Method implementation
168165
pass

tests/unit/test_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
129129
execute_response=mock_execute_response,
130130
thrift_client=mock_backend,
131131
session_id_hex=Mock(),
132-
statement_type=Mock(),
133132
)
134133

135134
# Mock execute_command to return our real result set
@@ -197,7 +196,6 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
197196
execute_response=Mock(),
198197
thrift_client=mock_backend,
199198
session_id_hex=Mock(),
200-
statement_type=Mock(),
201199
)
202200
result_set.results = mock_results
203201

@@ -225,7 +223,7 @@ def test_closing_result_set_hard_closes_commands(self):
225223

226224
mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0)
227225
result_set = ThriftResultSet(
228-
mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock(), statement_type=Mock(),)
226+
mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock())
229227

230228
result_set.close()
231229

@@ -271,7 +269,7 @@ def test_negative_fetch_throws_exception(self):
271269
mock_backend = Mock()
272270
mock_backend.fetch_results.return_value = (Mock(), False, 0)
273271

274-
result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock(), statement_type=Mock())
272+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock())
275273

276274
with self.assertRaises(ValueError) as e:
277275
result_set.fetchmany(-1)

0 commit comments

Comments
 (0)