3333# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535# SOFTWARE.
36-
37-
3836from typing import cast
3937
4038import pytensor .tensor as pt
4139
4240from pytensor .graph .basic import Apply
4341from pytensor .graph .fg import FunctionGraph
4442from pytensor .graph .rewriting .basic import node_rewriter
45- from pytensor .tensor .elemwise import Elemwise
4643from pytensor .tensor .math import Max
47- from pytensor .tensor .random .op import RandomVariable
4844from pytensor .tensor .variable import TensorVariable
4945
5046from pymc .logprob .abstract import (
47+ MeasurableElemwise ,
48+ MeasurableOp ,
5149 MeasurableOpMixin ,
5250 _logcdf_helper ,
5351 _logprob ,
5452 _logprob_helper ,
5553)
5654from pymc .logprob .rewriting import measurable_ir_rewrites_db
57- from pymc .logprob .utils import find_negated_var
5855from pymc .math import logdiffexp
5956from pymc .pytensorf import constant_fold
6057
@@ -73,25 +70,41 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
7370 if rv_map_feature is None :
7471 return None # pragma: no cover
7572
76- if isinstance (node .op , MeasurableMax ):
77- return None # pragma: no cover
73+ if isinstance (node .op , MeasurableMax | MeasurableMaxDiscrete ):
74+ return None
7875
79- base_var = cast ( TensorVariable , node .inputs [ 0 ])
76+ [ base_var ] = node .inputs
8077
8178 if base_var .owner is None :
8279 return None
8380
8481 if not rv_map_feature .request_measurable (node .inputs ):
8582 return None
8683
87- # Non-univariate distributions and non-RVs must be rejected
88- if not (isinstance (base_var .owner .op , RandomVariable ) and base_var .owner .op .ndim_supp == 0 ):
84+ # We allow Max of RandomVariables or Elemwise of univariate RandomVariables
85+ if isinstance (base_var .owner .op , MeasurableElemwise ):
86+ latent_base_vars = [
87+ var
88+ for var in base_var .owner .inputs
89+ if (var .owner and isinstance (var .owner .op , MeasurableOp ))
90+ ]
91+ if len (latent_base_vars ) != 1 :
92+ return None
93+ [latent_base_var ] = latent_base_vars
94+ else :
95+ latent_base_var = base_var
96+
97+ latent_op = latent_base_var .owner .op
98+ if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
8999 return None
90100
91101 # univariate i.i.d. test which also rules out other distributions
92- for params in base_var .owner .op .dist_params (base_var .owner ):
93- if not all (params .type .broadcastable ):
94- return None
102+ if not all (
103+ all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
104+ ):
105+ return None
106+
107+ base_var = cast (TensorVariable , base_var )
95108
96109 if node .op .axis is None :
97110 axis = tuple (range (base_var .ndim ))
@@ -102,16 +115,11 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
102115 return None
103116
104117 # distinguish measurable discrete and continuous (because logprob is different)
105- measurable_max : Max
106- if base_var .type .dtype .startswith ("int" ):
107- measurable_max = MeasurableMaxDiscrete (axis )
108- else :
109- measurable_max = MeasurableMax (axis )
110-
111- max_rv_node = measurable_max .make_node (base_var )
112- max_rv = max_rv_node .outputs
113-
114- return max_rv
118+ measurable_max_class = (
119+ MeasurableMaxDiscrete if latent_base_var .type .dtype .startswith ("int" ) else MeasurableMax
120+ )
121+ max_rv = cast (TensorVariable , measurable_max_class (axis )(base_var ))
122+ return [max_rv ]
115123
116124
117125measurable_ir_rewrites_db .register (
@@ -127,13 +135,13 @@ def max_logprob(op, values, base_rv, **kwargs):
127135 r"""Compute the log-likelihood graph for the `Max` operation."""
128136 (value ,) = values
129137
130- logprob = _logprob_helper (base_rv , value )
131- logcdf = _logcdf_helper (base_rv , value )
138+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
139+ bcast_value = pt .broadcast_to (value , base_rv_shape )
140+ logprob = _logprob_helper (base_rv , bcast_value )[0 ]
141+ logcdf = _logcdf_helper (base_rv , bcast_value )[0 ]
132142
133- [n ] = constant_fold ([base_rv .size ])
134- logprob = (n - 1 ) * logcdf + logprob + pt .math .log (n )
135-
136- return logprob
143+ n = pt .prod (base_rv_shape )
144+ return (n - 1 ) * logcdf + logprob + pt .math .log (n )
137145
138146
139147@_logprob .register (MeasurableMaxDiscrete )
@@ -146,126 +154,11 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
146154 where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
147155 """
148156 (value ,) = values
149- logcdf = _logcdf_helper (base_rv , value )
150- logcdf_prev = _logcdf_helper (base_rv , value - 1 )
151-
152- [n ] = constant_fold ([base_rv .size ])
153-
154- logprob = logdiffexp (n * logcdf , n * logcdf_prev )
155-
156- return logprob
157-
158-
159- class MeasurableMaxNeg (MeasurableOpMixin , Max ):
160- """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
161- This shows up in the graph of min, which is (neg(max(neg(x)))."""
162-
163-
164- class MeasurableDiscreteMaxNeg (MeasurableOpMixin , Max ):
165- """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
166-
167-
168- @node_rewriter (tracks = [Max ])
169- def find_measurable_max_neg (fgraph : FunctionGraph , node : Apply ) -> list [TensorVariable ] | None :
170- rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
171-
172- if rv_map_feature is None :
173- return None # pragma: no cover
174-
175- if isinstance (node .op , MeasurableMaxNeg ):
176- return None # pragma: no cover
177-
178- base_var = cast (TensorVariable , node .inputs [0 ])
179-
180- # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
181- if not (base_var .owner is not None and isinstance (base_var .owner .op , Elemwise )):
182- return None
183-
184- base_rv = find_negated_var (base_var )
185-
186- # negation is rv * (-1). Hence the scalar_op must be Mul
187- if base_rv is None :
188- return None
189-
190- # Non-univariate distributions and non-RVs must be rejected
191- if not (isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .ndim_supp == 0 ):
192- return None
193-
194- # univariate i.i.d. test which also rules out other distributions
195- for params in base_rv .owner .op .dist_params (base_rv .owner ):
196- if not all (params .type .broadcastable ):
197- return None
198157
199- if node .op .axis is None :
200- axis = tuple (range (base_var .ndim ))
201- else :
202- # Check whether axis is supported or not
203- axis = tuple (sorted (node .op .axis ))
204- if axis != tuple (range (base_var .ndim )):
205- return None
206-
207- if not rv_map_feature .request_measurable ([base_rv ]):
208- return None
209-
210- # distinguish measurable discrete and continuous (because logprob is different)
211- measurable_min : Max
212- if base_rv .type .dtype .startswith ("int" ):
213- measurable_min = MeasurableDiscreteMaxNeg (axis )
214- else :
215- measurable_min = MeasurableMaxNeg (axis )
216-
217- return measurable_min .make_node (base_rv ).outputs
218-
219-
220- measurable_ir_rewrites_db .register (
221- "find_measurable_max_neg" ,
222- find_measurable_max_neg ,
223- "basic" ,
224- "min" ,
225- )
226-
227-
228- @_logprob .register (MeasurableMaxNeg )
229- def max_neg_logprob (op , values , base_rv , ** kwargs ):
230- r"""Compute the log-likelihood graph for the `Max` operation.
231- The formula that we use here is :
232- \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
233- where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
234- """
235- (value ,) = values
236-
237- logprob = _logprob_helper (base_rv , - value )
238- logcdf = _logcdf_helper (base_rv , - value )
239-
240- [n ] = constant_fold ([base_rv .size ])
241- logprob = (n - 1 ) * pt .math .log (1 - pt .math .exp (logcdf )) + logprob + pt .math .log (n )
242-
243- return logprob
244-
245-
246- @_logprob .register (MeasurableDiscreteMaxNeg )
247- def discrete_max_neg_logprob (op , values , base_rv , ** kwargs ):
248- r"""Compute the log-likelihood graph for the `Max` operation.
249-
250- The formula that we use here is :
251- .. math::
252- \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
253- where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
254- """
255-
256- (value ,) = values
257-
258- # The cdf of a negative variable is the survival at the negated value
259- logcdf = pt .log1mexp (_logcdf_helper (base_rv , - value ))
260- logcdf_prev = pt .log1mexp (_logcdf_helper (base_rv , - (value + 1 )))
261-
262- [n ] = constant_fold ([base_rv .size ])
263-
264- # Now we can use the same expression as the discrete max
265- logprob = pt .where (
266- pt .and_ (pt .eq (logcdf , - pt .inf ), pt .eq (logcdf_prev , - pt .inf )),
267- - pt .inf ,
268- logdiffexp (n * logcdf_prev , n * logcdf ),
269- )
158+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
159+ bcast_value = pt .broadcast_to (value , base_rv_shape )
160+ logcdf = _logcdf_helper (base_rv , bcast_value )[0 ]
161+ logcdf_prev = _logcdf_helper (base_rv , bcast_value - 1 )[0 ]
270162
271- return logprob
163+ n = pt .prod (base_rv_shape )
164+ return logdiffexp (n * logcdf , n * logcdf_prev )
0 commit comments