Skip to content

Commit 202cb04

Browse files
committed
Move coords typing to pymc.typing and fix circular imports
1 parent d3759b8 commit 202cb04

File tree

4 files changed

+43
-13
lines changed

4 files changed

+43
-13
lines changed

pymc/backends/arviz.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@
3939

4040
import pymc
4141

42-
from pymc.distributions.shape_utils import StrongCoords
43-
from pymc.model import Model, modelcontext
44-
from pymc.progress_bar import CustomProgress, default_progress_theme
45-
from pymc.pytensorf import PointFunc, extract_obs_data
46-
from pymc.util import get_default_varnames
42+
from pymc.model import modelcontext
43+
from pymc.typing import StrongCoords
4744

4845
if TYPE_CHECKING:
4946
from pymc.backends.base import MultiTrace
47+
from pymc.model import Model
48+
49+
from pymc.progress_bar import CustomProgress, default_progress_theme
50+
from pymc.pytensorf import PointFunc, extract_obs_data
51+
from pymc.util import get_default_varnames
5052

5153
___all__ = [""]
5254

pymc/distributions/shape_utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818

19-
from collections.abc import Hashable, Mapping, Sequence
19+
from collections.abc import Sequence
2020
from functools import singledispatch
2121
from types import EllipsisType
2222
from typing import Any, TypeAlias, cast
@@ -97,12 +97,6 @@ def _check_shape_type(shape):
9797
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
9898
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]
9999

100-
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
101-
Coords: TypeAlias = Mapping[str, CoordValue]
102-
103-
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
104-
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]
105-
106100

107101
def convert_dims(dims: Dims | None) -> StrongDims | None:
108102
"""Process a user-provided dims variable into None or a valid dims tuple."""
@@ -416,6 +410,7 @@ def get_support_shape(
416410
if len(dims) < ndim_supp:
417411
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
418412
from pymc.model.core import modelcontext
413+
419414
model = modelcontext(None)
420415
inferred_support_shape = [
421416
model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0)

pymc/step_methods/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DataClassState:
3030
def equal_dataclass_values(v1, v2):
3131
if v1.__class__ != v2.__class__:
3232
return False
33-
if isinstance(v1, (list, tuple)): # noqa: UP038
33+
if isinstance(v1, (list, tuple)):
3434
return len(v1) == len(v2) and all(
3535
equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
3636
)

pymc/typing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from __future__ import annotations
17+
18+
from collections.abc import Hashable, Mapping, Sequence
19+
from typing import TypeAlias
20+
21+
import numpy as np
22+
23+
# -------------------------
24+
# Coordinate typing helpers
25+
# -------------------------
26+
27+
# User-facing coordinate values (before normalization)
28+
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
29+
Coords: TypeAlias = Mapping[str, CoordValue]
30+
31+
# After normalization / internal representation
32+
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
33+
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]

0 commit comments

Comments
 (0)