Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7ccf981
Add Coords and StrongCoords typing aliases and standardize model/arvi…
aman-coder03 Dec 2, 2025
77c2acb
Fix ruff formatting, imports, and dev requirements
aman-coder03 Dec 3, 2025
b384bff
Fix circular import by importing modelcontext from pymc.model.core
aman-coder03 Dec 3, 2025
e7da246
Fix circular import by lazily importing modelcontext in shape_from_dims
aman-coder03 Dec 3, 2025
3079404
Fix circular import by using only lazy modelcontext imports
aman-coder03 Dec 3, 2025
49c0db9
Fix Model circular import using TYPE_CHECKING and lazy import
aman-coder03 Dec 3, 2025
ec222e2
Fix lazy modelcontext import flagged by ruff
aman-coder03 Dec 3, 2025
9444946
Fix missing modelcontext import flagged by ruff
aman-coder03 Dec 3, 2025
d3759b8
Fix circular import of Model in printing.py
aman-coder03 Dec 3, 2025
202cb04
Move coords typing to pymc.typing and fix circular imports
aman-coder03 Dec 4, 2025
5d0ecac
Fix ruff UP038 isinstance union style
aman-coder03 Dec 4, 2025
f0a92c5
Fix printing Model NameError and move Coord typing to pymc.typing
aman-coder03 Dec 4, 2025
569b99c
Move coords typing to pymc.typing and fix printing imports
aman-coder03 Dec 4, 2025
b8e5ce1
Remove implicit modelcontext fallback from shape_from_dims
aman-coder03 Dec 4, 2025
d14a85a
Fix shape_from_dims typing and remove circular import
aman-coder03 Dec 4, 2025
fb5980e
Fix typing and import order
aman-coder03 Dec 4, 2025
3f71100
Removing comments from typing.py
aman-coder03 Dec 6, 2025
0da9559
Remove deprecated pymc.typing module after moving aliases to pymc.util
aman-coder03 Dec 6, 2025
27b010d
Fixes
aman-coder03 Dec 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
"""PyMC-ArviZ conversion code."""

from __future__ import annotations

import logging
import warnings

from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
TypeAlias,
cast,
)

Expand All @@ -38,13 +39,16 @@

import pymc

from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames
from pymc.model import modelcontext
from pymc.util import StrongCoords

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
from pymc.model import Model

from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames

___all__ = [""]

Expand All @@ -56,6 +60,7 @@

# random variable object ...
Var = Any
DimsDict: TypeAlias = Mapping[str, Sequence[str]]


def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
Expand Down Expand Up @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)


def find_observations(model: "Model") -> dict[str, Var]:
def find_observations(model: Model) -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
Expand All @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
return observations


def find_constants(model: "Model") -> dict[str, Var]:
def find_constants(model: Model) -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())
Expand All @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
def coords_and_dims_for_inferencedata(
model: Model,
) -> tuple[StrongCoords, DimsDict]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
Expand Down Expand Up @@ -265,7 +272,7 @@ def __init__(

self.observations = find_observations(self.model)

def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]:
"""Split MultiTrace object into posterior and warmup.

Returns
Expand Down Expand Up @@ -491,7 +498,7 @@ def to_inference_data(self):


def to_inference_data(
trace: Optional["MultiTrace"] = None,
trace: MultiTrace | None = None,
*,
prior: Mapping[str, Any] | None = None,
posterior_predictive: Mapping[str, Any] | None = None,
Expand All @@ -500,7 +507,7 @@ def to_inference_data(
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model: Optional["Model"] = None,
model: Model | None = None,
save_warmup: bool | None = None,
include_transformed: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -568,8 +575,8 @@ def to_inference_data(
### perhaps we should have an inplace argument?
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
posterior_trace: MultiTrace | None = None,
model: Model | None = None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
Expand Down
51 changes: 28 additions & 23 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or deemed in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC."""

from __future__ import annotations

import warnings

from collections.abc import Sequence
from functools import singledispatch
from types import EllipsisType
from typing import Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, cast

import numpy as np

Expand All @@ -33,18 +35,29 @@
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable

from pymc.model import modelcontext
from pymc.pytensorf import convert_observed_data
from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType, convert_observed_data
from pymc.util import StrongDims, StrongShape

if TYPE_CHECKING:
from pymc.model import Model
Shape = int | TensorVariable | Sequence[int | Variable]
Dims = str | Sequence[str | None]
DimsWithEllipsis = str | EllipsisType | Sequence[str | None | EllipsisType]
Size = int | TensorVariable | Sequence[int | Variable]

# Strong (validated) types from util

# Additional strong types needed inside this file
StrongDimsWithEllipsis = Sequence[str | EllipsisType]
StrongSize = TensorVariable | tuple[int | Variable, ...]

__all__ = [
"change_dist_size",
"rv_size_is_none",
"to_tuple",
]

from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples.
Expand Down Expand Up @@ -85,19 +98,6 @@ def _check_shape_type(shape):
return tuple(out)


# User-provided can be lazily specified as scalars
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
Dims: TypeAlias = str | Sequence[str | None]
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]

# After conversion to vectors
StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...]
StrongDims: TypeAlias = Sequence[str]
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]


def convert_dims(dims: Dims | None) -> StrongDims | None:
"""Process a user-provided dims variable into None or a valid dims tuple."""
if dims is None:
Expand Down Expand Up @@ -164,7 +164,7 @@ def convert_size(size: Size) -> StrongSize | None:
)


def shape_from_dims(dims: StrongDims, model) -> StrongShape:
def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape:
"""Determine shape from a `dims` tuple.

Parameters
Expand All @@ -176,9 +176,12 @@ def shape_from_dims(dims: StrongDims, model) -> StrongShape:

Returns
-------
dims : tuple of (str or None)
Names or None for all RV dimensions.
shape : tuple
Shape inferred from model dimension lengths.
"""
if model is None:
raise ValueError("model must be provided explicitly to infer shape from dims")

# Dims must be known already
unknowndim_dims = set(dims) - set(model.dim_lengths)
if unknowndim_dims:
Expand Down Expand Up @@ -403,6 +406,8 @@ def get_support_shape(
assert isinstance(dims, tuple)
if len(dims) < ndim_supp:
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
from pymc.model.core import modelcontext

model = modelcontext(None)
inferred_support_shape = [
model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0)
Expand Down
11 changes: 7 additions & 4 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
)
from pymc.util import (
UNSET,
Coords,
CoordValue,
StrongCoords,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -453,7 +456,7 @@ def _validate_name(name):
def __init__(
self,
name="",
coords=None,
coords: Coords | None = None,
check_bounds=True,
*,
model: _UnsetType | None | Model = UNSET,
Expand Down Expand Up @@ -488,7 +491,7 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._coords: StrongCoords = {}
self._dim_lengths = {}
self.add_coords(coords)

Expand Down Expand Up @@ -907,7 +910,7 @@ def unobserved_RVs(self):
return self.free_RVs + self.deterministics

@property
def coords(self) -> dict[str, tuple | None]:
def coords(self) -> StrongCoords:
"""Coordinate values for model dimensions."""
return self._coords

Expand Down Expand Up @@ -937,7 +940,7 @@ def shape_from_dims(self, dims):
def add_coord(
self,
name: str,
values: Sequence | np.ndarray | None = None,
values: CoordValue = None,
*,
length: int | Variable | None = None,
):
Expand Down
9 changes: 8 additions & 1 deletion pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.


from __future__ import annotations

import re

from functools import partial
from typing import TYPE_CHECKING

from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant
Expand All @@ -26,7 +29,9 @@
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.type_other import NoneTypeT

from pymc.model import Model
if TYPE_CHECKING:
from pymc.model import Model


__all__ = [
"str_for_dist",
Expand Down Expand Up @@ -302,6 +307,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
# register our custom pretty printer in ipython shells
import IPython

from pymc.model.core import Model

IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty)
IPython.lib.pretty.for_type(Model, _default_repr_pretty)
except (ModuleNotFoundError, AttributeError):
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassState:
def equal_dataclass_values(v1, v2):
if v1.__class__ != v2.__class__:
return False
if isinstance(v1, (list, tuple)): # noqa: UP038
if isinstance(v1, list | tuple):
return len(v1) == len(v2) and all(
equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
)
Expand Down
36 changes: 33 additions & 3 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,51 @@
import re

from collections import namedtuple
from collections.abc import Sequence
from collections.abc import Hashable, Mapping, Sequence
from copy import deepcopy
from typing import cast
from types import EllipsisType
from typing import TypeAlias, cast

import arviz
import cloudpickle
import numpy as np
import xarray

from cachetools import LRUCache, cachedmethod
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import BlockModelAccessError

# ---- User-facing coordinate types ----
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
Coords: TypeAlias = Mapping[str, CoordValue]

# ---- Internal strong coordinate types ----
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]

# ---- Internal strong dimension/shape types ----
StrongDims: TypeAlias = tuple[str, ...]
StrongShape: TypeAlias = tuple[int, ...]

# User-provided shape before processing
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]

# User-provided dims before processing
Dims: TypeAlias = str | Sequence[str | None]

# User-provided dims that may include ellipsis (...)
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]

# User-provided size before processing
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]

# Strong / normalized versions used internally
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down
Loading