Skip to content

Commit efc0337

Browse files
author
Jesse
authored
SQLAlchemy 2: add type compilation for uppercase types (#240)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 74f4126 commit efc0337

File tree

2 files changed

+110
-77
lines changed

2 files changed

+110
-77
lines changed
Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,7 @@
11
import enum
22

33
import pytest
4-
from sqlalchemy.types import (
5-
BigInteger,
6-
Boolean,
7-
Date,
8-
DateTime,
9-
Double,
10-
Enum,
11-
Float,
12-
Integer,
13-
Interval,
14-
LargeBinary,
15-
MatchType,
16-
Numeric,
17-
PickleType,
18-
SchemaType,
19-
SmallInteger,
20-
String,
21-
Text,
22-
Time,
23-
TypeEngine,
24-
Unicode,
25-
UnicodeText,
26-
Uuid,
27-
)
4+
import sqlalchemy
285

296
from databricks.sqlalchemy import DatabricksDialect
307

@@ -55,43 +32,49 @@ class DatabricksDataType(enum.Enum):
5532
# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
5633
# Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that.
5734
camel_case_type_map = {
58-
BigInteger: DatabricksDataType.BIGINT,
59-
LargeBinary: DatabricksDataType.BINARY,
60-
Boolean: DatabricksDataType.BOOLEAN,
61-
Date: DatabricksDataType.DATE,
62-
DateTime: DatabricksDataType.TIMESTAMP,
63-
Double: DatabricksDataType.DOUBLE,
64-
Enum: DatabricksDataType.STRING,
65-
Float: DatabricksDataType.FLOAT,
66-
Integer: DatabricksDataType.INT,
67-
Interval: DatabricksDataType.TIMESTAMP,
68-
Numeric: DatabricksDataType.DECIMAL,
69-
PickleType: DatabricksDataType.BINARY,
70-
SmallInteger: DatabricksDataType.SMALLINT,
71-
String: DatabricksDataType.STRING,
72-
Text: DatabricksDataType.STRING,
73-
Time: DatabricksDataType.STRING,
74-
Unicode: DatabricksDataType.STRING,
75-
UnicodeText: DatabricksDataType.STRING,
76-
Uuid: DatabricksDataType.STRING,
35+
sqlalchemy.types.BigInteger: DatabricksDataType.BIGINT,
36+
sqlalchemy.types.LargeBinary: DatabricksDataType.BINARY,
37+
sqlalchemy.types.Boolean: DatabricksDataType.BOOLEAN,
38+
sqlalchemy.types.Date: DatabricksDataType.DATE,
39+
sqlalchemy.types.DateTime: DatabricksDataType.TIMESTAMP,
40+
sqlalchemy.types.Double: DatabricksDataType.DOUBLE,
41+
sqlalchemy.types.Enum: DatabricksDataType.STRING,
42+
sqlalchemy.types.Float: DatabricksDataType.FLOAT,
43+
sqlalchemy.types.Integer: DatabricksDataType.INT,
44+
sqlalchemy.types.Interval: DatabricksDataType.TIMESTAMP,
45+
sqlalchemy.types.Numeric: DatabricksDataType.DECIMAL,
46+
sqlalchemy.types.PickleType: DatabricksDataType.BINARY,
47+
sqlalchemy.types.SmallInteger: DatabricksDataType.SMALLINT,
48+
sqlalchemy.types.String: DatabricksDataType.STRING,
49+
sqlalchemy.types.Text: DatabricksDataType.STRING,
50+
sqlalchemy.types.Time: DatabricksDataType.STRING,
51+
sqlalchemy.types.Unicode: DatabricksDataType.STRING,
52+
sqlalchemy.types.UnicodeText: DatabricksDataType.STRING,
53+
sqlalchemy.types.Uuid: DatabricksDataType.STRING,
7754
}
7855

79-
# Convert the dictionary into a list of tuples for use in pytest.mark.parametrize
80-
_as_tuple_list = [(key, value) for key, value in camel_case_type_map.items()]
56+
57+
def dict_as_tuple_list(d: dict):
58+
"""Return a list of [(key, value), ...] from a dictionary."""
59+
return [(key, value) for key, value in d.items()]
8160

8261

8362
class CompilationTestBase:
8463
dialect = DatabricksDialect()
8564

86-
def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType):
65+
def _assert_compiled_value(
66+
self, type_: sqlalchemy.types.TypeEngine, expected: DatabricksDataType
67+
):
8768
"""Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name.
8869
8970
This method initialises the type_ with no arguments.
9071
"""
9172
compiled_result = type_().compile(dialect=self.dialect) # type: ignore
9273
assert compiled_result == expected.name
9374

94-
def _assert_compiled_value_explicit(self, type_: TypeEngine, expected: str):
75+
def _assert_compiled_value_explicit(
76+
self, type_: sqlalchemy.types.TypeEngine, expected: str
77+
):
9578
"""Assert that when type_ is compiled for the databricks dialect, it renders the expected string.
9679
9780
This method expects an initialised type_ so that we can test how a TypeEngine created with arguments
@@ -117,12 +100,57 @@ class TestCamelCaseTypesCompilation(CompilationTestBase):
117100
[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types
118101
"""
119102

120-
@pytest.mark.parametrize("type_, expected", _as_tuple_list)
103+
@pytest.mark.parametrize("type_, expected", dict_as_tuple_list(camel_case_type_map))
121104
def test_bare_camel_case_types_compile(self, type_, expected):
122105
self._assert_compiled_value(type_, expected)
123106

124107
def test_numeric_renders_as_decimal_with_precision(self):
125-
self._assert_compiled_value_explicit(Numeric(10), "DECIMAL(10)")
108+
self._assert_compiled_value_explicit(
109+
sqlalchemy.types.Numeric(10), "DECIMAL(10)"
110+
)
126111

127112
def test_numeric_renders_as_decimal_with_precision_and_scale(self):
128-
self._assert_compiled_value_explicit(Numeric(10, 2), "DECIMAL(10, 2)")
113+
return self._assert_compiled_value_explicit(
114+
sqlalchemy.types.Numeric(10, 2), "DECIMAL(10, 2)"
115+
)
116+
117+
118+
uppercase_type_map = {
119+
sqlalchemy.types.ARRAY: DatabricksDataType.ARRAY,
120+
sqlalchemy.types.BIGINT: DatabricksDataType.BIGINT,
121+
sqlalchemy.types.BINARY: DatabricksDataType.BINARY,
122+
sqlalchemy.types.BOOLEAN: DatabricksDataType.BOOLEAN,
123+
sqlalchemy.types.DATE: DatabricksDataType.DATE,
124+
sqlalchemy.types.DECIMAL: DatabricksDataType.DECIMAL,
125+
sqlalchemy.types.DOUBLE: DatabricksDataType.DOUBLE,
126+
sqlalchemy.types.FLOAT: DatabricksDataType.FLOAT,
127+
sqlalchemy.types.INT: DatabricksDataType.INT,
128+
sqlalchemy.types.SMALLINT: DatabricksDataType.SMALLINT,
129+
sqlalchemy.types.TIMESTAMP: DatabricksDataType.TIMESTAMP,
130+
}
131+
132+
133+
class TestUppercaseTypesCompilation(CompilationTestBase):
134+
"""Per the sqlalchemy documentation[^1], uppercase types are considered to be specific to some
135+
database backends. These tests verify that the types compile into valid Databricks SQL type strings.
136+
137+
[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#backend-specific-uppercase-datatypes
138+
"""
139+
140+
@pytest.mark.parametrize("type_, expected", dict_as_tuple_list(uppercase_type_map))
141+
def test_bare_uppercase_types_compile(self, type_, expected):
142+
if isinstance(type_, type(sqlalchemy.types.ARRAY)):
143+
# ARRAY cannot be initialised without passing an item definition so we test separately
144+
# I preserve it in the uppercase_type_map for clarity
145+
return True
146+
return self._assert_compiled_value(type_, expected)
147+
148+
def test_array_string_renders_as_array_of_string(self):
149+
"""SQLAlchemy's ARRAY type requires an item definition. And their docs indicate that they've only tested
150+
it with Postgres since that's the only first-class dialect with support for ARRAY.
151+
152+
https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY
153+
"""
154+
return self._assert_compiled_value_explicit(
155+
sqlalchemy.types.ARRAY(sqlalchemy.types.String), "ARRAY<STRING>"
156+
)

src/databricks/sqlalchemy/types.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,14 @@
1+
import sqlalchemy
12
from sqlalchemy.ext.compiler import compiles
2-
from sqlalchemy.sql.compiler import GenericTypeCompiler
3-
from sqlalchemy.types import (
4-
DateTime,
5-
Enum,
6-
Integer,
7-
LargeBinary,
8-
Numeric,
9-
String,
10-
Text,
11-
Time,
12-
Unicode,
13-
UnicodeText,
14-
Uuid,
15-
)
16-
17-
18-
@compiles(Enum, "databricks")
19-
@compiles(String, "databricks")
20-
@compiles(Text, "databricks")
21-
@compiles(Time, "databricks")
22-
@compiles(Unicode, "databricks")
23-
@compiles(UnicodeText, "databricks")
24-
@compiles(Uuid, "databricks")
3+
4+
5+
@compiles(sqlalchemy.types.Enum, "databricks")
6+
@compiles(sqlalchemy.types.String, "databricks")
7+
@compiles(sqlalchemy.types.Text, "databricks")
8+
@compiles(sqlalchemy.types.Time, "databricks")
9+
@compiles(sqlalchemy.types.Unicode, "databricks")
10+
@compiles(sqlalchemy.types.UnicodeText, "databricks")
11+
@compiles(sqlalchemy.types.Uuid, "databricks")
2512
def compile_string_databricks(type_, compiler, **kw):
2613
"""
2714
We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy
@@ -40,23 +27,23 @@ def compile_string_databricks(type_, compiler, **kw):
4027
return "STRING"
4128

4229

43-
@compiles(Integer, "databricks")
30+
@compiles(sqlalchemy.types.Integer, "databricks")
4431
def compile_integer_databricks(type_, compiler, **kw):
4532
"""
4633
We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER"
4734
"""
4835
return "INT"
4936

5037

51-
@compiles(LargeBinary, "databricks")
38+
@compiles(sqlalchemy.types.LargeBinary, "databricks")
5239
def compile_binary_databricks(type_, compiler, **kw):
5340
"""
5441
We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB"
5542
"""
5643
return "BINARY"
5744

5845

59-
@compiles(Numeric, "databricks")
46+
@compiles(sqlalchemy.types.Numeric, "databricks")
6047
def compile_numeric_databricks(type_, compiler, **kw):
6148
"""
6249
We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC"
@@ -67,9 +54,27 @@ def compile_numeric_databricks(type_, compiler, **kw):
6754
return compiler.visit_DECIMAL(type_, **kw)
6855

6956

70-
@compiles(DateTime, "databricks")
57+
@compiles(sqlalchemy.types.DateTime, "databricks")
7158
def compile_datetime_databricks(type_, compiler, **kw):
7259
"""
7360
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME"
7461
"""
7562
return "TIMESTAMP"
63+
64+
65+
@compiles(sqlalchemy.types.ARRAY, "databricks")
66+
def compile_array_databricks(type_, compiler, **kw):
67+
"""
68+
SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql.
69+
The Postgres implementation works for Databricks SQL, so we duplicate that here.
70+
71+
:type_:
72+
This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute
73+
which is itself an instance of TypeEngine
74+
75+
https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY
76+
"""
77+
78+
inner = compiler.process(type_.item_type, **kw)
79+
80+
return f"ARRAY<{inner}>"

0 commit comments

Comments
 (0)