Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ np_1darray_complex: TypeAlias = np_1darray[np.complexfloating]
np_1darray_object: TypeAlias = np_1darray[np.object_]
np_1darray_bool: TypeAlias = np_1darray[np.bool]
np_1darray_intp: TypeAlias = np_1darray[np.intp]
np_1darray_int8: TypeAlias = np_1darray[np.int8]
np_1darray_int64: TypeAlias = np_1darray[np.int64]
np_1darray_anyint: TypeAlias = np_1darray[np.integer]
np_1darray_float: TypeAlias = np_1darray[np.floating]
Expand Down
103 changes: 67 additions & 36 deletions pandas-stubs/core/indexes/multi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ from collections.abc import (
)
from typing import (
Any,
final,
overload,
)

Expand All @@ -19,15 +18,18 @@ from typing_extensions import Self
from pandas._typing import (
AnyAll,
Axes,
DropKeep,
Dtype,
HashableT,
IndexLabel,
Label,
Level,
MaskType,
NaPosition,
SequenceNotStr,
Shape,
np_1darray_bool,
np_1darray_int8,
np_1darray_intp,
np_ndarray_anyint,
)

Expand Down Expand Up @@ -70,19 +72,46 @@ class MultiIndex(Index):
sortorder: int | None = ...,
names: SequenceNotStr[Hashable] = ...,
) -> Self: ...
@property
def shape(self): ...
@property # Should be read-only
def levels(self) -> list[Index]: ...
def set_levels(self, levels, *, level=..., verify_integrity: bool = ...): ...
@overload
def set_levels(
self,
levels: Sequence[SequenceNotStr[Hashable]],
*,
level: Sequence[Level] | None = None,
verify_integrity: bool = True,
) -> MultiIndex: ...
@overload
def set_levels(
self,
levels: SequenceNotStr[Hashable],
*,
level: Level,
verify_integrity: bool = True,
) -> MultiIndex: ...
@property
def codes(self): ...
def set_codes(self, codes, *, level=..., verify_integrity: bool = ...): ...
def codes(self) -> list[np_1darray_int8]: ...
@overload
def set_codes(
self,
codes: Sequence[Sequence[int]],
*,
level: Sequence[Level] | None = None,
verify_integrity: bool = True,
) -> MultiIndex: ...
@overload
def set_codes(
self,
codes: Sequence[int],
*,
level: Level,
verify_integrity: bool = True,
) -> MultiIndex: ...
def copy( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # pyrefly: ignore
self, names: SequenceNotStr[Hashable] = ..., deep: bool = False
) -> Self: ...
def view(self, cls=...): ...
def __contains__(self, key) -> bool: ...
def view(self, cls: Any = None) -> MultiIndex: ... # type: ignore[override] # pyrefly: ignore[bad-override] # pyright: ignore[reportIncompatibleMethodOverride]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an odd one, for multiindex.view, cls is ignored 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may be able to refine it, docs suggest: data-type or ndarray sub-class or None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, i've made it the same as index.view but without the overloads

@property
def dtype(self) -> np.dtype: ...
@property
Expand All @@ -92,29 +121,34 @@ class MultiIndex(Index):
def nbytes(self) -> int: ...
def __len__(self) -> int: ...
@property
def values(self): ...
@property
def is_monotonic_increasing(self) -> bool: ...
@property
def is_monotonic_decreasing(self) -> bool: ...
def duplicated(self, keep: DropKeep = "first"): ...
def dropna(self, how: AnyAll = "any") -> Self: ...
def droplevel(self, level: Level | Sequence[Level] = 0) -> MultiIndex | Index: ... # type: ignore[override]
def get_level_values(self, level: str | int) -> Index: ...
def unique(self, level=...): ...
@overload # type: ignore[override]
def unique( # pyrefly: ignore[bad-override]
self, level: None = None
) -> MultiIndex: ...
@overload
def unique( # ty: ignore[invalid-method-override] # pyright: ignore[reportIncompatibleMethodOverride]
self, level: Level
) -> (
Index
): ... # ty: ignore[invalid-method-override] # pyrefly: ignore[bad-override]
def to_frame( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
self,
index: bool = True,
name: list[HashableT] = ...,
allow_duplicates: bool = False,
) -> pd.DataFrame: ...
def to_flat_index(self) -> Index: ...
def remove_unused_levels(self): ...
def remove_unused_levels(self) -> MultiIndex: ...
@property
def nlevels(self) -> int: ...
@property
def levshape(self): ...
def __reduce__(self): ...
def levshape(self) -> Shape: ...
@overload # type: ignore[override]
# pyrefly: ignore # bad-override
def __getitem__(
Expand All @@ -125,36 +159,33 @@ class MultiIndex(Index):
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] # ty: ignore[invalid-method-override]
self, key: int
) -> tuple[Hashable, ...]: ...
def append(self, other): ...
def repeat(self, repeats, axis=...): ...
def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@overload # type: ignore[override]
def append(self, other: MultiIndex | Sequence[MultiIndex]) -> MultiIndex: ...
@overload
def append( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: Index | Sequence[Index]
) -> Index: ... # pyrefly: ignore[bad-override]
def drop(self, codes: Level | Sequence[Level], level: Level | None = None, errors: str = "raise") -> MultiIndex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def swaplevel(self, i: int = -2, j: int = -1) -> Self: ...
def reorder_levels(self, order): ...
def reorder_levels(self, order: Sequence[Level]) -> MultiIndex: ...
def sortlevel(
self,
level: Level | Sequence[Level] = 0,
ascending: bool = True,
sort_remaining: bool = True,
na_position: NaPosition = "first",
): ...
@final
def get_indexer(self, target, method=..., limit=..., tolerance=...): ...
def get_indexer_non_unique(self, target): ...
def reindex(self, target, method=..., level=..., limit=..., tolerance=...): ...
def get_slice_bound(
self, label: Hashable | Sequence[Hashable], side: str
) -> int: ...
) -> tuple[MultiIndex, np_1darray_intp]: ...
def get_loc_level(
self, key, level: Level | list[Level] | None = None, drop_level: bool = True
): ...
def get_locs(self, seq): ...
self,
key: Label | Sequence[Label],
level: Level | Sequence[Level] | None = None,
drop_level: bool = True,
) -> tuple[int | slice | np_1darray_bool, Index]: ...
def get_locs(self, seq: Level | Sequence[Level]) -> np_1darray_intp: ...
def truncate(
self, before: IndexLabel | None = None, after: IndexLabel | None = None
): ...
def equals(self, other) -> bool: ...
def equal_levels(self, other): ...
def insert(self, loc, item): ...
def delete(self, loc): ...
) -> MultiIndex: ...
def equal_levels(self, other: MultiIndex) -> bool: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually documented? I can't find it in the docs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could not find it at all, prob worth deleting unless you have seen use cases in production.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well i'm always happy to delete code 🔥

@overload # type: ignore[override]
def isin( # pyrefly: ignore[bad-override]
self, values: Iterable[Any], level: Level
Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ ignore = [
"PYI042", # https://docs.astral.sh/ruff/rules/snake-case-type-alias/
"ERA001", "PLR0402", "PLC0105"
]
"multi.pyi" = [
# TODO: remove when multi.pyi is fully typed
"ANN001", "ANN201", "ANN204", "ANN206",
]
"indexing.pyi" = [
# TODO: remove when indexing.pyi is fully typed
"ANN001", "ANN201", "ANN204", "ANN206",
Expand Down
2 changes: 2 additions & 0 deletions tests/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
np_1darray_complex,
np_1darray_dt,
np_1darray_float,
np_1darray_int8,
np_1darray_int64,
np_1darray_intp,
np_1darray_object,
Expand Down Expand Up @@ -81,6 +82,7 @@
"np_ndarray_dt",
"np_1darray_object",
"np_1darray_td",
"np_1darray_int8",
"np_1darray_int64",
"np_ndarray_num",
"FloatDtypeArg",
Expand Down
98 changes: 98 additions & 0 deletions tests/indexes/test_multi_indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import pandas as pd
from typing_extensions import (
assert_type,
)

from tests import (
check,
)
from tests._typing import (
np_1darray_bool,
np_1darray_int8,
np_1darray_intp,
)


def test_multiindex_unique() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
check(assert_type(mi.unique(), pd.MultiIndex), pd.MultiIndex)
check(assert_type(mi.unique(level=0), pd.Index), pd.Index)


def test_multiindex_set_levels() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
res = mi.set_levels([[10, 20, 30], [40, 50, 60]])
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
res = mi.set_levels([10, 20, 30], level=0)
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)


def test_multiindex_codes() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
check(assert_type(mi.codes, list[np_1darray_int8]), list)


def test_multiindex_set_codes() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
res = mi.set_codes([[0, 1, 2], [0, 1, 2]])
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
res = mi.set_codes([0, 1, 2], level=0)
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)


def test_multiindex_view() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
check(assert_type(mi.view(), pd.MultiIndex), pd.MultiIndex)
check(assert_type(mi.view(pd.Index), pd.MultiIndex), pd.MultiIndex)


def test_multiindex_remove_unused_levels() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
res = mi.remove_unused_levels()
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)


def test_multiindex_levshape() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
ls = mi.levshape
check(assert_type(ls, tuple[int, ...]), tuple, int)


def test_multiindex_append() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
check(assert_type(mi.append([mi]), pd.MultiIndex), pd.MultiIndex)
check(assert_type(mi.append([pd.Index([1, 2])]), pd.Index), pd.Index)


def test_multiindex_drop() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
dropped = mi.drop([1])
check(assert_type(dropped, pd.MultiIndex), pd.MultiIndex)


def test_multiindex_reorder_levels() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
reordered = mi.reorder_levels([1, 0])
check(assert_type(reordered, pd.MultiIndex), pd.MultiIndex)


def test_multiindex_get_locs() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
locs = mi.get_locs([1, 4])
check(assert_type(locs, np_1darray_intp), np_1darray_intp)


def test_multiindex_equal_levels() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
mi2 = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
eq = mi.equal_levels(mi2)
check(assert_type(eq, bool), bool)


def test_multiindex_get_loc_level() -> None:
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
res_0, res_1 = mi.get_loc_level(1, level=0)
check(assert_type(res_0, int | slice | np_1darray_bool), np_1darray_bool)
check(assert_type(res_1, pd.Index), pd.Index)