Skip to content

Commit 776f34b

Browse files
authored
[PECO-1026] Add Parameterized Query support to Python (#217)
* Initial commit Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com> * Added tsparkparam handling Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com> * Added basic test Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com> * Addressed comments Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com> * Addressed missed comments Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com> * Resolved comments --------- Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent 8211337 commit 776f34b

File tree

4 files changed

+184
-10
lines changed

4 files changed

+184
-10
lines changed

src/databricks/sql/client.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
CursorAlreadyClosedError,
1515
)
1616
from databricks.sql.thrift_backend import ThriftBackend
17-
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
17+
from databricks.sql.utils import (
18+
ExecuteResponse,
19+
ParamEscaper,
20+
named_parameters_to_tsparkparams,
21+
)
1822
from databricks.sql.types import Row
1923
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
2024
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
@@ -482,7 +486,9 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
482486
)
483487

484488
def execute(
485-
self, operation: str, parameters: Optional[Dict[str, str]] = None
489+
self,
490+
operation: str,
491+
parameters: Optional[Union[List[Any], Dict[str, str]]] = None,
486492
) -> "Cursor":
487493
"""
488494
Execute a query and wait for execution to complete.
@@ -493,10 +499,10 @@ def execute(
493499
Will result in the query "SELECT * FROM table WHERE field = 'foo' being sent to the server
494500
:returns self
495501
"""
496-
if parameters is not None:
497-
operation = inject_parameters(
498-
operation, self.escaper.escape_args(parameters)
499-
)
502+
if parameters is None:
503+
parameters = []
504+
else:
505+
parameters = named_parameters_to_tsparkparams(parameters)
500506

501507
self._check_not_closed()
502508
self._close_and_clear_active_result_set()
@@ -508,6 +514,7 @@ def execute(
508514
lz4_compression=self.connection.lz4_compression,
509515
cursor=self,
510516
use_cloud_fetch=self.connection.use_cloud_fetch,
517+
parameters=parameters,
511518
)
512519
self.active_result_set = ResultSet(
513520
self.connection,

src/databricks/sql/thrift_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
def _initialize_retry_args(self, kwargs):
225225
# Configure retries & timing: use user-settings or defaults, and bound
226226
# by policy. Log.warn when given param gets restricted.
227-
for (key, (type_, default, min, max)) in _retry_policy.items():
227+
for key, (type_, default, min, max) in _retry_policy.items():
228228
given_or_default = type_(kwargs.get(key, default))
229229
bound = _bound(min, max, given_or_default)
230230
setattr(self, key, bound)
@@ -368,7 +368,6 @@ def attempt_request(attempt):
368368

369369
error, error_message, retry_delay = None, None, None
370370
try:
371-
372371
this_method_name = getattr(method, "__name__")
373372

374373
logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
@@ -614,7 +613,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
614613
num_rows,
615614
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
616615
elif t_row_set.arrowBatches is not None:
617-
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
616+
(
617+
arrow_table,
618+
num_rows,
619+
) = convert_arrow_based_set_to_arrow_table(
618620
t_row_set.arrowBatches, lz4_compressed, schema_bytes
619621
)
620622
else:
@@ -813,6 +815,7 @@ def execute_command(
813815
lz4_compression,
814816
cursor,
815817
use_cloud_fetch=False,
818+
parameters=[],
816819
):
817820
assert session_handle is not None
818821

@@ -839,6 +842,7 @@ def execute_command(
839842
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
840843
},
841844
useArrowNativeTypes=spark_arrow_types,
845+
parameters=parameters,
842846
)
843847
resp = self.make_request(self._client.ExecuteStatement, req)
844848
return self._handle_execute_response(resp, cursor)

src/databricks/sql/utils.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from abc import ABC, abstractmethod
23
from collections import namedtuple, OrderedDict
34
from collections.abc import Iterable
@@ -8,13 +9,17 @@
89
import lz4.frame
910
from typing import Dict, List, Union, Any
1011
import pyarrow
12+
from enum import Enum
13+
import copy
1114

1215
from databricks.sql import exc, OperationalError
1316
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
1417
from databricks.sql.thrift_api.TCLIService.ttypes import (
1518
TSparkArrowResultLink,
1619
TSparkRowSetType,
1720
TRowSet,
21+
TSparkParameter,
22+
TSparkParameterValue,
1823
)
1924

2025
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
@@ -404,7 +409,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
404409

405410

406411
def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
407-
for (i, col) in enumerate(table.itercolumns()):
412+
for i, col in enumerate(table.itercolumns()):
408413
if description[i][1] == "decimal":
409414
decimal_col = col.to_pandas().apply(
410415
lambda v: v if v is None else Decimal(v)
@@ -470,3 +475,86 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
470475
result[i] = None
471476

472477
return pyarrow.array(result, type=arrow_type)
478+
479+
480+
class DbSqlType(Enum):
481+
STRING = "STRING"
482+
DATE = "DATE"
483+
TIMESTAMP = "TIMESTAMP"
484+
FLOAT = "FLOAT"
485+
DECIMAL = "DECIMAL"
486+
INTEGER = "INTEGER"
487+
BIGINT = "BIGINT"
488+
SMALLINT = "SMALLINT"
489+
TINYINT = "TINYINT"
490+
BOOLEAN = "BOOLEAN"
491+
INTERVAL_MONTH = "INTERVAL MONTH"
492+
INTERVAL_DAY = "INTERVAL DAY"
493+
494+
495+
class DbSqlParameter:
496+
name: str
497+
value: Any
498+
type: DbSqlType
499+
500+
def __init__(self, name="", value=None, type=None):
501+
self.name = name
502+
self.value = value
503+
self.type = type
504+
505+
def __eq__(self, other):
506+
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
507+
508+
509+
def named_parameters_to_dbsqlparams_v1(parameters: Dict[str, str]):
510+
dbsqlparams = []
511+
for name, parameter in parameters.items():
512+
dbsqlparams.append(DbSqlParameter(name=name, value=parameter))
513+
return dbsqlparams
514+
515+
516+
def named_parameters_to_dbsqlparams_v2(parameters: List[Any]):
517+
dbsqlparams = []
518+
for parameter in parameters:
519+
if isinstance(parameter, DbSqlParameter):
520+
dbsqlparams.append(parameter)
521+
else:
522+
dbsqlparams.append(DbSqlParameter(value=parameter))
523+
return dbsqlparams
524+
525+
526+
def infer_types(params: list[DbSqlParameter]):
527+
type_lookup_table = {
528+
str: DbSqlType.STRING,
529+
int: DbSqlType.INTEGER,
530+
float: DbSqlType.FLOAT,
531+
datetime.datetime: DbSqlType.TIMESTAMP,
532+
bool: DbSqlType.BOOLEAN,
533+
}
534+
newParams = copy.deepcopy(params)
535+
for param in newParams:
536+
if not param.type:
537+
if type(param.value) in type_lookup_table:
538+
param.type = type_lookup_table[type(param.value)]
539+
else:
540+
raise ValueError("Parameter type cannot be inferred")
541+
param.value = str(param.value)
542+
return newParams
543+
544+
545+
def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):
546+
tspark_params = []
547+
if isinstance(parameters, dict):
548+
dbsql_params = named_parameters_to_dbsqlparams_v1(parameters)
549+
else:
550+
dbsql_params = named_parameters_to_dbsqlparams_v2(parameters)
551+
inferred_type_parameters = infer_types(dbsql_params)
552+
for param in inferred_type_parameters:
553+
tspark_params.append(
554+
TSparkParameter(
555+
type=param.type.value,
556+
name=param.name,
557+
value=TSparkParameterValue(stringValue=param.value),
558+
)
559+
)
560+
return tspark_params

tests/unit/test_parameters.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from databricks.sql.utils import (
2+
named_parameters_to_tsparkparams,
3+
infer_types,
4+
named_parameters_to_dbsqlparams_v1,
5+
named_parameters_to_dbsqlparams_v2,
6+
)
7+
from databricks.sql.thrift_api.TCLIService.ttypes import (
8+
TSparkParameter,
9+
TSparkParameterValue,
10+
)
11+
from databricks.sql.utils import DbSqlParameter, DbSqlType
12+
import pytest
13+
14+
15+
class TestTSparkParameterConversion(object):
16+
def test_conversion_e2e(self):
17+
"""This behaviour falls back to Python's default string formatting of numbers"""
18+
assert named_parameters_to_tsparkparams(
19+
["a", 1, True, 1.0, DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)]
20+
) == [
21+
TSparkParameter(
22+
name="", type="STRING", value=TSparkParameterValue(stringValue="a")
23+
),
24+
TSparkParameter(
25+
name="", type="INTEGER", value=TSparkParameterValue(stringValue="1")
26+
),
27+
TSparkParameter(
28+
name="", type="BOOLEAN", value=TSparkParameterValue(stringValue="True")
29+
),
30+
TSparkParameter(
31+
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
32+
),
33+
TSparkParameter(
34+
name="", type="DECIMAL", value=TSparkParameterValue(stringValue="1.0")
35+
),
36+
]
37+
38+
def test_basic_conversions_v1(self):
39+
# Test legacy codepath
40+
assert named_parameters_to_dbsqlparams_v1({"1": 1, "2": "foo", "3": 2.0}) == [
41+
DbSqlParameter("1", 1),
42+
DbSqlParameter("2", "foo"),
43+
DbSqlParameter("3", 2.0),
44+
]
45+
46+
def test_basic_conversions_v2(self):
47+
# Test interspersing named params with unnamed
48+
assert named_parameters_to_dbsqlparams_v2(
49+
[DbSqlParameter("1", 1.0, DbSqlType.DECIMAL), 5, DbSqlParameter("3", "foo")]
50+
) == [
51+
DbSqlParameter("1", 1.0, DbSqlType.DECIMAL),
52+
DbSqlParameter("", 5),
53+
DbSqlParameter("3", "foo"),
54+
]
55+
56+
def test_type_inference(self):
57+
with pytest.raises(ValueError):
58+
infer_types([DbSqlParameter("", None)])
59+
with pytest.raises(ValueError):
60+
infer_types([DbSqlParameter("", {1: 1})])
61+
assert infer_types([DbSqlParameter("", 1)]) == [
62+
DbSqlParameter("", "1", DbSqlType.INTEGER)
63+
]
64+
assert infer_types([DbSqlParameter("", True)]) == [
65+
DbSqlParameter("", "True", DbSqlType.BOOLEAN)
66+
]
67+
assert infer_types([DbSqlParameter("", 1.0)]) == [
68+
DbSqlParameter("", "1.0", DbSqlType.FLOAT)
69+
]
70+
assert infer_types([DbSqlParameter("", "foo")]) == [
71+
DbSqlParameter("", "foo", DbSqlType.STRING)
72+
]
73+
assert infer_types([DbSqlParameter("", 1.0, DbSqlType.DECIMAL)]) == [
74+
DbSqlParameter("", "1.0", DbSqlType.DECIMAL)
75+
]

0 commit comments

Comments
 (0)