Skip to content

ENH: pymc.math.sum could not be observed #7990

@louiszhuang

Description

@louiszhuang

Describe the issue:

sum function will break sample if observed.

Reproduceable code example:

import pymc as pm
with pm.Model() as m:
    x = pm.Normal("x", mu=0, sigma=1e6)
    y = pm.Normal.dist(x, shape=(5,))
    y_sum = pm.Deterministic("y_sum", pm.math.sum(y))
with pm.observe(m, {"y_sum": 2.0}):
    trace = pm.sample()

Error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 3
      1 #%%
      2 with pm.observe(m, {"y_sum": 2.0}):
----> 3     trace = pm.sample(nuts_sampler='nutpie')

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    779     msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
    780     _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
    783 exclusive_nuts = (
    784     # User provided an instantiated NUTS step, and nothing else is needed
    785     (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
   (...)    792     )
    793 )
    795 if nuts_sampler != "pymc":

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:245, in assign_step_methods(model, step, methods)
    243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
    247 for var in model.value_vars:
    248     if var not in assigned_vars:
    249         # determine if a gradient can be computed

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\model\core.py:714, in Model.logp(self, vars, jacobian, sum)
    712 rv_logps: list[TensorVariable] = []
    713 if rvs:
--> 714     rv_logps = transformed_conditional_logp(
    715         rvs=rvs,
    716         rvs_to_values=self.rvs_to_values,
    717         rvs_to_transforms=self.rvs_to_transforms,
    718         jacobian=jacobian,
    719     )
    720     assert isinstance(rv_logps, list)
    722 # Replace random variables by their value variables in potential terms

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:574, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    571     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore[arg-type]
    573 kwargs.setdefault("warn_rvs", False)
--> 574 temp_logp_terms = conditional_logp(
    575     rvs_to_values,
    576     extra_rewrites=transform_rewrite,
    577     use_jacobian=jacobian,
    578     **kwargs,
    579 )
    581 # The function returns the logp for every single value term we provided to it.
    582 # This includes the extra values we plugged in above, so we filter those we
    583 # actually wanted in the same order they were given in.
    584 logp_terms = {}

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:531, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    529 missing_value_terms = set(original_values) - set(values_to_logprobs)
    530 if missing_value_terms:
--> 531     raise RuntimeError(
    532         f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
    533     )
    535 # Ensure same order as input
    536 logprobs = cleanup_ir(tuple(values_to_logprobs[v] for v in original_values))

RuntimeError: The logprob terms of the following value variables could not be derived: {TensorConstant(TensorType(float64, shape=()), data=array(2.))}

PyMC version information:

5.26.1

Context for the issue:

observing sum/max etc will be very helpful for many cases

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions