diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..98b8598eb0 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -13,6 +13,8 @@ # limitations under the License. """PyMC-ArviZ conversion code.""" +from __future__ import annotations + import logging import warnings @@ -20,8 +22,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, + TypeAlias, cast, ) @@ -38,13 +39,16 @@ import pymc -from pymc.model import Model, modelcontext -from pymc.progress_bar import CustomProgress, default_progress_theme -from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import get_default_varnames +from pymc.model import modelcontext +from pymc.util import StrongCoords if TYPE_CHECKING: from pymc.backends.base import MultiTrace + from pymc.model import Model + +from pymc.progress_bar import CustomProgress, default_progress_theme +from pymc.pytensorf import PointFunc, extract_obs_data +from pymc.util import get_default_varnames ___all__ = [""] @@ -56,6 +60,7 @@ # random variable object ... Var = Any +DimsDict: TypeAlias = Mapping[str, Sequence[str]] def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) -def find_observations(model: "Model") -> dict[str, Var]: +def find_observations(model: Model) -> dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} for obs in model.observed_RVs: @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]: return observations -def find_constants(model: "Model") -> dict[str, Var]: +def find_constants(model: Model) -> dict[str, Var]: """If there are constants available, return them as a dictionary.""" model_vars = model.basic_RVs + model.deterministics + model.potentials value_vars = set(model.rvs_to_values.values()) @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]: return constant_data -def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]: +def coords_and_dims_for_inferencedata( + model: Model, +) -> tuple[StrongCoords, DimsDict]: """Parse PyMC model coords and dims format to one accepted by InferenceData.""" coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals @@ -265,7 +272,7 @@ def __init__( self.observations = find_observations(self.model) - def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: + def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]: """Split MultiTrace object into posterior and warmup. Returns @@ -491,7 +498,7 @@ def to_inference_data(self): def to_inference_data( - trace: Optional["MultiTrace"] = None, + trace: MultiTrace | None = None, *, prior: Mapping[str, Any] | None = None, posterior_predictive: Mapping[str, Any] | None = None, @@ -500,7 +507,7 @@ def to_inference_data( coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, - model: Optional["Model"] = None, + model: Model | None = None, save_warmup: bool | None = None, include_transformed: bool = False, ) -> InferenceData: @@ -568,8 +575,8 @@ def to_inference_data( ### perhaps we should have an inplace argument? def predictions_to_inference_data( predictions, - posterior_trace: Optional["MultiTrace"] = None, - model: Optional["Model"] = None, + posterior_trace: MultiTrace | None = None, + model: Model | None = None, coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index cdab3046b1..244822e9e9 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or deemed in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -14,12 +14,14 @@ """Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC.""" +from __future__ import annotations + import warnings from collections.abc import Sequence from functools import singledispatch from types import EllipsisType -from typing import Any, TypeAlias, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -33,8 +35,22 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -from pymc.model import modelcontext -from pymc.pytensorf import convert_observed_data +from pymc.exceptions import ShapeError +from pymc.pytensorf import PotentialShapeType, convert_observed_data +from pymc.util import StrongDims, StrongShape + +if TYPE_CHECKING: + from pymc.model import Model +Shape = int | TensorVariable | Sequence[int | Variable] +Dims = str | Sequence[str | None] +DimsWithEllipsis = str | EllipsisType | Sequence[str | None | EllipsisType] +Size = int | TensorVariable | Sequence[int | Variable] + +# Strong (validated) types from util + +# Additional strong types needed inside this file +StrongDimsWithEllipsis = Sequence[str | EllipsisType] +StrongSize = TensorVariable | tuple[int | Variable, ...] __all__ = [ "change_dist_size", @@ -42,9 +58,6 @@ "to_tuple", ] -from pymc.exceptions import ShapeError -from pymc.pytensorf import PotentialShapeType - def to_tuple(shape): """Convert ints, arrays, and Nones to tuples. @@ -85,19 +98,6 @@ def _check_shape_type(shape): return tuple(out) -# User-provided can be lazily specified as scalars -Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] -Dims: TypeAlias = str | Sequence[str | None] -DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType] -Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] - -# After conversion to vectors -StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...] -StrongDims: TypeAlias = Sequence[str] -StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] -StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] - - def convert_dims(dims: Dims | None) -> StrongDims | None: """Process a user-provided dims variable into None or a valid dims tuple.""" if dims is None: @@ -164,7 +164,7 @@ def convert_size(size: Size) -> StrongSize | None: ) -def shape_from_dims(dims: StrongDims, model) -> StrongShape: +def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape: """Determine shape from a `dims` tuple. Parameters @@ -176,9 +176,12 @@ def shape_from_dims(dims: StrongDims, model) -> StrongShape: Returns ------- - dims : tuple of (str or None) - Names or None for all RV dimensions. + shape : tuple + Shape inferred from model dimension lengths. """ + if model is None: + raise ValueError("model must be provided explicitly to infer shape from dims") + # Dims must be known already unknowndim_dims = set(dims) - set(model.dim_lengths) if unknowndim_dims: @@ -403,6 +406,8 @@ def get_support_shape( assert isinstance(dims, tuple) if len(dims) < ndim_supp: raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}") + from pymc.model.core import modelcontext + model = modelcontext(None) inferred_support_shape = [ model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0) diff --git a/pymc/model/core.py b/pymc/model/core.py index 3630138e00..20c2e3ec47 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -67,6 +67,9 @@ ) from pymc.util import ( UNSET, + Coords, + CoordValue, + StrongCoords, WithMemoization, _UnsetType, get_transformed_name, @@ -453,7 +456,7 @@ def _validate_name(name): def __init__( self, name="", - coords=None, + coords: Coords | None = None, check_bounds=True, *, model: _UnsetType | None | Model = UNSET, @@ -488,7 +491,7 @@ def __init__( self.deterministics = treelist() self.potentials = treelist() self.data_vars = treelist() - self._coords = {} + self._coords: StrongCoords = {} self._dim_lengths = {} self.add_coords(coords) @@ -907,7 +910,7 @@ def unobserved_RVs(self): return self.free_RVs + self.deterministics @property - def coords(self) -> dict[str, tuple | None]: + def coords(self) -> StrongCoords: """Coordinate values for model dimensions.""" return self._coords @@ -937,7 +940,7 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Sequence | np.ndarray | None = None, + values: CoordValue = None, *, length: int | Variable | None = None, ): diff --git a/pymc/printing.py b/pymc/printing.py index 63514ac4d0..180038e650 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,9 +13,12 @@ # limitations under the License. +from __future__ import annotations + import re from functools import partial +from typing import TYPE_CHECKING from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant @@ -26,7 +29,9 @@ from pytensor.tensor.random.type import RandomType from pytensor.tensor.type_other import NoneTypeT -from pymc.model import Model +if TYPE_CHECKING: + from pymc.model import Model + __all__ = [ "str_for_dist", @@ -302,6 +307,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): # register our custom pretty printer in ipython shells import IPython + from pymc.model.core import Model + IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) IPython.lib.pretty.for_type(Model, _default_repr_pretty) except (ModuleNotFoundError, AttributeError): diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 98e177aa03..6fe97de21a 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -30,7 +30,7 @@ class DataClassState: def equal_dataclass_values(v1, v2): if v1.__class__ != v2.__class__: return False - if isinstance(v1, (list, tuple)): # noqa: UP038 + if isinstance(v1, list | tuple): return len(v1) == len(v2) and all( equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) ) diff --git a/pymc/util.py b/pymc/util.py index 32d8d65e70..c9d6943e09 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,9 +16,10 @@ import re from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy -from typing import cast +from types import EllipsisType +from typing import TypeAlias, cast import arviz import cloudpickle @@ -26,11 +27,40 @@ import xarray from cachetools import LRUCache, cachedmethod -from pytensor import Variable from pytensor.compile import SharedVariable +from pytensor.graph.basic import Variable +from pytensor.tensor.variable import TensorVariable from pymc.exceptions import BlockModelAccessError +# ---- User-facing coordinate types ---- +CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None +Coords: TypeAlias = Mapping[str, CoordValue] + +# ---- Internal strong coordinate types ---- +StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None +StrongCoords: TypeAlias = Mapping[str, StrongCoordValue] + +# ---- Internal strong dimension/shape types ---- +StrongDims: TypeAlias = tuple[str, ...] +StrongShape: TypeAlias = tuple[int, ...] + +# User-provided shape before processing +Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] + +# User-provided dims before processing +Dims: TypeAlias = str | Sequence[str | None] + +# User-provided dims that may include ellipsis (...) +DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType] + +# User-provided size before processing +Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] + +# Strong / normalized versions used internally +StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] +StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] + class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs."""