Skip to content

Commit 74f4126

Browse files
author
Jesse
authored
SQLAlchemy 2: add type compilation for all CamelCase types (#238)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 9afb215 commit 74f4126

File tree

4 files changed

+223
-94
lines changed

4 files changed

+223
-94
lines changed

src/databricks/sqlalchemy/__init__.py

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
from databricks import sql
1414

15+
# This import is required to process our @compiles decorators
16+
import databricks.sqlalchemy.types
17+
1518

1619
from databricks.sqlalchemy.base import (
1720
DatabricksDDLCompiler,
1821
DatabricksIdentifierPreparer,
1922
)
20-
from databricks.sqlalchemy.compiler import DatabricksTypeCompiler
2123

2224
try:
2325
import alembic
@@ -30,52 +32,14 @@ class DatabricksImpl(DefaultImpl):
3032
__dialect__ = "databricks"
3133

3234

33-
class DatabricksDecimal(types.TypeDecorator):
34-
"""Translates strings to decimals"""
35-
36-
impl = types.DECIMAL
37-
38-
def process_result_value(self, value, dialect):
39-
if value is not None:
40-
return decimal.Decimal(value)
41-
else:
42-
return None
43-
44-
45-
class DatabricksTimestamp(types.TypeDecorator):
46-
"""Translates timestamp strings to datetime objects"""
47-
48-
impl = types.TIMESTAMP
49-
50-
def process_result_value(self, value, dialect):
51-
return value
52-
53-
def adapt(self, impltype, **kwargs):
54-
return self.impl
55-
56-
57-
class DatabricksDate(types.TypeDecorator):
58-
"""Translates date strings to date objects"""
59-
60-
impl = types.DATE
61-
62-
def process_result_value(self, value, dialect):
63-
return value
64-
65-
def adapt(self, impltype, **kwargs):
66-
return self.impl
67-
68-
6935
class DatabricksDialect(default.DefaultDialect):
7036
"""This dialect implements only those methods required to pass our e2e tests"""
7137

7238
# Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect
7339
name: str = "databricks"
7440
driver: str = "databricks"
7541
default_schema_name: str = "default"
76-
7742
preparer = DatabricksIdentifierPreparer # type: ignore
78-
type_compiler = DatabricksTypeCompiler
7943
ddl_compiler = DatabricksDDLCompiler
8044
supports_statement_cache: bool = True
8145
supports_multivalues_insert: bool = True
@@ -137,23 +101,23 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
137101
"""
138102

139103
_type_map = {
140-
"boolean": types.Boolean,
141-
"smallint": types.SmallInteger,
142-
"int": types.Integer,
143-
"bigint": types.BigInteger,
144-
"float": types.Float,
145-
"double": types.Float,
146-
"string": types.String,
147-
"varchar": types.String,
148-
"char": types.String,
149-
"binary": types.String,
150-
"array": types.String,
151-
"map": types.String,
152-
"struct": types.String,
153-
"uniontype": types.String,
154-
"decimal": DatabricksDecimal,
155-
"timestamp": DatabricksTimestamp,
156-
"date": DatabricksDate,
104+
"boolean": sqlalchemy.types.Boolean,
105+
"smallint": sqlalchemy.types.SmallInteger,
106+
"int": sqlalchemy.types.Integer,
107+
"bigint": sqlalchemy.types.BigInteger,
108+
"float": sqlalchemy.types.Float,
109+
"double": sqlalchemy.types.Float,
110+
"string": sqlalchemy.types.String,
111+
"varchar": sqlalchemy.types.String,
112+
"char": sqlalchemy.types.String,
113+
"binary": sqlalchemy.types.String,
114+
"array": sqlalchemy.types.String,
115+
"map": sqlalchemy.types.String,
116+
"struct": sqlalchemy.types.String,
117+
"uniontype": sqlalchemy.types.String,
118+
"decimal": sqlalchemy.types.Numeric,
119+
"timestamp": sqlalchemy.types.DateTime,
120+
"date": sqlalchemy.types.Date,
157121
}
158122

159123
with self.get_connection_cursor(connection) as cur:

src/databricks/sqlalchemy/compiler.py

Lines changed: 0 additions & 38 deletions
This file was deleted.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import enum
2+
3+
import pytest
4+
from sqlalchemy.types import (
5+
BigInteger,
6+
Boolean,
7+
Date,
8+
DateTime,
9+
Double,
10+
Enum,
11+
Float,
12+
Integer,
13+
Interval,
14+
LargeBinary,
15+
MatchType,
16+
Numeric,
17+
PickleType,
18+
SchemaType,
19+
SmallInteger,
20+
String,
21+
Text,
22+
Time,
23+
TypeEngine,
24+
Unicode,
25+
UnicodeText,
26+
Uuid,
27+
)
28+
29+
from databricks.sqlalchemy import DatabricksDialect
30+
31+
32+
class DatabricksDataType(enum.Enum):
33+
"""https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html"""
34+
35+
BIGINT = enum.auto()
36+
BINARY = enum.auto()
37+
BOOLEAN = enum.auto()
38+
DATE = enum.auto()
39+
DECIMAL = enum.auto()
40+
DOUBLE = enum.auto()
41+
FLOAT = enum.auto()
42+
INT = enum.auto()
43+
INTERVAL = enum.auto()
44+
VOID = enum.auto()
45+
SMALLINT = enum.auto()
46+
STRING = enum.auto()
47+
TIMESTAMP = enum.auto()
48+
TIMESTAMP_NTZ = enum.auto()
49+
TINYINT = enum.auto()
50+
ARRAY = enum.auto()
51+
MAP = enum.auto()
52+
STRUCT = enum.auto()
53+
54+
55+
# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
56+
# Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that.
57+
camel_case_type_map = {
58+
BigInteger: DatabricksDataType.BIGINT,
59+
LargeBinary: DatabricksDataType.BINARY,
60+
Boolean: DatabricksDataType.BOOLEAN,
61+
Date: DatabricksDataType.DATE,
62+
DateTime: DatabricksDataType.TIMESTAMP,
63+
Double: DatabricksDataType.DOUBLE,
64+
Enum: DatabricksDataType.STRING,
65+
Float: DatabricksDataType.FLOAT,
66+
Integer: DatabricksDataType.INT,
67+
Interval: DatabricksDataType.TIMESTAMP,
68+
Numeric: DatabricksDataType.DECIMAL,
69+
PickleType: DatabricksDataType.BINARY,
70+
SmallInteger: DatabricksDataType.SMALLINT,
71+
String: DatabricksDataType.STRING,
72+
Text: DatabricksDataType.STRING,
73+
Time: DatabricksDataType.STRING,
74+
Unicode: DatabricksDataType.STRING,
75+
UnicodeText: DatabricksDataType.STRING,
76+
Uuid: DatabricksDataType.STRING,
77+
}
78+
79+
# Convert the dictionary into a list of tuples for use in pytest.mark.parametrize
80+
_as_tuple_list = [(key, value) for key, value in camel_case_type_map.items()]
81+
82+
83+
class CompilationTestBase:
84+
dialect = DatabricksDialect()
85+
86+
def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType):
87+
"""Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name.
88+
89+
This method initialises the type_ with no arguments.
90+
"""
91+
compiled_result = type_().compile(dialect=self.dialect) # type: ignore
92+
assert compiled_result == expected.name
93+
94+
def _assert_compiled_value_explicit(self, type_: TypeEngine, expected: str):
95+
"""Assert that when type_ is compiled for the databricks dialect, it renders the expected string.
96+
97+
This method expects an initialised type_ so that we can test how a TypeEngine created with arguments
98+
is compiled.
99+
"""
100+
compiled_result = type_.compile(dialect=self.dialect)
101+
assert compiled_result == expected
102+
103+
104+
class TestCamelCaseTypesCompilation(CompilationTestBase):
105+
"""Per the sqlalchemy documentation[^1] here, the camel case members of sqlalchemy.types are
106+
are expected to work across all dialects. These tests verify that the types compile into valid
107+
Databricks SQL type strings. For example, the sqlalchemy.types.Integer() should compile as "INT".
108+
109+
Truly custom types like STRUCT (notice the uppercase) are not expected to work across all dialects.
110+
We test these separately.
111+
112+
Note that these tests have to do with type **name** compiliation. Which is separate from actually
113+
mapping values between Python and Databricks.
114+
115+
Note: SchemaType and MatchType are not tested because it's not used in table definitions
116+
117+
[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types
118+
"""
119+
120+
@pytest.mark.parametrize("type_, expected", _as_tuple_list)
121+
def test_bare_camel_case_types_compile(self, type_, expected):
122+
self._assert_compiled_value(type_, expected)
123+
124+
def test_numeric_renders_as_decimal_with_precision(self):
125+
self._assert_compiled_value_explicit(Numeric(10), "DECIMAL(10)")
126+
127+
def test_numeric_renders_as_decimal_with_precision_and_scale(self):
128+
self._assert_compiled_value_explicit(Numeric(10, 2), "DECIMAL(10, 2)")

src/databricks/sqlalchemy/types.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from sqlalchemy.ext.compiler import compiles
2+
from sqlalchemy.sql.compiler import GenericTypeCompiler
3+
from sqlalchemy.types import (
4+
DateTime,
5+
Enum,
6+
Integer,
7+
LargeBinary,
8+
Numeric,
9+
String,
10+
Text,
11+
Time,
12+
Unicode,
13+
UnicodeText,
14+
Uuid,
15+
)
16+
17+
18+
@compiles(Enum, "databricks")
19+
@compiles(String, "databricks")
20+
@compiles(Text, "databricks")
21+
@compiles(Time, "databricks")
22+
@compiles(Unicode, "databricks")
23+
@compiles(UnicodeText, "databricks")
24+
@compiles(Uuid, "databricks")
25+
def compile_string_databricks(type_, compiler, **kw):
26+
"""
27+
We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy
28+
defaults to incompatible / abnormal compiled names
29+
30+
Enum -> VARCHAR
31+
String -> VARCHAR[LENGTH]
32+
Text -> VARCHAR[LENGTH]
33+
Time -> TIME
34+
Unicode -> VARCHAR[LENGTH]
35+
UnicodeText -> TEXT
36+
Uuid -> CHAR[32]
37+
38+
But all of these types will be compiled to STRING in Databricks SQL
39+
"""
40+
return "STRING"
41+
42+
43+
@compiles(Integer, "databricks")
44+
def compile_integer_databricks(type_, compiler, **kw):
45+
"""
46+
We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER"
47+
"""
48+
return "INT"
49+
50+
51+
@compiles(LargeBinary, "databricks")
52+
def compile_binary_databricks(type_, compiler, **kw):
53+
"""
54+
We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB"
55+
"""
56+
return "BINARY"
57+
58+
59+
@compiles(Numeric, "databricks")
60+
def compile_numeric_databricks(type_, compiler, **kw):
61+
"""
62+
We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC"
63+
64+
The built-in visit_DECIMAL behaviour captures the precision and scale. Here we're just mapping calls to compile Numeric
65+
to the SQLAlchemy Decimal() implementation
66+
"""
67+
return compiler.visit_DECIMAL(type_, **kw)
68+
69+
70+
@compiles(DateTime, "databricks")
71+
def compile_datetime_databricks(type_, compiler, **kw):
72+
"""
73+
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME"
74+
"""
75+
return "TIMESTAMP"

0 commit comments

Comments
 (0)