Skip to content

Commit 4092415

Browse files
authored
statistics: relax linear_regression input types (#15249)
1 parent 9d40875 commit 4092415

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

stdlib/statistics.pyi

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import sys
22
from _typeshed import SupportsRichComparisonT
3-
from collections.abc import Callable, Hashable, Iterable, Sequence
3+
from collections.abc import Callable, Hashable, Iterable, Sequence, Sized
44
from decimal import Decimal
55
from fractions import Fraction
6-
from typing import Literal, NamedTuple, SupportsFloat, SupportsIndex, TypeVar
6+
from typing import Literal, NamedTuple, Protocol, SupportsFloat, SupportsIndex, TypeVar
77
from typing_extensions import Self, TypeAlias
88

99
__all__ = [
@@ -41,6 +41,10 @@ _HashableT = TypeVar("_HashableT", bound=Hashable)
4141
# Used in NormalDist.samples and kde_random
4242
_Seed: TypeAlias = int | float | str | bytes | bytearray # noqa: Y041
4343

44+
# Used in linear_regression
45+
_T_co = TypeVar("_T_co", covariant=True)
46+
47+
class _SizedIterable(Iterable[_T_co], Sized, Protocol[_T_co]): ...
4448
class StatisticsError(ValueError): ...
4549

4650
if sys.version_info >= (3, 11):
@@ -129,11 +133,13 @@ if sys.version_info >= (3, 10):
129133

130134
if sys.version_info >= (3, 11):
131135
def linear_regression(
132-
regressor: Sequence[_Number], dependent_variable: Sequence[_Number], /, *, proportional: bool = False
136+
regressor: _SizedIterable[_Number], dependent_variable: _SizedIterable[_Number], /, *, proportional: bool = False
133137
) -> LinearRegression: ...
134138

135139
elif sys.version_info >= (3, 10):
136-
def linear_regression(regressor: Sequence[_Number], dependent_variable: Sequence[_Number], /) -> LinearRegression: ...
140+
def linear_regression(
141+
regressor: _SizedIterable[_Number], dependent_variable: _SizedIterable[_Number], /
142+
) -> LinearRegression: ...
137143

138144
if sys.version_info >= (3, 13):
139145
_Kernel: TypeAlias = Literal[

0 commit comments

Comments
 (0)