|
1 | 1 | import sys |
2 | 2 | from _typeshed import SupportsRichComparisonT |
3 | | -from collections.abc import Callable, Hashable, Iterable, Sequence |
| 3 | +from collections.abc import Callable, Hashable, Iterable, Sequence, Sized |
4 | 4 | from decimal import Decimal |
5 | 5 | from fractions import Fraction |
6 | | -from typing import Literal, NamedTuple, SupportsFloat, SupportsIndex, TypeVar |
| 6 | +from typing import Literal, NamedTuple, Protocol, SupportsFloat, SupportsIndex, TypeVar |
7 | 7 | from typing_extensions import Self, TypeAlias |
8 | 8 |
|
9 | 9 | __all__ = [ |
@@ -41,6 +41,10 @@ _HashableT = TypeVar("_HashableT", bound=Hashable) |
41 | 41 | # Used in NormalDist.samples and kde_random |
42 | 42 | _Seed: TypeAlias = int | float | str | bytes | bytearray # noqa: Y041 |
43 | 43 |
|
| 44 | +# Used in linear_regression |
| 45 | +_T_co = TypeVar("_T_co", covariant=True) |
| 46 | + |
| 47 | +class _SizedIterable(Iterable[_T_co], Sized, Protocol[_T_co]): ... |
44 | 48 | class StatisticsError(ValueError): ... |
45 | 49 |
|
46 | 50 | if sys.version_info >= (3, 11): |
@@ -129,11 +133,13 @@ if sys.version_info >= (3, 10): |
129 | 133 |
|
130 | 134 | if sys.version_info >= (3, 11): |
131 | 135 | def linear_regression( |
132 | | - regressor: Sequence[_Number], dependent_variable: Sequence[_Number], /, *, proportional: bool = False |
| 136 | + regressor: _SizedIterable[_Number], dependent_variable: _SizedIterable[_Number], /, *, proportional: bool = False |
133 | 137 | ) -> LinearRegression: ... |
134 | 138 |
|
135 | 139 | elif sys.version_info >= (3, 10): |
136 | | - def linear_regression(regressor: Sequence[_Number], dependent_variable: Sequence[_Number], /) -> LinearRegression: ... |
| 140 | + def linear_regression( |
| 141 | + regressor: _SizedIterable[_Number], dependent_variable: _SizedIterable[_Number], / |
| 142 | + ) -> LinearRegression: ... |
137 | 143 |
|
138 | 144 | if sys.version_info >= (3, 13): |
139 | 145 | _Kernel: TypeAlias = Literal[ |
|
0 commit comments