Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions python/interpret-core/interpret/utils/_clean_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,24 @@ def categorical_encode(uniques, indexes, nonmissings, categories):

def _process_column_initial(X_col, nonmissings, processing, min_unique_continuous):
# called under: fit

if issubclass(X_col.dtype.type, np.floating):

# Check if X_col is a pandas StringArray
if _pandas_installed and hasattr(X_col, 'dtype') and isinstance(X_col.dtype, pd.StringDtype):
# Use pd.factorize for StringDtype which is more efficient than np.unique
# factorize returns (codes, uniques) where codes are integer indices
indexes, uniques = pd.factorize(X_col)

# Convert uniques from Index to numpy array and then to strings
uniques = uniques.to_numpy(dtype=np.str_, na_value='')

# Count occurrences for each unique value
counts = np.bincount(indexes[indexes >= 0]) # -1 is for NA values, which we don't count

# Convert to strings (already done above)
floats = None # StringDtype values are strings, not floats

# Continue with the sorting logic below
elif issubclass(X_col.dtype.type, np.floating):
m = np.isnan(X_col)
if m.any():
np.logical_not(m, out=m)
Expand Down Expand Up @@ -592,6 +608,14 @@ def _encode_categorical_existing(X_col, nonmissings):

# TODO: add special case handling if there is only 1 sample to make that faster
# if we have just 1 sample, we can avoid making the mapping below

# Check if X_col is a pandas StringArray
if _pandas_installed and hasattr(X_col, 'dtype') and isinstance(X_col.dtype, pd.StringDtype):
# Use pd.factorize for StringDtype which is more efficient than np.unique
indexes, uniques = pd.factorize(X_col)
# Convert uniques from Index to numpy array and then to strings
uniques = uniques.to_numpy(dtype=np.str_, na_value='')
return nonmissings, uniques, indexes

tt = X_col.dtype.type
if issubclass(tt, np.floating):
Expand Down Expand Up @@ -1056,10 +1080,27 @@ def _process_pandas_column(X_col, is_initial, feature_type, min_unique_continuou
feature_type,
min_unique_continuous,
)
elif isinstance(dt, pd.StringDtype):
# this handles pd.StringDtype both the numpy and arrow versions
# StringDtype is similar to object dtype but with proper NA handling
if X_col.hasnans:
# if hasnans is true then there is definetly a real missing value in there and not just a mask
return _process_ndarray(
X_col.dropna().values.astype(np.str_, copy=False),
X_col.notna().values,
is_initial,
feature_type,
min_unique_continuous,
)
return _process_ndarray(
X_col.values.astype(np.str_, copy=False),
None,
is_initial,
feature_type,
min_unique_continuous,
)

# TODO: implement pd.SparseDtype
# TODO: implement pd.StringDtype both the numpy and arrow versions
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.StringDtype.html#pandas.StringDtype
msg = f"{type(dt)} not supported"
_log.error(msg)
raise TypeError(msg)
Expand Down
4 changes: 4 additions & 0 deletions python/interpret-core/tests/utils/test_clean_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,10 @@ def test_unify_columns_pandas_missings_BooleanDtype():
check_pandas_missings(pd.BooleanDtype(), False, True)


def test_unify_columns_pandas_missings_StringDtype():
check_pandas_missings(pd.StringDtype(), "abc", "def")


def test_unify_columns_pandas_missings_str():
check_pandas_missings(np.object_, "abc", "def")

Expand Down