|
6 | 6 | # |
7 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
8 | 8 | # |
9 | | -# Unless required by applicable law or agreed to in writing, software |
| 9 | +# Unless required by applicable law or deemed in writing, software |
10 | 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC.""" |
16 | 16 |
|
| 17 | +from __future__ import annotations |
| 18 | + |
17 | 19 | import warnings |
18 | 20 |
|
19 | 21 | from collections.abc import Sequence |
20 | 22 | from functools import singledispatch |
21 | 23 | from types import EllipsisType |
22 | | -from typing import TYPE_CHECKING, Any, TypeAlias, cast |
| 24 | +from typing import TYPE_CHECKING, Any, cast |
23 | 25 |
|
24 | 26 | import numpy as np |
25 | 27 |
|
|
33 | 35 | from pytensor.tensor.type_other import NoneTypeT |
34 | 36 | from pytensor.tensor.variable import TensorVariable |
35 | 37 |
|
36 | | -from pymc.pytensorf import convert_observed_data |
| 38 | +from pymc.exceptions import ShapeError |
| 39 | +from pymc.pytensorf import PotentialShapeType, convert_observed_data |
| 40 | +from pymc.util import StrongDims, StrongShape |
37 | 41 |
|
38 | 42 | if TYPE_CHECKING: |
39 | 43 | from pymc.model import Model |
| 44 | +Shape = int | TensorVariable | Sequence[int | Variable] |
| 45 | +Dims = str | Sequence[str | None] |
| 46 | +DimsWithEllipsis = str | EllipsisType | Sequence[str | None | EllipsisType] |
| 47 | +Size = int | TensorVariable | Sequence[int | Variable] |
40 | 48 |
|
| 49 | +# Strong (validated) types from util |
| 50 | + |
| 51 | +# Additional strong types needed inside this file |
| 52 | +StrongDimsWithEllipsis = Sequence[str | EllipsisType] |
| 53 | +StrongSize = TensorVariable | tuple[int | Variable, ...] |
41 | 54 |
|
42 | 55 | __all__ = [ |
43 | 56 | "change_dist_size", |
44 | 57 | "rv_size_is_none", |
45 | 58 | "to_tuple", |
46 | 59 | ] |
47 | 60 |
|
48 | | -from pymc.exceptions import ShapeError |
49 | | -from pymc.pytensorf import PotentialShapeType |
50 | | - |
51 | 61 |
|
52 | 62 | def to_tuple(shape): |
53 | 63 | """Convert ints, arrays, and Nones to tuples. |
@@ -88,19 +98,6 @@ def _check_shape_type(shape): |
88 | 98 | return tuple(out) |
89 | 99 |
|
90 | 100 |
|
91 | | -# User-provided can be lazily specified as scalars |
92 | | -Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] |
93 | | -Dims: TypeAlias = str | Sequence[str | None] |
94 | | -DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType] |
95 | | -Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] |
96 | | - |
97 | | -# After conversion to vectors |
98 | | -StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...] |
99 | | -StrongDims: TypeAlias = Sequence[str] |
100 | | -StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] |
101 | | -StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] |
102 | | - |
103 | | - |
104 | 101 | def convert_dims(dims: Dims | None) -> StrongDims | None: |
105 | 102 | """Process a user-provided dims variable into None or a valid dims tuple.""" |
106 | 103 | if dims is None: |
@@ -167,7 +164,7 @@ def convert_size(size: Size) -> StrongSize | None: |
167 | 164 | ) |
168 | 165 |
|
169 | 166 |
|
170 | | -def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape: |
| 167 | +def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape: |
171 | 168 | """Determine shape from a `dims` tuple. |
172 | 169 |
|
173 | 170 | Parameters |
|
0 commit comments