Skip to content

Commit e064e25

Browse files
author
Jesse
authored
Query parameters: implement support for binding NoneType parameters (#233)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 13cb417 commit e064e25

File tree

4 files changed

+91
-27
lines changed

4 files changed

+91
-27
lines changed

src/databricks/sql/utils.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import copy
43
import datetime
54
import decimal
65
from abc import ABC, abstractmethod
@@ -495,6 +494,7 @@ class DbSqlType(Enum):
495494
BOOLEAN = "BOOLEAN"
496495
INTERVAL_MONTH = "INTERVAL MONTH"
497496
INTERVAL_DAY = "INTERVAL DAY"
497+
VOID = "VOID"
498498

499499

500500
class DbSqlParameter:
@@ -542,20 +542,41 @@ def infer_types(params: list[DbSqlParameter]):
542542
datetime.date: DbSqlType.DATE,
543543
bool: DbSqlType.BOOLEAN,
544544
Decimal: DbSqlType.DECIMAL,
545+
type(None): DbSqlType.VOID,
545546
}
546-
new_params = copy.deepcopy(params)
547-
for param in new_params:
548-
if not param.type:
549-
if type(param.value) in type_lookup_table:
550-
param.type = type_lookup_table[type(param.value)]
551-
else:
552-
raise ValueError("Parameter type cannot be inferred")
553-
554-
if param.type == DbSqlType.DECIMAL:
547+
548+
new_params = []
549+
550+
# cycle through each parameter we've been passed
551+
for param in params:
552+
_name: str = param.name
553+
_value: Any = param.value
554+
_type: Union[DbSqlType, DbsqlDynamicDecimalType, Enum, None]
555+
556+
if param.type:
557+
_type = param.type
558+
else:
559+
# figure out what type to use
560+
_type = type_lookup_table.get(type(_value), None)
561+
if not _type:
562+
raise ValueError(
563+
f"Could not infer parameter type from {type(param.value)} - {param.value}"
564+
)
565+
566+
# Decimal require special handling because one column type in Databricks can have multiple precisions
567+
if _type == DbSqlType.DECIMAL:
555568
cast_exp = calculate_decimal_cast_string(param.value)
556-
param.type = DbsqlDynamicDecimalType(cast_exp)
569+
_type = DbsqlDynamicDecimalType(cast_exp)
570+
571+
# VOID / NULL types must be passed in a unique way as TSparkParameters with no value
572+
if _type == DbSqlType.VOID:
573+
new_params.append(DbSqlParameter(name=_name, type=DbSqlType.VOID))
574+
continue
575+
else:
576+
_value = str(param.value)
577+
578+
new_params.append(DbSqlParameter(name=_name, value=_value, type=_type))
557579

558-
param.value = str(param.value)
559580
return new_params
560581

561582

@@ -594,11 +615,15 @@ def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]
594615
dbsql_params = named_parameters_to_dbsqlparams_v2(parameters)
595616
inferred_type_parameters = infer_types(dbsql_params)
596617
for param in inferred_type_parameters:
597-
tspark_params.append(
598-
TSparkParameter(
599-
type=param.type.value,
600-
name=param.name,
601-
value=TSparkParameterValue(stringValue=param.value),
618+
# The only way to pass a VOID/NULL to DBR is to declare TSparkParameter without declaring
619+
# its value or type arguments. If we set these to NoneType, the request will fail with a
620+
# thrift transport error
621+
if param.type == DbSqlType.VOID:
622+
this_tspark_param = TSparkParameter(name=param.name)
623+
else:
624+
this_tspark_param_value = TSparkParameterValue(stringValue=param.value)
625+
this_tspark_param = TSparkParameter(
626+
type=param.type.value, name=param.name, value=this_tspark_param_value
602627
)
603-
)
628+
tspark_params.append(this_tspark_param)
604629
return tspark_params

tests/e2e/common/parameterized_query_tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def _quantize(self, input: Union[float, int], place_value=2) -> Decimal:
3434

3535
return Decimal(str(input)).quantize(Decimal("0." + "0" * place_value))
3636

37+
def test_primitive_inferred_none(self):
38+
39+
params = {"p": None}
40+
result = self._get_one_result(self.QUERY, params)
41+
assert result.col == None
42+
3743
def test_primitive_inferred_bool(self):
3844

3945
params = {"p": True}
@@ -79,6 +85,12 @@ def test_primitive_inferred_decimal(self):
7985
result = self._get_one_result(self.QUERY, params)
8086
assert result.col == Decimal("1234.56")
8187

88+
def test_dbsqlparam_inferred_none(self):
89+
90+
params = [DbSqlParameter(name="p", value=None, type=None)]
91+
result = self._get_one_result(self.QUERY, params)
92+
assert result.col == None
93+
8294
def test_dbsqlparam_inferred_bool(self):
8395

8496
params = [DbSqlParameter(name="p", value=True, type=None)]
@@ -124,6 +136,12 @@ def test_dbsqlparam_inferred_decimal(self):
124136
result = self._get_one_result(self.QUERY, params)
125137
assert result.col == Decimal("1234.56")
126138

139+
def test_dbsqlparam_explicit_none(self):
140+
141+
params = [DbSqlParameter(name="p", value=None, type=DbSqlType.VOID)]
142+
result = self._get_one_result(self.QUERY, params)
143+
assert result.col == None
144+
127145
def test_dbsqlparam_explicit_bool(self):
128146

129147
params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]

tests/e2e/test_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_escape_single_quotes(self):
341341
assert rows[0]["col_1"] == "you're"
342342

343343
# Test escape syntax in parameter
344-
cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"})
344+
cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE :var".format(table_name, table_name), parameters={"var": "you're"})
345345
rows = cursor.fetchall()
346346
assert rows[0]["col_1"] == "you're"
347347

tests/unit/test_parameters.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
named_parameters_to_dbsqlparams_v1,
55
named_parameters_to_dbsqlparams_v2,
66
calculate_decimal_cast_string,
7-
DbsqlDynamicDecimalType
7+
DbsqlDynamicDecimalType,
88
)
99
from databricks.sql.thrift_api.TCLIService.ttypes import (
1010
TSparkParameter,
@@ -37,7 +37,9 @@ def test_conversion_e2e(self):
3737
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
3838
),
3939
TSparkParameter(
40-
name="", type="DECIMAL(2,1)", value=TSparkParameterValue(stringValue="1.0")
40+
name="",
41+
type="DECIMAL(2,1)",
42+
value=TSparkParameterValue(stringValue="1.0"),
4143
),
4244
]
4345

@@ -91,26 +93,45 @@ def test_infer_types_decimal(self):
9193
# The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1)
9294
input = DbSqlParameter("", Decimal("1.0"))
9395
output: List[DbSqlParameter] = infer_types([input])
94-
96+
9597
x = output[0]
9698

9799
assert x.value == "1.0"
98100
assert isinstance(x.type, DbsqlDynamicDecimalType)
99101
assert x.type.value == "DECIMAL(2,1)"
100-
101102

102-
class TestCalculateDecimalCast(object):
103+
def test_infer_types_none(self):
104+
105+
input = DbSqlParameter("", None)
106+
output: List[DbSqlParameter] = infer_types([input])
107+
108+
x = output[0]
109+
110+
assert x.value == None
111+
assert x.type == DbSqlType.VOID
112+
assert x.type.value == "VOID"
103113

114+
def test_infer_types_unsupported(self):
115+
class ArbitraryType:
116+
pass
117+
118+
input = DbSqlParameter("", ArbitraryType())
119+
120+
with pytest.raises(ValueError, match="Could not infer parameter type from"):
121+
infer_types([input])
122+
123+
124+
class TestCalculateDecimalCast(object):
104125
def test_38_38(self):
105126
input = Decimal(".12345678912345678912345678912345678912")
106127
output = calculate_decimal_cast_string(input)
107128
assert output == "DECIMAL(38,38)"
108-
129+
109130
def test_18_9(self):
110131
input = Decimal("123456789.123456789")
111132
output = calculate_decimal_cast_string(input)
112133
assert output == "DECIMAL(18,9)"
113-
134+
114135
def test_38_0(self):
115136
input = Decimal("12345678912345678912345678912345678912")
116137
output = calculate_decimal_cast_string(input)
@@ -119,4 +140,4 @@ def test_38_0(self):
119140
def test_6_2(self):
120141
input = Decimal("1234.56")
121142
output = calculate_decimal_cast_string(input)
122-
assert output == "DECIMAL(6,2)"
143+
assert output == "DECIMAL(6,2)"

0 commit comments

Comments
 (0)