Unify constraint handling for AutoContinuous, AutoDelta, AutoNormal.#2015
Unify constraint handling for AutoContinuous, AutoDelta, AutoNormal.#2015tillahoffmann wants to merge 5 commits intopyro-ppl:masterfrom
AutoContinuous, AutoDelta, AutoNormal.#2015Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR unifies constraint handling across auto guides by constructing distributions in the unconstrained space and then transforming them to the appropriate constrained support. Key changes include updating tests to verify matching supports between guide and model, revising docstrings to emphasize unconstrained values, and refactoring AutoGuide implementations to use a common helper for constraint handling.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| test/infer/test_autoguide.py | Added tests to verify that guide and model distributions have matching supports. |
| numpyro/infer/util.py | Updated docstrings to clarify that parameters are now in unconstrained space. |
| numpyro/infer/autoguide.py | Refactored constraint handling in auto guides including updated event dimension logic. |
Comments suppressed due to low confidence (1)
test/infer/test_autoguide.py:487
- Consider adding AutoContinuous to the parameterized tests to ensure its constraint handling behavior is also validated.
@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])
c1a47d1 to
65f1b29
Compare
65f1b29 to
d861328
Compare
|
Interestingly, the |
|
I think it is memory issue. Could you try to set a smaller number for this flag: https://github.com/pyro-ppl/numpyro/blob/master/examples/neutra.py#L205 in the test? |
|
The example didn't pass with a smaller number of hidden features. I had a look at the last passing and first failing runs. The dependencies are almost the same and don't look like they would cause this issue. Maybe something changed about the runner? On a related note, should we consider locking the requirements using a version-locked 20c20
< coverage==7.7.1
---
> coverage==7.8.0
32c32
< flax==0.10.4
---
> flax==0.10.5
35c35
< fsspec==2025.3.0
---
> fsspec==2025.3.2
98c98
< -e git+https://github.com/pyro-ppl/numpyro@19fbd57d96973d7d33ec594ad99110ff544e9ea7#egg=numpyro
---
> -e git+https://github.com/pyro-ppl/numpyro@4027928e31e859cfc7eacc41d8b53baba36137de#egg=numpyro
133c133
< rich==13.9.4
---
> rich==14.0.0
170c170
< xarray==2025.3.0
---
> xarray==2025.3.1 |
|
Could you mark this test as xfail in CI instead? like in https://github.com/pyro-ppl/numpyro/blob/master/test/test_examples.py#L52-L59 Re version lock: how does it work? what if users want to just install numpyro without upgrading/degrading other dependencies? |
| f"{name}_{self.prefix}_loc", init_loc, event_dim=event_dim | ||
| ) | ||
|
|
||
| site_fn = dist.Delta(site_loc).to_event(event_dim) |
There was a problem hiding this comment.
Here you are finding the MAP point in unconstrained space. This class gets MAP point in constrained space.
| transform = biject_to(site["fn"].support) | ||
| value = transform(unconstrained_value) | ||
| event_ndim = site["fn"].event_dim | ||
| if numpyro.get_mask() is False: |
There was a problem hiding this comment.
we need this logic to save computation for prediction.
This PR unifies constraint handling for several auto guides. In short, distributions for all variables in the guide are constructed in unconstrained space. If the variable has non-real support, the distribution is transformed.
The original motivation was to ensure that the support of random variables in the guide and model match.
AutoDeltaandAutoContinuousdid not meet that requirement because they useDeltadistributions in constrained space.