Skip to content

Commit fc62545

Browse files
sander-goossusodapop
authored andcommitted
Use to_pandas deserialisation and add .column syntax
Leverage PyArrow's to_pandas and Pandas' to_numpy for more efficient conversion of Arrow table to python results. Also add the .column syntax with the Row class. Existing / modified tests. Added micro benchmark test. 30s micro benchmark without server (higher is better): - Original: ~ 47 times - Using zip and Row class: ~ 97 times - Using pandas and toList: ~ 442 times - Using pandas and Row class: ~ 256 times 30s benchmark with local server: - V1 client: ~ 27 times - Original V2 client: ~ 34 times - Using pandas and toList: ~ 65 times - Using pandas and Row class: ~ 58 times Using toList is still around 12% faster with a local server, because it skips generating the tuples. In reality, this will be diluted with network latency as Thriftserver will not be running on the same machine in practice.
1 parent 86b9055 commit fc62545

File tree

6 files changed

+403
-52
lines changed

6 files changed

+403
-52
lines changed

cmdexec/clients/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
version="2.0.0rc2",
66
package_dir={"": "src"},
77
packages=setuptools.find_packages(where="src"),
8-
install_requires=["pyarrow", 'thrift>=0.13.0', "pandas"],
8+
install_requires=["pyarrow", 'thrift>=0.13.0', "pandas>=1.0.0"],
99
author="Databricks",
1010
)

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import re
66
from typing import Dict, Tuple, List, Optional, Any
77

8+
import pandas
89
import pyarrow
910

1011
from databricks.sql import USER_AGENT_NAME, __version__
1112
from databricks.sql import *
1213
from databricks.sql.thrift_backend import ThriftBackend
1314
from databricks.sql.utils import ExecuteResponse, ParamEscaper
15+
from databricks.sql.types import Row
1416

1517
logger = logging.getLogger(__name__)
1618

@@ -67,9 +69,13 @@ def __init__(self,
6769
# _socket_timeout
6870
# The timeout in seconds for socket send, recv and connect operations. Defaults to None for
6971
# no timeout. Should be a positive float or integer.
72+
# _disable_pandas
73+
# In case the deserialisation through pandas causes any issues, it can be disabled with
74+
# this flag.
7075

7176
self.host = server_hostname
7277
self.port = kwargs.get("_port", 443)
78+
self.disable_pandas = kwargs.get("_disable_pandas", False)
7379

7480
authorization_header = []
7581
if kwargs.get("_username") and kwargs.get("_password"):
@@ -324,7 +330,7 @@ def columns(self,
324330
self.buffer_size_bytes, self.arraysize)
325331
return self
326332

327-
def fetchall(self) -> List[Tuple]:
333+
def fetchall(self) -> List[Row]:
328334
"""
329335
Fetch all (remaining) rows of a query result, returning them as a sequence of sequences.
330336
@@ -337,7 +343,7 @@ def fetchall(self) -> List[Tuple]:
337343
else:
338344
raise Error("There is no active result set")
339345

340-
def fetchone(self) -> Tuple:
346+
def fetchone(self) -> Optional[Row]:
341347
"""
342348
Fetch the next row of a query result set, returning a single sequence, or ``None`` when
343349
no more data is available.
@@ -351,7 +357,7 @@ def fetchone(self) -> Tuple:
351357
else:
352358
raise Error("There is no active result set")
353359

354-
def fetchmany(self, size: int) -> List[Tuple]:
360+
def fetchmany(self, size: int) -> List[Row]:
355361
"""
356362
Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
357363
list of tuples).
@@ -373,14 +379,14 @@ def fetchmany(self, size: int) -> List[Tuple]:
373379
else:
374380
raise Error("There is no active result set")
375381

376-
def fetchall_arrow(self):
382+
def fetchall_arrow(self) -> pyarrow.Table:
377383
self._check_not_closed()
378384
if self.active_result_set:
379385
return self.active_result_set.fetchall_arrow()
380386
else:
381387
raise Error("There is no active result set")
382388

383-
def fetchmany_arrow(self, size):
389+
def fetchmany_arrow(self, size) -> pyarrow.Table:
384390
self._check_not_closed()
385391
if self.active_result_set:
386392
return self.active_result_set.fetchmany_arrow(size)
@@ -505,10 +511,43 @@ def _fill_results_buffer(self):
505511
self.has_more_rows = has_more_rows
506512

507513
def _convert_arrow_table(self, table):
508-
n_rows, _ = table.shape
509-
list_repr = [[col[row_index].as_py() for col in table.itercolumns()]
510-
for row_index in range(n_rows)]
511-
return list_repr
514+
column_names = [c[0] for c in self.description]
515+
ResultRow = Row(*column_names)
516+
517+
if self.connection.disable_pandas is True:
518+
return [ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns())]
519+
520+
# Need to use nullable types, as otherwise type can change when there are missing values.
521+
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
522+
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
523+
dtype_mapping = {
524+
pyarrow.int8(): pandas.Int8Dtype(),
525+
pyarrow.int16(): pandas.Int16Dtype(),
526+
pyarrow.int32(): pandas.Int32Dtype(),
527+
pyarrow.int64(): pandas.Int64Dtype(),
528+
pyarrow.uint8(): pandas.UInt8Dtype(),
529+
pyarrow.uint16(): pandas.UInt16Dtype(),
530+
pyarrow.uint32(): pandas.UInt32Dtype(),
531+
pyarrow.uint64(): pandas.UInt64Dtype(),
532+
pyarrow.bool_(): pandas.BooleanDtype(),
533+
pyarrow.float32(): pandas.Float32Dtype(),
534+
pyarrow.float64(): pandas.Float64Dtype(),
535+
pyarrow.string(): pandas.StringDtype(),
536+
}
537+
538+
# Need to rename columns, as the to_pandas function cannot handle duplicate column names
539+
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
540+
df = table_renamed.to_pandas(types_mapper=dtype_mapping.get)
541+
542+
for (i, col) in enumerate(df.columns):
543+
# Check for 0 because .dt doesn't work on empty series
544+
if self.description[i][1] == 'timestamp' and len(df) > 0:
545+
# We store the dtype as object so we don't use the pandas datetime dtype but
546+
# a native datetime.datetime
547+
df[col] = pandas.Series(df[col].dt.to_pydatetime(), dtype='object')
548+
549+
res = df.to_numpy(na_value=None)
550+
return [ResultRow(*v) for v in res]
512551

513552
@property
514553
def rownumber(self):
@@ -548,7 +587,7 @@ def fetchall_arrow(self) -> pyarrow.Table:
548587

549588
return results
550589

551-
def fetchone(self) -> Optional[Tuple]:
590+
def fetchone(self) -> Optional[Row]:
552591
"""
553592
Fetch the next row of a query result set, returning a single sequence,
554593
or None when no more data is available.
@@ -559,15 +598,15 @@ def fetchone(self) -> Optional[Tuple]:
559598
else:
560599
return None
561600

562-
def fetchall(self) -> List[Tuple]:
601+
def fetchall(self) -> List[Row]:
563602
"""
564-
Fetch all (remaining) rows of a query result, returning them as a list of lists.
603+
Fetch all (remaining) rows of a query result, returning them as a list of rows.
565604
"""
566605
return self._convert_arrow_table(self.fetchall_arrow())
567606

568-
def fetchmany(self, size: int) -> List[Tuple]:
607+
def fetchmany(self, size: int) -> List[Row]:
569608
"""
570-
Fetch the next set of rows of a query result, returning a list of lists.
609+
Fetch the next set of rows of a query result, returning a list of rows.
571610
572611
An empty sequence is returned when no more rows are available.
573612
"""
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# Row class was taken from Apache Spark pyspark.
18+
19+
from typing import (Any, Dict, List, Optional, Tuple, Union)
20+
21+
22+
class Row(tuple):
23+
"""
24+
A row in a query result.
25+
The fields in it can be accessed:
26+
27+
* like attributes (``row.key``)
28+
* like dictionary values (``row[key]``)
29+
30+
``key in row`` will search through row keys.
31+
32+
Row can be used to create a row object by using named arguments.
33+
It is not allowed to omit a named argument to represent that the value is
34+
None or missing. This should be explicitly set to None in this case.
35+
36+
Examples
37+
--------
38+
>>> row = Row(name="Alice", age=11)
39+
>>> row
40+
Row(name='Alice', age=11)
41+
>>> row['name'], row['age']
42+
('Alice', 11)
43+
>>> row.name, row.age
44+
('Alice', 11)
45+
>>> 'name' in row
46+
True
47+
>>> 'wrong_key' in row
48+
False
49+
50+
Row also can be used to create another Row like class, then it
51+
could be used to create Row objects, such as
52+
53+
>>> Person = Row("name", "age")
54+
>>> Person
55+
<Row('name', 'age')>
56+
>>> 'name' in Person
57+
True
58+
>>> 'wrong_key' in Person
59+
False
60+
>>> Person("Alice", 11)
61+
Row(name='Alice', age=11)
62+
63+
This form can also be used to create rows as tuple values, i.e. with unnamed
64+
fields.
65+
66+
>>> row1 = Row("Alice", 11)
67+
>>> row2 = Row(name="Alice", age=11)
68+
>>> row1 == row2
69+
True
70+
"""
71+
72+
def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row":
73+
if args and kwargs:
74+
raise ValueError("Can not use both args " "and kwargs to create Row")
75+
if kwargs:
76+
# create row objects
77+
row = tuple.__new__(cls, list(kwargs.values()))
78+
row.__fields__ = list(kwargs.keys())
79+
return row
80+
else:
81+
# create row class or objects
82+
return tuple.__new__(cls, args)
83+
84+
def asDict(self, recursive: bool = False) -> Dict[str, Any]:
85+
"""
86+
Return as a dict
87+
88+
Parameters
89+
----------
90+
recursive : bool, optional
91+
turns the nested Rows to dict (default: False).
92+
93+
Notes
94+
-----
95+
If a row contains duplicate field names, e.g., the rows of a join
96+
between two dataframes that both have the fields of same names,
97+
one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
98+
will also return one of the duplicate fields, however returned value might
99+
be different to ``asDict``.
100+
101+
Examples
102+
--------
103+
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
104+
True
105+
>>> row = Row(key=1, value=Row(name='a', age=2))
106+
>>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}
107+
True
108+
>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
109+
True
110+
"""
111+
if not hasattr(self, "__fields__"):
112+
raise TypeError("Cannot convert a Row class into dict")
113+
114+
if recursive:
115+
116+
def conv(obj: Any) -> Any:
117+
if isinstance(obj, Row):
118+
return obj.asDict(True)
119+
elif isinstance(obj, list):
120+
return [conv(o) for o in obj]
121+
elif isinstance(obj, dict):
122+
return dict((k, conv(v)) for k, v in obj.items())
123+
else:
124+
return obj
125+
126+
return dict(zip(self.__fields__, (conv(o) for o in self)))
127+
else:
128+
return dict(zip(self.__fields__, self))
129+
130+
def __contains__(self, item: Any) -> bool:
131+
if hasattr(self, "__fields__"):
132+
return item in self.__fields__
133+
else:
134+
return super(Row, self).__contains__(item)
135+
136+
# let object acts like class
137+
def __call__(self, *args: Any) -> "Row":
138+
"""create new Row object"""
139+
if len(args) > len(self):
140+
raise ValueError("Can not create Row with fields %s, expected %d values "
141+
"but got %s" % (self, len(self), args))
142+
return _create_row(self, args)
143+
144+
def __getitem__(self, item: Any) -> Any:
145+
if isinstance(item, (int, slice)):
146+
return super(Row, self).__getitem__(item)
147+
try:
148+
# it will be slow when it has many fields,
149+
# but this will not be used in normal cases
150+
idx = self.__fields__.index(item)
151+
return super(Row, self).__getitem__(idx)
152+
except IndexError:
153+
raise KeyError(item)
154+
except ValueError:
155+
raise ValueError(item)
156+
157+
def __getattr__(self, item: str) -> Any:
158+
if item.startswith("__"):
159+
raise AttributeError(item)
160+
try:
161+
# it will be slow when it has many fields,
162+
# but this will not be used in normal cases
163+
idx = self.__fields__.index(item)
164+
return self[idx]
165+
except IndexError:
166+
raise AttributeError(item)
167+
except ValueError:
168+
raise AttributeError(item)
169+
170+
def __setattr__(self, key: Any, value: Any) -> None:
171+
if key != "__fields__":
172+
raise RuntimeError("Row is read-only")
173+
self.__dict__[key] = value
174+
175+
def __reduce__(self, ) -> Union[str, Tuple[Any, ...]]:
176+
"""Returns a tuple so Python knows how to pickle Row."""
177+
if hasattr(self, "__fields__"):
178+
return (_create_row, (self.__fields__, tuple(self)))
179+
else:
180+
return tuple.__reduce__(self)
181+
182+
def __repr__(self) -> str:
183+
"""Printable representation of Row used in Python REPL."""
184+
if hasattr(self, "__fields__"):
185+
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
186+
for k, v in zip(self.__fields__, tuple(self)))
187+
else:
188+
return "<Row(%s)>" % ", ".join("%r" % field for field in self)
189+
190+
191+
def _create_row(fields: Union["Row", List[str]],
192+
values: Union[Tuple[Any, ...], List[Any]]) -> "Row":
193+
row = Row(*values)
194+
row.__fields__ = fields
195+
return row

0 commit comments

Comments
 (0)