11import time
22import functools
3+ from typing import Optional
34from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
45from databricks .sql .telemetry .models .event import (
56 SqlExecutionEvent ,
67)
8+ from databricks .sql .telemetry .models .enums import ExecutionResultFormat , StatementType
9+ from databricks .sql .utils import ColumnQueue , CloudFetchQueue , ArrowQueue
10+ from uuid import UUID
11+
12+
13+ class TelemetryExtractor :
14+ def __init__ (self , obj ):
15+ self ._obj = obj
16+
17+ def __getattr__ (self , name ):
18+ return getattr (self ._obj , name )
19+
20+ def get_session_id_hex (self ): pass
21+ def get_statement_id (self ): pass
22+ def get_statement_type (self ): pass
23+ def get_is_compressed (self ): pass
24+ def get_execution_result (self ): pass
25+ def get_retry_count (self ): pass
26+
27+
28+ class CursorExtractor (TelemetryExtractor ):
29+ def get_statement_id (self ) -> Optional [str ]:
30+ return self .query_id
31+
32+ def get_session_id_hex (self ) -> Optional [str ]:
33+ return self .connection .get_session_id_hex ()
34+
35+ def get_is_compressed (self ) -> bool :
36+ return self .connection .lz4_compression
37+
38+ def get_execution_result (self ) -> ExecutionResultFormat :
39+ if self .active_result_set is None :
40+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
41+
42+ if isinstance (self .active_result_set .results , ColumnQueue ):
43+ return ExecutionResultFormat .COLUMNAR_INLINE
44+ elif isinstance (self .active_result_set .results , CloudFetchQueue ):
45+ return ExecutionResultFormat .EXTERNAL_LINKS
46+ elif isinstance (self .active_result_set .results , ArrowQueue ):
47+ return ExecutionResultFormat .INLINE_ARROW
48+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
49+
50+ def get_retry_count (self ) -> int :
51+ if (
52+ hasattr (self .thrift_backend , "retry_policy" )
53+ and self .thrift_backend .retry_policy
54+ ):
55+ return len (self .thrift_backend .retry_policy .history )
56+ return 0
57+
58+ def get_statement_type (self : str ) -> StatementType :
59+ # TODO: Implement this
60+ return StatementType .SQL
61+
62+
63+ class ResultSetExtractor (TelemetryExtractor ):
64+ def get_statement_id (self ) -> Optional [str ]:
65+ if self .command_id :
66+ return str (UUID (bytes = self .command_id .operationId .guid ))
67+ return None
68+
69+ def get_session_id_hex (self ) -> Optional [str ]:
70+ return self .connection .get_session_id_hex ()
71+
72+ def get_is_compressed (self ) -> bool :
73+ return self .lz4_compressed
74+
75+ def get_execution_result (self ) -> ExecutionResultFormat :
76+ if isinstance (self .results , ColumnQueue ):
77+ return ExecutionResultFormat .COLUMNAR_INLINE
78+ elif isinstance (self .results , CloudFetchQueue ):
79+ return ExecutionResultFormat .EXTERNAL_LINKS
80+ elif isinstance (self .results , ArrowQueue ):
81+ return ExecutionResultFormat .INLINE_ARROW
82+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
83+
84+ def get_statement_type (self : str ) -> StatementType :
85+ # TODO: Implement this
86+ return StatementType .SQL
87+
88+ def get_retry_count (self ) -> int :
89+ if (
90+ hasattr (self .thrift_backend , "retry_policy" )
91+ and self .thrift_backend .retry_policy
92+ ):
93+ return len (self .thrift_backend .retry_policy .history )
94+ return 0
95+
96+
97+ def get_extractor (obj ):
98+ if obj .__class__ .__name__ == 'Cursor' :
99+ return CursorExtractor (obj )
100+ elif obj .__class__ .__name__ == 'ResultSet' :
101+ return ResultSetExtractor (obj )
102+ else :
103+ return TelemetryExtractor (obj )
7104
8105
9106def log_latency ():
@@ -19,14 +116,15 @@ def wrapper(self, *args, **kwargs):
19116 end_time = time .perf_counter ()
20117 duration_ms = int ((end_time - start_time ) * 1000 )
21118
22- session_id_hex = self .get_session_id_hex ()
23- statement_id = self .get_statement_id ()
119+ extractor = get_extractor (self )
120+ session_id_hex = extractor .get_session_id_hex ()
121+ statement_id = extractor .get_statement_id ()
24122
25123 sql_exec_event = SqlExecutionEvent (
26- statement_type = self .get_statement_type (func . __name__ ),
27- is_compressed = self .get_is_compressed (),
28- execution_result = self .get_execution_result (),
29- retry_count = self .get_retry_count (),
124+ statement_type = extractor .get_statement_type (),
125+ is_compressed = extractor .get_is_compressed (),
126+ execution_result = extractor .get_execution_result (),
127+ retry_count = extractor .get_retry_count (),
30128 )
31129
32130 telemetry_client = TelemetryClientFactory .get_telemetry_client (
0 commit comments