Skip to content

Commit 4735329

Browse files
committed
Merge branch 'main' into telemetry-server-flag
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
2 parents 725cce9 + 2f8b1ab commit 4735329

File tree

13 files changed

+448
-153
lines changed

13 files changed

+448
-153
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
WaitTimeout,
2020
MetadataCommands,
2121
)
22+
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
2223
from databricks.sql.thrift_api.TCLIService import ttypes
2324

2425
if TYPE_CHECKING:
@@ -322,6 +323,11 @@ def _extract_description_from_manifest(
322323
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
323324
name = col_data.get("name", "")
324325
type_name = col_data.get("type_name", "")
326+
327+
# Normalize SEA type to Thrift conventions before any processing
328+
type_name = normalize_sea_type_to_thrift(type_name, col_data)
329+
330+
# Now strip _TYPE suffix and convert to lowercase
325331
type_name = (
326332
type_name[:-5] if type_name.endswith("_TYPE") else type_name
327333
).lower()

src/databricks/sql/backend/sea/result_set.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,19 @@ def _convert_json_types(self, row: List[str]) -> List[Any]:
9292
converted_row = []
9393

9494
for i, value in enumerate(row):
95+
column_name = self.description[i][0]
9596
column_type = self.description[i][1]
9697
precision = self.description[i][4]
9798
scale = self.description[i][5]
9899

99-
try:
100-
converted_value = SqlTypeConverter.convert_value(
101-
value, column_type, precision=precision, scale=scale
102-
)
103-
converted_row.append(converted_value)
104-
except Exception as e:
105-
logger.warning(
106-
f"Error converting value '{value}' to {column_type}: {e}"
107-
)
108-
converted_row.append(value)
100+
converted_value = SqlTypeConverter.convert_value(
101+
value,
102+
column_type,
103+
column_name=column_name,
104+
precision=precision,
105+
scale=scale,
106+
)
107+
converted_row.append(converted_value)
109108

110109
return converted_row
111110

src/databricks/sql/backend/sea/utils/conversion.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,60 +50,65 @@ def _convert_decimal(
5050

5151
class SqlType:
5252
"""
53-
SQL type constants
53+
SQL type constants based on Thrift TTypeId values.
5454
55-
The list of types can be found in the SEA REST API Reference:
56-
https://docs.databricks.com/api/workspace/statementexecution/executestatement
55+
These correspond to the normalized type names that come from the SEA backend
56+
after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix).
5757
"""
5858

5959
# Numeric types
60-
BYTE = "byte"
61-
SHORT = "short"
62-
INT = "int"
63-
LONG = "long"
64-
FLOAT = "float"
65-
DOUBLE = "double"
66-
DECIMAL = "decimal"
60+
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
61+
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
62+
INT = "int" # Maps to TTypeId.INT_TYPE
63+
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
64+
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
65+
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
66+
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
6767

6868
# Boolean type
69-
BOOLEAN = "boolean"
69+
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
7070

7171
# Date/Time types
72-
DATE = "date"
73-
TIMESTAMP = "timestamp"
74-
INTERVAL = "interval"
72+
DATE = "date" # Maps to TTypeId.DATE_TYPE
73+
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
74+
INTERVAL_YEAR_MONTH = (
75+
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
76+
)
77+
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
7578

7679
# String types
77-
CHAR = "char"
78-
STRING = "string"
80+
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
81+
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
82+
STRING = "string" # Maps to TTypeId.STRING_TYPE
7983

8084
# Binary type
81-
BINARY = "binary"
85+
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
8286

8387
# Complex types
84-
ARRAY = "array"
85-
MAP = "map"
86-
STRUCT = "struct"
88+
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
89+
MAP = "map" # Maps to TTypeId.MAP_TYPE
90+
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
8791

8892
# Other types
89-
NULL = "null"
90-
USER_DEFINED_TYPE = "user_defined_type"
93+
NULL = "null" # Maps to TTypeId.NULL_TYPE
94+
UNION = "union" # Maps to TTypeId.UNION_TYPE
95+
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
9196

9297

9398
class SqlTypeConverter:
9499
"""
95100
Utility class for converting SQL types to Python types.
96-
Based on the types supported by the Databricks SDK.
101+
Based on the Thrift TTypeId types after normalization.
97102
"""
98103

99104
# SQL type to conversion function mapping
100105
# TODO: complex types
101106
TYPE_MAPPING: Dict[str, Callable] = {
102107
# Numeric types
103-
SqlType.BYTE: lambda v: int(v),
104-
SqlType.SHORT: lambda v: int(v),
108+
SqlType.TINYINT: lambda v: int(v),
109+
SqlType.SMALLINT: lambda v: int(v),
105110
SqlType.INT: lambda v: int(v),
106-
SqlType.LONG: lambda v: int(v),
111+
SqlType.BIGINT: lambda v: int(v),
107112
SqlType.FLOAT: lambda v: float(v),
108113
SqlType.DOUBLE: lambda v: float(v),
109114
SqlType.DECIMAL: _convert_decimal,
@@ -112,30 +117,34 @@ class SqlTypeConverter:
112117
# Date/Time types
113118
SqlType.DATE: lambda v: datetime.date.fromisoformat(v),
114119
SqlType.TIMESTAMP: lambda v: parser.parse(v),
115-
SqlType.INTERVAL: lambda v: v, # Keep as string for now
120+
SqlType.INTERVAL_YEAR_MONTH: lambda v: v, # Keep as string for now
121+
SqlType.INTERVAL_DAY_TIME: lambda v: v, # Keep as string for now
116122
# String types - no conversion needed
117123
SqlType.CHAR: lambda v: v,
124+
SqlType.VARCHAR: lambda v: v,
118125
SqlType.STRING: lambda v: v,
119126
# Binary type
120127
SqlType.BINARY: lambda v: bytes.fromhex(v),
121128
# Other types
122129
SqlType.NULL: lambda v: None,
123130
# Complex types and user-defined types return as-is
124-
SqlType.USER_DEFINED_TYPE: lambda v: v,
131+
SqlType.USER_DEFINED: lambda v: v,
125132
}
126133

127134
@staticmethod
128135
def convert_value(
129136
value: str,
130137
sql_type: str,
138+
column_name: Optional[str],
131139
**kwargs,
132140
) -> object:
133141
"""
134142
Convert a string value to the appropriate Python type based on SQL type.
135143
136144
Args:
137145
value: The string value to convert
138-
sql_type: The SQL type (e.g., 'int', 'decimal')
146+
sql_type: The SQL type (e.g., 'tinyint', 'decimal')
147+
column_name: The name of the column being converted
139148
**kwargs: Additional keyword arguments for the conversion function
140149
141150
Returns:
@@ -155,6 +164,10 @@ def convert_value(
155164
return converter_func(value, precision, scale)
156165
else:
157166
return converter_func(value)
158-
except (ValueError, TypeError, decimal.InvalidOperation) as e:
159-
logger.warning(f"Error converting value '{value}' to {sql_type}: {e}")
167+
except Exception as e:
168+
warning_message = f"Error converting value '{value}' to {sql_type}"
169+
if column_name:
170+
warning_message += f" in column {column_name}"
171+
warning_message += f": {e}"
172+
logger.warning(warning_message)
160173
return value
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Type normalization utilities for SEA backend.
3+
4+
This module provides functionality to normalize SEA type names to match
5+
Thrift type naming conventions.
6+
"""
7+
8+
from typing import Dict, Any
9+
10+
# SEA types that need to be translated to Thrift types
11+
# The list of all SEA types is available in the REST reference at:
12+
# https://docs.databricks.com/api/workspace/statementexecution/executestatement
13+
# The list of all Thrift types can be found in the ttypes.TTypeId definition
14+
# The SEA types that do not align with Thrift are explicitly mapped below
15+
SEA_TO_THRIFT_TYPE_MAP = {
16+
"BYTE": "TINYINT",
17+
"SHORT": "SMALLINT",
18+
"LONG": "BIGINT",
19+
"INTERVAL": "INTERVAL", # Default mapping, will be overridden if type_interval_type is present
20+
}
21+
22+
23+
def normalize_sea_type_to_thrift(type_name: str, col_data: Dict[str, Any]) -> str:
24+
"""
25+
Normalize SEA type names to match Thrift type naming conventions.
26+
27+
Args:
28+
type_name: The type name from SEA (e.g., "BYTE", "LONG", "INTERVAL")
29+
col_data: The full column data dictionary from manifest (for accessing type_interval_type)
30+
31+
Returns:
32+
Normalized type name matching Thrift conventions
33+
"""
34+
# Early return if type doesn't need mapping
35+
if type_name not in SEA_TO_THRIFT_TYPE_MAP:
36+
return type_name
37+
38+
normalized_type = SEA_TO_THRIFT_TYPE_MAP[type_name]
39+
40+
# Special handling for interval types
41+
if type_name == "INTERVAL":
42+
type_interval_type = col_data.get("type_interval_type")
43+
if type_interval_type:
44+
return (
45+
"INTERVAL_YEAR_MONTH"
46+
if any(t in type_interval_type.upper() for t in ["YEAR", "MONTH"])
47+
else "INTERVAL_DAY_TIME"
48+
)
49+
50+
return normalized_type

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id):
145145
def close(self):
146146
pass
147147

148+
def _flush(self):
149+
pass
150+
148151

149152
class TelemetryClient(BaseTelemetryClient):
150153
"""

tests/e2e/test_complex_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ def table_fixture(self, connection_details):
3939
)
4040
"""
4141
)
42-
yield
43-
# Clean up the table after the test
44-
cursor.execute("DELETE FROM pysql_test_complex_types_table")
42+
try:
43+
yield
44+
finally:
45+
# Clean up the table after the test
46+
cursor.execute("DELETE FROM pysql_test_complex_types_table")
4547

4648
@pytest.mark.parametrize(
4749
"field,expected_type",

0 commit comments

Comments
 (0)