Skip to content

Commit e83e4f9

Browse files
committed
assertionerror fix
1 parent cddc574 commit e83e4f9

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

pandas/core/arrays/categorical.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,11 @@ def __init__(
480480

481481
# If we should prserve object dtype, force categories to object dtype
482482
if preserve_object_dtpe:
483-
from pandas import Index
483+
# Only preserve object dtype if not all elements are strings
484+
if not all(isinstance(x, str) for x in categories):
485+
from pandas import Index
484486

485-
categories = Index(categories, dtype=object, copy=False)
487+
categories = Index(categories, dtype=object, copy=False)
486488
dtype = CategoricalDtype(categories, dtype.ordered)
487489

488490
elif isinstance(values.dtype, CategoricalDtype):

pandas/tests/arrays/categorical/test_constructors.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,19 @@ def test_categorical_preserve_object_dtype_from_pandas(self):
799799
cat_from_arr = Categorical(arr)
800800
cat_from_list = Categorical(pylist)
801801

802-
# Series/Index with object dtype: preserve object dtype
803-
assert cat_from_ser.categories.dtype == "object"
804-
assert cat_from_idx.categories.dtype == "object"
802+
# Series/Index with object dtype: infer string
803+
# dtype if all elements are strings
804+
assert cat_from_ser.categories.inferred_type == "string"
805+
assert cat_from_idx.categories.inferred_type == "string"
805806

806807
# Numpy array or list: infer string dtype
807-
assert cat_from_arr.categories.dtype == "str"
808-
assert cat_from_list.categories.dtype == "str"
808+
assert cat_from_arr.categories.inferred_type == "string"
809+
assert cat_from_list.categories.inferred_type == "string"
810+
811+
# Mixed types: preserve object dtype
812+
ser_mixed = Series(["foo", 1, None], dtype="object")
813+
idx_mixed = Index(["foo", 1, None], dtype="object")
814+
cat_mixed_ser = Categorical(ser_mixed)
815+
cat_mixed_idx = Categorical(idx_mixed)
816+
assert cat_mixed_ser.categories.dtype == "object"
817+
assert cat_mixed_idx.categories.dtype == "object"

0 commit comments

Comments
 (0)