Skip to content

Commit 7c72cf4

Browse files
author
Jesse
authored
SQLAlchemy 2: reorganise dialect files into a single directory (#231)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 1239bff commit 7c72cf4

File tree

15 files changed

+442
-398
lines changed

15 files changed

+442
-398
lines changed

CONTRIBUTING.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,11 @@ The suites marked `[not documented]` require additional configuration which will
148148

149149
SQLAlchemy provides reusable tests for testing dialect implementations.
150150

151-
To run these tests, assuming the environment variables needed for e2e tests are set:
151+
To run these tests, assuming the environment variables needed for e2e tests are set, do the following:
152152

153153
```
154-
poetry run python -m pytest tests/sqlalchemy_dialect_compliance --dburi \
154+
cd src/databricks/sqlalchemy
155+
poetry run python -m pytest test/sqlalchemy_dialect_compliance.py --dburi \
155156
"databricks://token:$access_token@$host?http_path=$http_path&catalog=$catalog&schema=$schema"
156157
```
157158

examples/sqlalchemy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
3535
# Known Gaps
3636
- MAP, ARRAY, and STRUCT types: this dialect can read these types out as strings. But you cannot
37-
define a SQLAlchemy model with databricks.sqlalchemy.dialect.types.DatabricksMap (e.g.) because
37+
define a SQLAlchemy model with databricks.sqlalchemy.types.DatabricksMap (e.g.) because
3838
we haven't implemented them yet.
3939
- Constraints: with the addition of information_schema to Unity Catalog, Databricks SQL supports
4040
foreign key and primary key constraints. This dialect can write these constraints but the ability

poetry.lock

Lines changed: 86 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ urllib3 = ">=1.0"
3333

3434
[tool.poetry.dev-dependencies]
3535
pytest = "^7.1.2"
36-
mypy = "^0.950"
36+
mypy = "^0.981"
3737
pylint = ">=2.12.0"
3838
black = "^22.3.0"
3939
pytest-dotenv = "^0.5.2"
@@ -62,5 +62,3 @@ log_cli = "false"
6262
log_cli_level = "INFO"
6363
testpaths = ["tests"]
6464
env_files = ["test.env"]
65-
addopts = "--ignore=tests/sqlalchemy_dialect_compliance"
66-

setup.cfg

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/databricks/sql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# PEP 249 module globals
66
apilevel = "2.0"
77
threadsafety = 1 # Threads may share the module, but not connections.
8-
paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s
8+
paramstyle = "named" # Python extended format codes, e.g. ...WHERE name=%(name)s
99

1010

1111
class DBAPITypeObject(object):
Lines changed: 341 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,341 @@
1-
from databricks.sqlalchemy.dialect import DatabricksDialect
1+
"""This module's layout loosely follows example of SQLAlchemy's postgres dialect
2+
"""
3+
4+
import decimal, re, datetime
5+
from dateutil.parser import parse
6+
7+
import sqlalchemy
8+
from sqlalchemy import types, event
9+
from sqlalchemy.engine import default, Engine
10+
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
11+
from sqlalchemy.engine import reflection
12+
13+
from databricks import sql
14+
15+
16+
from databricks.sqlalchemy.base import (
17+
DatabricksDDLCompiler,
18+
DatabricksIdentifierPreparer,
19+
)
20+
from databricks.sqlalchemy.compiler import DatabricksTypeCompiler
21+
22+
try:
23+
import alembic
24+
except ImportError:
25+
pass
26+
else:
27+
from alembic.ddl import DefaultImpl
28+
29+
class DatabricksImpl(DefaultImpl):
30+
__dialect__ = "databricks"
31+
32+
33+
class DatabricksDecimal(types.TypeDecorator):
34+
"""Translates strings to decimals"""
35+
36+
impl = types.DECIMAL
37+
38+
def process_result_value(self, value, dialect):
39+
if value is not None:
40+
return decimal.Decimal(value)
41+
else:
42+
return None
43+
44+
45+
class DatabricksTimestamp(types.TypeDecorator):
46+
"""Translates timestamp strings to datetime objects"""
47+
48+
impl = types.TIMESTAMP
49+
50+
def process_result_value(self, value, dialect):
51+
return value
52+
53+
def adapt(self, impltype, **kwargs):
54+
return self.impl
55+
56+
57+
class DatabricksDate(types.TypeDecorator):
58+
"""Translates date strings to date objects"""
59+
60+
impl = types.DATE
61+
62+
def process_result_value(self, value, dialect):
63+
return value
64+
65+
def adapt(self, impltype, **kwargs):
66+
return self.impl
67+
68+
69+
class DatabricksDialect(default.DefaultDialect):
70+
"""This dialect implements only those methods required to pass our e2e tests"""
71+
72+
# Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect
73+
name: str = "databricks"
74+
driver: str = "databricks"
75+
default_schema_name: str = "default"
76+
77+
preparer = DatabricksIdentifierPreparer # type: ignore
78+
type_compiler = DatabricksTypeCompiler
79+
ddl_compiler = DatabricksDDLCompiler
80+
supports_statement_cache: bool = True
81+
supports_multivalues_insert: bool = True
82+
supports_native_decimal: bool = True
83+
supports_sane_rowcount: bool = False
84+
non_native_boolean_check_constraint: bool = False
85+
paramstyle: str = "named"
86+
87+
@classmethod
88+
def dbapi(cls):
89+
return sql
90+
91+
def create_connect_args(self, url):
92+
# TODO: can schema be provided after HOST?
93+
# Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/***
94+
95+
kwargs = {
96+
"server_hostname": url.host,
97+
"access_token": url.password,
98+
"http_path": url.query.get("http_path"),
99+
"catalog": url.query.get("catalog"),
100+
"schema": url.query.get("schema"),
101+
}
102+
103+
self.schema = kwargs["schema"]
104+
self.catalog = kwargs["catalog"]
105+
106+
return [], kwargs
107+
108+
def get_columns(self, connection, table_name, schema=None, **kwargs):
109+
"""Return information about columns in `table_name`.
110+
111+
Given a :class:`_engine.Connection`, a string
112+
`table_name`, and an optional string `schema`, return column
113+
information as a list of dictionaries with these keys:
114+
115+
name
116+
the column's name
117+
118+
type
119+
[sqlalchemy.types#TypeEngine]
120+
121+
nullable
122+
boolean
123+
124+
default
125+
the column's default value
126+
127+
autoincrement
128+
boolean
129+
130+
sequence
131+
a dictionary of the form
132+
{'name' : str, 'start' :int, 'increment': int, 'minvalue': int,
133+
'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool,
134+
'cycle': bool, 'cache': int, 'order': bool}
135+
136+
Additional column attributes may be present.
137+
"""
138+
139+
_type_map = {
140+
"boolean": types.Boolean,
141+
"smallint": types.SmallInteger,
142+
"int": types.Integer,
143+
"bigint": types.BigInteger,
144+
"float": types.Float,
145+
"double": types.Float,
146+
"string": types.String,
147+
"varchar": types.String,
148+
"char": types.String,
149+
"binary": types.String,
150+
"array": types.String,
151+
"map": types.String,
152+
"struct": types.String,
153+
"uniontype": types.String,
154+
"decimal": DatabricksDecimal,
155+
"timestamp": DatabricksTimestamp,
156+
"date": DatabricksDate,
157+
}
158+
159+
with self.get_connection_cursor(connection) as cur:
160+
resp = cur.columns(
161+
catalog_name=self.catalog,
162+
schema_name=schema or self.schema,
163+
table_name=table_name,
164+
).fetchall()
165+
166+
columns = []
167+
168+
for col in resp:
169+
170+
# Taken from PyHive. This removes added type info from decimals and maps
171+
_col_type = re.search(r"^\w+", col.TYPE_NAME).group(0)
172+
this_column = {
173+
"name": col.COLUMN_NAME,
174+
"type": _type_map[_col_type.lower()],
175+
"nullable": bool(col.NULLABLE),
176+
"default": col.COLUMN_DEF,
177+
"autoincrement": False if col.IS_AUTO_INCREMENT == "NO" else True,
178+
}
179+
columns.append(this_column)
180+
181+
return columns
182+
183+
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
184+
"""Return information about the primary key constraint on
185+
table_name`.
186+
187+
Given a :class:`_engine.Connection`, a string
188+
`table_name`, and an optional string `schema`, return primary
189+
key information as a dictionary with these keys:
190+
191+
constrained_columns
192+
a list of column names that make up the primary key
193+
194+
name
195+
optional name of the primary key constraint.
196+
197+
"""
198+
# TODO: implement this behaviour
199+
return {"constrained_columns": []}
200+
201+
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
202+
"""Return information about foreign_keys in `table_name`.
203+
204+
Given a :class:`_engine.Connection`, a string
205+
`table_name`, and an optional string `schema`, return foreign
206+
key information as a list of dicts with these keys:
207+
208+
name
209+
the constraint's name
210+
211+
constrained_columns
212+
a list of column names that make up the foreign key
213+
214+
referred_schema
215+
the name of the referred schema
216+
217+
referred_table
218+
the name of the referred table
219+
220+
referred_columns
221+
a list of column names in the referred table that correspond to
222+
constrained_columns
223+
"""
224+
# TODO: Implement this behaviour
225+
return []
226+
227+
def get_indexes(self, connection, table_name, schema=None, **kw):
228+
"""Return information about indexes in `table_name`.
229+
230+
Given a :class:`_engine.Connection`, a string
231+
`table_name` and an optional string `schema`, return index
232+
information as a list of dictionaries with these keys:
233+
234+
name
235+
the index's name
236+
237+
column_names
238+
list of column names in order
239+
240+
unique
241+
boolean
242+
"""
243+
# TODO: Implement this behaviour
244+
return []
245+
246+
def get_table_names(self, connection, schema=None, **kwargs):
247+
TABLE_NAME = 1
248+
with self.get_connection_cursor(connection) as cur:
249+
sql_str = "SHOW TABLES FROM {}".format(
250+
".".join([self.catalog, schema or self.schema])
251+
)
252+
data = cur.execute(sql_str).fetchall()
253+
_tables = [i[TABLE_NAME] for i in data]
254+
255+
return _tables
256+
257+
def get_view_names(self, connection, schema=None, **kwargs):
258+
VIEW_NAME = 1
259+
with self.get_connection_cursor(connection) as cur:
260+
sql_str = "SHOW VIEWS FROM {}".format(
261+
".".join([self.catalog, schema or self.schema])
262+
)
263+
data = cur.execute(sql_str).fetchall()
264+
_tables = [i[VIEW_NAME] for i in data]
265+
266+
return _tables
267+
268+
def do_rollback(self, dbapi_connection):
269+
# Databricks SQL Does not support transactions
270+
pass
271+
272+
def has_table(
273+
self, connection, table_name, schema=None, catalog=None, **kwargs
274+
) -> bool:
275+
"""SQLAlchemy docstrings say dialect providers must implement this method"""
276+
277+
_schema = schema or self.schema
278+
_catalog = catalog or self.catalog
279+
280+
# DBR >12.x uses underscores in error messages
281+
DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
282+
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
283+
284+
try:
285+
res = connection.execute(
286+
f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}"
287+
)
288+
return True
289+
except DatabaseError as e:
290+
if DBR_GT_12_NOT_FOUND_STRING in str(
291+
e
292+
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
293+
return False
294+
else:
295+
raise e
296+
297+
def get_connection_cursor(self, connection):
298+
"""Added for backwards compatibility with 1.3.x"""
299+
if hasattr(connection, "_dbapi_connection"):
300+
return connection._dbapi_connection.dbapi_connection.cursor()
301+
elif hasattr(connection, "raw_connection"):
302+
return connection.raw_connection().cursor()
303+
elif hasattr(connection, "connection"):
304+
return connection.connection.cursor()
305+
306+
raise SQLAlchemyError(
307+
"Databricks dialect can't obtain a cursor context manager from the dbapi"
308+
)
309+
310+
@reflection.cache
311+
def get_schema_names(self, connection, **kw):
312+
# Equivalent to SHOW DATABASES
313+
314+
# TODO: replace with call to cursor.schemas() once its performance matches raw SQL
315+
return [row[0] for row in connection.execute("SHOW SCHEMAS")]
316+
317+
318+
@event.listens_for(Engine, "do_connect")
319+
def receive_do_connect(dialect, conn_rec, cargs, cparams):
320+
"""Helpful for DS on traffic from clients using SQLAlchemy in particular"""
321+
322+
# Ignore connect invocations that don't use our dialect
323+
if not dialect.name == "databricks":
324+
return
325+
326+
if "_user_agent_entry" in cparams:
327+
new_user_agent = f"sqlalchemy + {cparams['_user_agent_entry']}"
328+
else:
329+
new_user_agent = "sqlalchemy"
330+
331+
cparams["_user_agent_entry"] = new_user_agent
332+
333+
if sqlalchemy.__version__.startswith("1.3"):
334+
# SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string
335+
# These should be passed in as connect_args when building the Engine
336+
337+
if "schema" in cparams:
338+
dialect.schema = cparams["schema"]
339+
340+
if "catalog" in cparams:
341+
dialect.catalog = cparams["catalog"]
File renamed without changes.

0 commit comments

Comments
 (0)