Skip to content

Commit 91c8699

Browse files
committed
ENH: Enable pytables to round-trip with StringDtype
1 parent 1be2637 commit 91c8699

File tree

2 files changed

+101
-20
lines changed

2 files changed

+101
-20
lines changed

pandas/io/pytables.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
writers as libwriters,
3939
)
4040
from pandas._libs.lib import is_string_array
41+
from pandas._libs.missing import NA
4142
from pandas._libs.tslibs import timezones
4243
from pandas.compat._optional import import_optional_dependency
4344
from pandas.compat.pickle_compat import patch_pickle
@@ -91,7 +92,10 @@
9192
PyTablesExpr,
9293
maybe_expression,
9394
)
94-
from pandas.core.construction import extract_array
95+
from pandas.core.construction import (
96+
array as pd_array,
97+
extract_array,
98+
)
9599
from pandas.core.indexes.api import ensure_index
96100

97101
from pandas.io.common import stringify_path
@@ -3023,6 +3027,18 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None
30233027

30243028
if isinstance(node, tables.VLArray):
30253029
ret = node[0][start:stop]
3030+
dtype = getattr(attrs, "value_type", None)
3031+
if dtype is not None:
3032+
if dtype == "str[python]":
3033+
dtype = StringDtype("python", np.nan)
3034+
elif dtype == "string[python]":
3035+
dtype = StringDtype("python", NA)
3036+
elif dtype == "str[pyarrow]":
3037+
dtype = StringDtype("pyarrow", np.nan)
3038+
else:
3039+
assert dtype == "string[pyarrow]"
3040+
dtype = StringDtype("pyarrow", NA)
3041+
ret = pd_array(ret, dtype=dtype)
30263042
else:
30273043
dtype = getattr(attrs, "value_type", None)
30283044
shape = getattr(attrs, "shape", None)
@@ -3210,6 +3226,8 @@ def write_array(
32103226
# get the atom for this datatype
32113227
atom = _tables().Atom.from_dtype(value.dtype)
32123228

3229+
from pandas.core.arrays.string_ import BaseStringArray
3230+
32133231
if atom is not None:
32143232
# We only get here if self._filters is non-None and
32153233
# the Atom.from_dtype call succeeded
@@ -3262,6 +3280,19 @@ def write_array(
32623280
elif lib.is_np_dtype(value.dtype, "m"):
32633281
self._handle.create_array(self.group, key, value.view("i8"))
32643282
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
3283+
elif isinstance(value, BaseStringArray):
3284+
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom())
3285+
vlarr.append(value.to_numpy())
3286+
node = getattr(self.group, key)
3287+
if value.dtype == StringDtype("python", np.nan):
3288+
node._v_attrs.value_type = "str[python]"
3289+
elif value.dtype == StringDtype("python", NA):
3290+
node._v_attrs.value_type = "string[python]"
3291+
elif value.dtype == StringDtype("pyarrow", np.nan):
3292+
node._v_attrs.value_type = "str[pyarrow]"
3293+
else:
3294+
assert value.dtype == StringDtype("pyarrow", NA)
3295+
node._v_attrs.value_type = "string[pyarrow]"
32653296
elif empty_array:
32663297
self.write_array_empty(key, value)
32673298
else:
@@ -3294,7 +3325,11 @@ def read(
32943325
index = self.read_index("index", start=start, stop=stop)
32953326
values = self.read_array("values", start=start, stop=stop)
32963327
result = Series(values, index=index, name=self.name, copy=False)
3297-
if using_string_dtype() and is_string_array(values, skipna=True):
3328+
if (
3329+
using_string_dtype()
3330+
and isinstance(values, np.ndarray)
3331+
and is_string_array(values, skipna=True)
3332+
):
32983333
result = result.astype(StringDtype(na_value=np.nan))
32993334
return result
33003335

@@ -3363,7 +3398,11 @@ def read(
33633398

33643399
columns = items[items.get_indexer(blk_items)]
33653400
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
3366-
if using_string_dtype() and is_string_array(values, skipna=True):
3401+
if (
3402+
using_string_dtype()
3403+
and isinstance(values, np.ndarray)
3404+
and is_string_array(values, skipna=True)
3405+
):
33673406
df = df.astype(StringDtype(na_value=np.nan))
33683407
dfs.append(df)
33693408

@@ -4737,9 +4776,13 @@ def read(
47374776
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
47384777
if not (using_string_dtype() and values.dtype.kind == "O"):
47394778
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4740-
if using_string_dtype() and is_string_array(
4741-
values, # type: ignore[arg-type]
4742-
skipna=True,
4779+
if (
4780+
using_string_dtype()
4781+
and isinstance(values, np.ndarray)
4782+
and is_string_array(
4783+
values, # type: ignore[arg-type]
4784+
skipna=True,
4785+
)
47434786
):
47444787
df = df.astype(StringDtype(na_value=np.nan))
47454788
frames.append(df)

pandas/tests/io/pytables/test_put.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas._libs.tslibs import Timestamp
97

108
import pandas as pd
@@ -26,7 +24,6 @@
2624

2725
pytestmark = [
2826
pytest.mark.single_cpu,
29-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
3027
]
3128

3229

@@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path):
5451
with ensure_clean_store(setup_path) as store:
5552
df = DataFrame(
5653
1.1 * np.arange(120).reshape((30, 4)),
57-
columns=Index(list("ABCD"), dtype=object),
58-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
54+
columns=Index(list("ABCD")),
55+
index=Index([f"i-{i}" for i in range(30)]),
5956
)
6057

6158
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path):
7976
path = tmp_path / setup_path
8077
df = DataFrame(
8178
1.1 * np.arange(120).reshape((30, 4)),
82-
columns=Index(list("ABCD"), dtype=object),
83-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
79+
columns=Index(list("ABCD")),
80+
index=Index([f"i-{i}" for i in range(30)]),
8481
)
8582

8683
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -106,7 +103,7 @@ def test_put(setup_path):
106103
)
107104
df = DataFrame(
108105
np.random.default_rng(2).standard_normal((20, 4)),
109-
columns=Index(list("ABCD"), dtype=object),
106+
columns=Index(list("ABCD")),
110107
index=date_range("2000-01-01", periods=20, freq="B"),
111108
)
112109
store["a"] = ts
@@ -166,7 +163,7 @@ def test_put_compression(setup_path):
166163
with ensure_clean_store(setup_path) as store:
167164
df = DataFrame(
168165
np.random.default_rng(2).standard_normal((10, 4)),
169-
columns=Index(list("ABCD"), dtype=object),
166+
columns=Index(list("ABCD")),
170167
index=date_range("2000-01-01", periods=10, freq="B"),
171168
)
172169

@@ -183,7 +180,7 @@ def test_put_compression(setup_path):
183180
def test_put_compression_blosc(setup_path):
184181
df = DataFrame(
185182
np.random.default_rng(2).standard_normal((10, 4)),
186-
columns=Index(list("ABCD"), dtype=object),
183+
columns=Index(list("ABCD")),
187184
index=date_range("2000-01-01", periods=10, freq="B"),
188185
)
189186

@@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path):
197194
tm.assert_frame_equal(store["c"], df)
198195

199196

200-
def test_put_mixed_type(setup_path, performance_warning):
197+
def test_put_datetime_ser(setup_path, performance_warning, using_infer_string):
198+
# https://github.com/pandas-dev/pandas/pull/???
199+
ser = Series(3 * [Timestamp("20010102").as_unit("ns")])
200+
with ensure_clean_store(setup_path) as store:
201+
store.put("ser", ser)
202+
expected = ser.copy()
203+
result = store.get("ser")
204+
tm.assert_series_equal(result, expected)
205+
206+
207+
def test_put_mixed_type(setup_path, performance_warning, using_infer_string):
201208
df = DataFrame(
202209
np.random.default_rng(2).standard_normal((10, 4)),
203-
columns=Index(list("ABCD"), dtype=object),
210+
columns=Index(list("ABCD")),
204211
index=date_range("2000-01-01", periods=10, freq="B"),
205212
)
206213
df["obj1"] = "foo"
@@ -220,13 +227,38 @@ def test_put_mixed_type(setup_path, performance_warning):
220227
with ensure_clean_store(setup_path) as store:
221228
_maybe_remove(store, "df")
222229

223-
with tm.assert_produces_warning(performance_warning):
230+
warning = None if using_infer_string else performance_warning
231+
with tm.assert_produces_warning(warning):
224232
store.put("df", df)
225233

226234
expected = store.get("df")
227235
tm.assert_frame_equal(expected, df)
228236

229237

238+
def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments):
239+
dtype = pd.StringDtype(*string_dtype_arguments)
240+
df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)})
241+
with ensure_clean_store(setup_path) as store:
242+
_maybe_remove(store, "df")
243+
244+
store.put("df", df)
245+
expected = df
246+
result = store.get("df")
247+
tm.assert_frame_equal(result, expected)
248+
249+
250+
def test_put_str_series(setup_path, performance_warning, string_dtype_arguments):
251+
dtype = pd.StringDtype(*string_dtype_arguments)
252+
ser = Series(["x", pd.NA, "y"], dtype=dtype)
253+
with ensure_clean_store(setup_path) as store:
254+
_maybe_remove(store, "df")
255+
256+
store.put("ser", ser)
257+
expected = ser
258+
result = store.get("ser")
259+
tm.assert_series_equal(result, expected)
260+
261+
230262
@pytest.mark.parametrize("format", ["table", "fixed"])
231263
@pytest.mark.parametrize(
232264
"index",
@@ -253,7 +285,7 @@ def test_store_index_types(setup_path, format, index):
253285
tm.assert_frame_equal(df, store["df"])
254286

255287

256-
def test_column_multiindex(setup_path):
288+
def test_column_multiindex(setup_path, using_infer_string):
257289
# GH 4710
258290
# recreate multi-indexes properly
259291

@@ -264,6 +296,12 @@ def test_column_multiindex(setup_path):
264296
expected = df.set_axis(df.index.to_numpy())
265297

266298
with ensure_clean_store(setup_path) as store:
299+
if using_infer_string:
300+
# TODO(infer_string) make this work for string dtype
301+
msg = "Saving a MultiIndex with an extension dtype is not supported."
302+
with pytest.raises(NotImplementedError, match=msg):
303+
store.put("df", df)
304+
return
267305
store.put("df", df)
268306
tm.assert_frame_equal(
269307
store["df"], expected, check_index_type=True, check_column_type=True

0 commit comments

Comments
 (0)