|
37 | 37 |
|
38 | 38 | from pathlib import Path |
39 | 39 |
|
40 | | -import pytensor |
41 | | - |
42 | 40 | from pytensor import tensor as pt |
43 | | -from pytensor.graph.fg import FunctionGraph |
44 | | -from pytensor.graph.op import compute_test_value |
45 | 41 | from pytensor.graph.rewriting.basic import node_rewriter |
46 | 42 | from pytensor.tensor import TensorVariable |
47 | | -from pytensor.tensor.basic import Alloc, Join, MakeVector |
| 43 | +from pytensor.tensor.basic import Join, MakeVector |
48 | 44 | from pytensor.tensor.elemwise import DimShuffle |
49 | 45 | from pytensor.tensor.random.op import RandomVariable |
50 | 46 | from pytensor.tensor.random.rewriting import ( |
51 | 47 | local_dimshuffle_rv_lift, |
52 | | - local_rv_size_lift, |
53 | 48 | ) |
54 | 49 |
|
55 | 50 | from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper |
|
62 | 57 | from pymc.pytensorf import constant_fold |
63 | 58 |
|
64 | 59 |
|
65 | | -@node_rewriter([Alloc]) |
66 | | -def naive_bcast_rv_lift(fgraph: FunctionGraph, node): |
67 | | - """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. |
68 | | -
|
69 | | - XXX: This implementation simply broadcasts the ``RandomVariable``'s |
70 | | - parameters, which won't always work (e.g. multivariate distributions). |
71 | | -
|
72 | | - TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to |
73 | | - determine which dimensions of each parameter need to be broadcasted. |
74 | | - Also, this doesn't need to remove ``size`` to perform the lifting, like it |
75 | | - currently does. |
76 | | - """ |
77 | | - |
78 | | - if not ( |
79 | | - isinstance(node.op, Alloc) |
80 | | - and node.inputs[0].owner |
81 | | - and isinstance(node.inputs[0].owner.op, RandomVariable) |
82 | | - ): |
83 | | - return None # pragma: no cover |
84 | | - |
85 | | - bcast_shape = node.inputs[1:] |
86 | | - |
87 | | - rv_var = node.inputs[0] |
88 | | - rv_node = rv_var.owner |
89 | | - |
90 | | - if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: |
91 | | - return None # pragma: no cover |
92 | | - |
93 | | - # Do not replace RV if it is associated with a value variable |
94 | | - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) |
95 | | - if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: |
96 | | - return None |
97 | | - |
98 | | - if not bcast_shape: |
99 | | - # The `Alloc` is broadcasting a scalar to a scalar (i.e. doing nothing) |
100 | | - assert rv_var.ndim == 0 |
101 | | - return [rv_var] |
102 | | - |
103 | | - size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) |
104 | | - if size_lift_res is None: |
105 | | - lifted_node = rv_node |
106 | | - else: |
107 | | - _, lifted_rv = size_lift_res |
108 | | - lifted_node = lifted_rv.owner |
109 | | - |
110 | | - rng, size, *dist_params = lifted_node.inputs |
111 | | - |
112 | | - new_dist_params = [ |
113 | | - pt.broadcast_to( |
114 | | - param, |
115 | | - pt.broadcast_shape(tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True), |
116 | | - ) |
117 | | - for param in dist_params |
118 | | - ] |
119 | | - bcasted_node = lifted_node.op.make_node(rng, size, *new_dist_params) |
120 | | - |
121 | | - if pytensor.config.compute_test_value != "off": |
122 | | - compute_test_value(bcasted_node) |
123 | | - |
124 | | - return [bcasted_node.outputs[1]] |
125 | | - |
126 | | - |
127 | 60 | class MeasurableMakeVector(MeasurableOp, MakeVector): |
128 | 61 | """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" |
129 | 62 |
|
|
0 commit comments