Skip to content

Commit 58ecda3

Browse files
author
Jesse
authored
Parameters: Add type inference for BIGINT and TINYINT types (#246)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 317a471 commit 58ecda3

File tree

2 files changed

+56
-21
lines changed

2 files changed

+56
-21
lines changed

src/databricks/sql/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,16 @@ def named_parameters_to_dbsqlparams_v2(parameters: List[Any]):
533533
return dbsqlparams
534534

535535

536+
def resolve_databricks_sql_integer_type(integer):
537+
"""Returns the smallest Databricks SQL integer type that can contain the passed integer"""
538+
if -128 <= integer <= 127:
539+
return DbSqlType.TINYINT
540+
elif -2147483648 <= integer <= 2147483647:
541+
return DbSqlType.INTEGER
542+
else:
543+
return DbSqlType.BIGINT
544+
545+
536546
def infer_types(params: list[DbSqlParameter]):
537547
type_lookup_table = {
538548
str: DbSqlType.STRING,
@@ -568,6 +578,10 @@ def infer_types(params: list[DbSqlParameter]):
568578
cast_exp = calculate_decimal_cast_string(param.value)
569579
_type = DbsqlDynamicDecimalType(cast_exp)
570580

581+
# int() requires special handling because one Python type can be cast to multiple SQL types (INT, BIGINT, TINYINT)
582+
if _type == DbSqlType.INTEGER:
583+
_type = resolve_databricks_sql_integer_type(param.value)
584+
571585
# VOID / NULL types must be passed in a unique way as TSparkParameters with no value
572586
if _type == DbSqlType.VOID:
573587
new_params.append(DbSqlParameter(name=_name, type=DbSqlType.VOID))

tests/unit/test_parameters.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,36 @@
1919

2020

2121
class TestTSparkParameterConversion(object):
22-
def test_conversion_e2e(self):
22+
@pytest.mark.parametrize(
23+
"input_value, expected_type",
24+
[
25+
("a", "STRING"),
26+
(1, "TINYINT"),
27+
(1000, "INTEGER"),
28+
(9223372036854775807, "BIGINT"), # Max value of a signed 64-bit integer
29+
(True, "BOOLEAN"),
30+
(1.0, "FLOAT"),
31+
],
32+
)
33+
def test_conversion_e2e(self, input_value, expected_type):
2334
"""This behaviour falls back to Python's default string formatting of numbers"""
24-
assert named_parameters_to_tsparkparams(
25-
["a", 1, True, 1.0, DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)]
26-
) == [
27-
TSparkParameter(
28-
name="", type="STRING", value=TSparkParameterValue(stringValue="a")
29-
),
30-
TSparkParameter(
31-
name="", type="INTEGER", value=TSparkParameterValue(stringValue="1")
32-
),
33-
TSparkParameter(
34-
name="", type="BOOLEAN", value=TSparkParameterValue(stringValue="True")
35-
),
36-
TSparkParameter(
37-
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
38-
),
35+
output = named_parameters_to_tsparkparams([input_value])
36+
expected = TSparkParameter(
37+
name="",
38+
type=expected_type,
39+
value=TSparkParameterValue(stringValue=str(input_value)),
40+
)
41+
assert output == [expected]
42+
43+
def test_conversion_e2e_decimal(self):
44+
input = DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)
45+
output = named_parameters_to_tsparkparams([input])
46+
assert output == [
3947
TSparkParameter(
4048
name="",
4149
type="DECIMAL(2,1)",
4250
value=TSparkParameterValue(stringValue="1.0"),
43-
),
51+
)
4452
]
4553

4654
def test_basic_conversions_v1(self):
@@ -69,10 +77,24 @@ def test_infer_types_dict(self):
6977
with pytest.raises(ValueError):
7078
infer_types([DbSqlParameter("", {1: 1})])
7179

72-
def test_infer_types_integer(self):
73-
input = DbSqlParameter("", 1)
80+
@pytest.mark.parametrize(
81+
"input_value, expected_type",
82+
[
83+
(-128, DbSqlType.TINYINT),
84+
(127, DbSqlType.TINYINT),
85+
(-2147483649, DbSqlType.BIGINT),
86+
(-2147483648, DbSqlType.INTEGER),
87+
(2147483647, DbSqlType.INTEGER),
88+
(-9223372036854775808, DbSqlType.BIGINT),
89+
(9223372036854775807, DbSqlType.BIGINT),
90+
],
91+
)
92+
def test_infer_types_integer(self, input_value, expected_type):
93+
input = DbSqlParameter("", input_value)
7494
output = infer_types([input])
75-
assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)]
95+
assert output == [
96+
DbSqlParameter("", str(input_value), expected_type)
97+
], f"{output[0].type} received, expected {expected_type}"
7698

7799
def test_infer_types_boolean(self):
78100
input = DbSqlParameter("", True)
@@ -101,7 +123,6 @@ def test_infer_types_decimal(self):
101123
assert x.type.value == "DECIMAL(2,1)"
102124

103125
def test_infer_types_none(self):
104-
105126
input = DbSqlParameter("", None)
106127
output: List[DbSqlParameter] = infer_types([input])
107128

0 commit comments

Comments
 (0)