Skip to content

Commit 27b010d

Browse files
committed
Fixes
1 parent 0da9559 commit 27b010d

File tree

4 files changed

+43
-27
lines changed

4 files changed

+43
-27
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import pymc
4141

4242
from pymc.model import modelcontext
43-
from pymc.util import StrongCoords
43+
from pymc.util import StrongCoords
4444

4545
if TYPE_CHECKING:
4646
from pymc.backends.base import MultiTrace

pymc/distributions/shape_utils.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
# Unless required by applicable law or agreed to in writing, software
9+
# Unless required by applicable law or deemed in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

1515
"""Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC."""
1616

17+
from __future__ import annotations
18+
1719
import warnings
1820

1921
from collections.abc import Sequence
2022
from functools import singledispatch
2123
from types import EllipsisType
22-
from typing import TYPE_CHECKING, Any, TypeAlias, cast
24+
from typing import TYPE_CHECKING, Any, cast
2325

2426
import numpy as np
2527

@@ -33,21 +35,29 @@
3335
from pytensor.tensor.type_other import NoneTypeT
3436
from pytensor.tensor.variable import TensorVariable
3537

36-
from pymc.pytensorf import convert_observed_data
38+
from pymc.exceptions import ShapeError
39+
from pymc.pytensorf import PotentialShapeType, convert_observed_data
40+
from pymc.util import StrongDims, StrongShape
3741

3842
if TYPE_CHECKING:
3943
from pymc.model import Model
44+
Shape = int | TensorVariable | Sequence[int | Variable]
45+
Dims = str | Sequence[str | None]
46+
DimsWithEllipsis = str | EllipsisType | Sequence[str | None | EllipsisType]
47+
Size = int | TensorVariable | Sequence[int | Variable]
4048

49+
# Strong (validated) types from util
50+
51+
# Additional strong types needed inside this file
52+
StrongDimsWithEllipsis = Sequence[str | EllipsisType]
53+
StrongSize = TensorVariable | tuple[int | Variable, ...]
4154

4255
__all__ = [
4356
"change_dist_size",
4457
"rv_size_is_none",
4558
"to_tuple",
4659
]
4760

48-
from pymc.exceptions import ShapeError
49-
from pymc.pytensorf import PotentialShapeType
50-
5161

5262
def to_tuple(shape):
5363
"""Convert ints, arrays, and Nones to tuples.
@@ -88,19 +98,6 @@ def _check_shape_type(shape):
8898
return tuple(out)
8999

90100

91-
# User-provided can be lazily specified as scalars
92-
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
93-
Dims: TypeAlias = str | Sequence[str | None]
94-
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
95-
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]
96-
97-
# After conversion to vectors
98-
StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...]
99-
StrongDims: TypeAlias = Sequence[str]
100-
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
101-
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]
102-
103-
104101
def convert_dims(dims: Dims | None) -> StrongDims | None:
105102
"""Process a user-provided dims variable into None or a valid dims tuple."""
106103
if dims is None:
@@ -167,7 +164,7 @@ def convert_size(size: Size) -> StrongSize | None:
167164
)
168165

169166

170-
def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape:
167+
def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape:
171168
"""Determine shape from a `dims` tuple.
172169
173170
Parameters

pymc/model/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@
6666
rewrite_pregrad,
6767
)
6868
from pymc.util import (
69+
UNSET,
6970
Coords,
7071
CoordValue,
7172
StrongCoords,
72-
UNSET,
7373
WithMemoization,
7474
_UnsetType,
7575
get_transformed_name,
@@ -78,7 +78,6 @@
7878
treedict,
7979
treelist,
8080
)
81-
8281
from pymc.vartypes import continuous_types, discrete_types, typefilter
8382

8483
__all__ = [

pymc/util.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,51 @@
1616
import re
1717

1818
from collections import namedtuple
19-
from collections.abc import Sequence
19+
from collections.abc import Hashable, Mapping, Sequence
2020
from copy import deepcopy
21-
from typing import Mapping, TypeAlias, Hashable, cast
21+
from types import EllipsisType
22+
from typing import TypeAlias, cast
2223

2324
import arviz
2425
import cloudpickle
2526
import numpy as np
2627
import xarray
2728

2829
from cachetools import LRUCache, cachedmethod
29-
from pytensor import Variable
3030
from pytensor.compile import SharedVariable
31+
from pytensor.graph.basic import Variable
32+
from pytensor.tensor.variable import TensorVariable
3133

3234
from pymc.exceptions import BlockModelAccessError
3335

34-
#Coordinate & Shape Typing
36+
# ---- User-facing coordinate types ----
3537
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
3638
Coords: TypeAlias = Mapping[str, CoordValue]
3739

40+
# ---- Internal strong coordinate types ----
3841
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
3942
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]
4043

44+
# ---- Internal strong dimension/shape types ----
4145
StrongDims: TypeAlias = tuple[str, ...]
4246
StrongShape: TypeAlias = tuple[int, ...]
4347

48+
# User-provided shape before processing
49+
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
50+
51+
# User-provided dims before processing
52+
Dims: TypeAlias = str | Sequence[str | None]
53+
54+
# User-provided dims that may include ellipsis (...)
55+
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
56+
57+
# User-provided size before processing
58+
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]
59+
60+
# Strong / normalized versions used internally
61+
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
62+
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]
63+
4464

4565
class _UnsetType:
4666
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""

0 commit comments

Comments
 (0)