Skip to content

Commit fcc262f

Browse files
author
Jesse
authored
Parameterized queries: Add e2e tests for inference (#227)
1 parent 776f34b commit fcc262f

File tree

5 files changed

+151
-6
lines changed

5 files changed

+151
-6
lines changed

src/databricks/sql/thrift_backend.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
613613
num_rows,
614614
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
615615
elif t_row_set.arrowBatches is not None:
616-
(
617-
arrow_table,
618-
num_rows,
619-
) = convert_arrow_based_set_to_arrow_table(
616+
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
620617
t_row_set.arrowBatches, lz4_compressed, schema_bytes
621618
)
622619
else:

src/databricks/sql/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def infer_types(params: list[DbSqlParameter]):
529529
int: DbSqlType.INTEGER,
530530
float: DbSqlType.FLOAT,
531531
datetime.datetime: DbSqlType.TIMESTAMP,
532+
datetime.date: DbSqlType.DATE,
532533
bool: DbSqlType.BOOLEAN,
533534
}
534535
newParams = copy.deepcopy(params)

src/databricks/sqlalchemy/dialect/requirements.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def __some_example_requirement(self):
2424
import sqlalchemy.testing.exclusions
2525

2626
import logging
27+
2728
logger = logging.getLogger(__name__)
2829

2930
logger.warning("requirements.py is not currently employed by Databricks dialect")
3031

32+
3133
class Requirements(sqlalchemy.testing.requirements.SuiteRequirements):
32-
pass
34+
pass
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import datetime
2+
from decimal import Decimal
3+
from typing import Dict, List, Tuple, Union
4+
5+
import pytz
6+
7+
from databricks.sql.client import Connection
8+
from databricks.sql.utils import DbSqlParameter, DbSqlType
9+
10+
11+
class PySQLParameterizedQueryTestSuiteMixin:
12+
"""Namespace for tests of server-side parameterized queries"""
13+
14+
QUERY = "SELECT :p AS col"
15+
16+
def _get_one_result(self, query: str, parameters: Union[Dict, List[Dict]]) -> Tuple:
17+
with self.connection() as conn:
18+
with conn.cursor() as cursor:
19+
cursor.execute(query, parameters=parameters)
20+
return cursor.fetchone()
21+
22+
def _quantize(self, input: Union[float, int], place_value=2) -> Decimal:
23+
24+
return Decimal(str(input)).quantize(Decimal("0." + "0" * place_value))
25+
26+
def test_primitive_inferred_bool(self):
27+
28+
params = {"p": True}
29+
result = self._get_one_result(self.QUERY, params)
30+
assert result.col == True
31+
32+
def test_primitive_inferred_integer(self):
33+
34+
params = {"p": 1}
35+
result = self._get_one_result(self.QUERY, params)
36+
assert result.col == 1
37+
38+
def test_primitive_inferred_double(self):
39+
40+
params = {"p": 3.14}
41+
result = self._get_one_result(self.QUERY, params)
42+
assert self._quantize(result.col) == self._quantize(3.14)
43+
44+
def test_primitive_inferred_date(self):
45+
46+
# DATE in Databricks is mapped into a datetime.date object in Python
47+
date_value = datetime.date(2023, 9, 6)
48+
params = {"p": date_value}
49+
result = self._get_one_result(self.QUERY, params)
50+
assert result.col == date_value
51+
52+
def test_primitive_inferred_timestamp(self):
53+
54+
# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
55+
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
56+
params = {"p": date_value}
57+
result = self._get_one_result(self.QUERY, params)
58+
assert result.col == date_value
59+
60+
def test_primitive_inferred_string(self):
61+
62+
params = {"p": "Hello"}
63+
result = self._get_one_result(self.QUERY, params)
64+
assert result.col == "Hello"
65+
66+
def test_dbsqlparam_inferred_bool(self):
67+
68+
params = [DbSqlParameter(name="p", value=True, type=None)]
69+
result = self._get_one_result(self.QUERY, params)
70+
assert result.col == True
71+
72+
def test_dbsqlparam_inferred_integer(self):
73+
74+
params = [DbSqlParameter(name="p", value=1, type=None)]
75+
result = self._get_one_result(self.QUERY, params)
76+
assert result.col == 1
77+
78+
def test_dbsqlparam_inferred_double(self):
79+
80+
params = [DbSqlParameter(name="p", value=3.14, type=None)]
81+
result = self._get_one_result(self.QUERY, params)
82+
assert self._quantize(result.col) == self._quantize(3.14)
83+
84+
def test_dbsqlparam_inferred_date(self):
85+
86+
# DATE in Databricks is mapped into a datetime.date object in Python
87+
date_value = datetime.date(2023, 9, 6)
88+
params = [DbSqlParameter(name="p", value=date_value, type=None)]
89+
result = self._get_one_result(self.QUERY, params)
90+
assert result.col == date_value
91+
92+
def test_dbsqlparam_inferred_timestamp(self):
93+
94+
# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
95+
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
96+
params = [DbSqlParameter(name="p", value=date_value, type=None)]
97+
result = self._get_one_result(self.QUERY, params)
98+
assert result.col == date_value
99+
100+
def test_dbsqlparam_inferred_string(self):
101+
102+
params = [DbSqlParameter(name="p", value="Hello", type=None)]
103+
result = self._get_one_result(self.QUERY, params)
104+
assert result.col == "Hello"
105+
106+
def test_dbsqlparam_explicit_bool(self):
107+
108+
params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]
109+
result = self._get_one_result(self.QUERY, params)
110+
assert result.col == True
111+
112+
def test_dbsqlparam_explicit_integer(self):
113+
114+
params = [DbSqlParameter(name="p", value=1, type=DbSqlType.INTEGER)]
115+
result = self._get_one_result(self.QUERY, params)
116+
assert result.col == 1
117+
118+
def test_dbsqlparam_explicit_double(self):
119+
120+
params = [DbSqlParameter(name="p", value=3.14, type=DbSqlType.FLOAT)]
121+
result = self._get_one_result(self.QUERY, params)
122+
assert self._quantize(result.col) == self._quantize(3.14)
123+
124+
def test_dbsqlparam_explicit_date(self):
125+
126+
# DATE in Databricks is mapped into a datetime.date object in Python
127+
date_value = datetime.date(2023, 9, 6)
128+
params = [DbSqlParameter(name="p", value=date_value, type=DbSqlType.DATE)]
129+
result = self._get_one_result(self.QUERY, params)
130+
assert result.col == date_value
131+
132+
def test_dbsqlparam_explicit_timestamp(self):
133+
134+
# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
135+
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
136+
params = [DbSqlParameter(name="p", value=date_value, type=DbSqlType.TIMESTAMP)]
137+
result = self._get_one_result(self.QUERY, params)
138+
assert result.col == date_value
139+
140+
def test_dbsqlparam_explicit_string(self):
141+
142+
params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)]
143+
result = self._get_one_result(self.QUERY, params)
144+
assert result.col == "Hello"

tests/e2e/test_driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin
2929
from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin
3030
from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin
31+
from tests.e2e.common.parameterized_query_tests import PySQLParameterizedQueryTestSuiteMixin
3132

3233
log = logging.getLogger(__name__)
3334

@@ -142,7 +143,7 @@ def test_cloud_fetch(self):
142143
# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core
143144
# tests
144145
class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin,
145-
PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin):
146+
PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLParameterizedQueryTestSuiteMixin):
146147
validate_row_value_type = True
147148
validate_result = True
148149

0 commit comments

Comments
 (0)