|
1 | | -from typing import Any, Dict, List, Optional, Sequence, Union |
| 1 | +from typing import Any, List, Mapping, Optional, Sequence, Union |
2 | 2 |
|
3 | 3 | from pymc import Model |
4 | 4 | from pymc.logprob.transforms import RVTransform |
|
26 | 26 | from pymc_experimental.utils.pytensorf import rvs_in_graph |
27 | 27 |
|
28 | 28 |
|
29 | | -def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model: |
| 29 | +def observe( |
| 30 | + model: Model, vars_to_observations: Mapping[Union["str", TensorVariable], Any] |
| 31 | +) -> Model: |
30 | 32 | """Convert free RVs or Deterministics to observed RVs. |
31 | 33 |
|
32 | 34 | Parameters |
@@ -122,7 +124,9 @@ def replacement_fn(var, inner_replacements): |
122 | 124 |
|
123 | 125 |
|
124 | 126 | def do( |
125 | | - model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False |
| 127 | + model: Model, |
| 128 | + vars_to_interventions: Mapping[Union["str", TensorVariable], Any], |
| 129 | + prune_vars=False, |
126 | 130 | ) -> Model: |
127 | 131 | """Replace model variables by intervention variables. |
128 | 132 |
|
@@ -217,7 +221,7 @@ def do( |
217 | 221 |
|
218 | 222 | def change_value_transforms( |
219 | 223 | model: Model, |
220 | | - vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]], |
| 224 | + vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], |
221 | 225 | ) -> Model: |
222 | 226 | """Change the value variables transforms in the model |
223 | 227 |
|
|
0 commit comments