@@ -12,6 +12,7 @@ The available implementations are
1212"""
1313module Proposals
1414
15+ using .. NestedSamplers: prior_transform_and_loglikelihood
1516using .. Bounds
1617
1718using Random
5455
5556@deprecate Uniform () Rejection ()
5657
57- function (prop:: Rejection )(rng:: AbstractRNG ,
58+ function (prop:: Rejection )(
59+ rng:: AbstractRNG ,
5860 point:: AbstractVector ,
5961 logl_star,
6062 bounds:: AbstractBoundingSpace ,
61- loglike,
62- prior_transform )
63+ model
64+ )
6365
6466 ncall = 0
6567 for _ in 1 : prop. maxiter
6668 u = rand (rng, bounds)
6769 unitcheck (u) || continue
68- v = prior_transform (u)
69- logl = loglike (v)
70+ v, logl = prior_transform_and_loglikelihood (model, u)
7071 ncall += 1
7172 logl ≥ logl_star && return u, v, logl, ncall
7273 end
@@ -95,13 +96,14 @@ Propose a new live point by random walking away from an existing live point.
9596 @assert scale ≥ 0 " Proposal scale must be non-negative"
9697end
9798
98- function (prop:: RWalk )(rng:: AbstractRNG ,
99- point:: AbstractVector ,
100- logl_star,
101- bounds:: AbstractBoundingSpace ,
102- loglike,
103- prior_transform;
104- kwargs... )
99+ function (prop:: RWalk )(
100+ rng:: AbstractRNG ,
101+ point:: AbstractVector ,
102+ logl_star,
103+ bounds:: AbstractBoundingSpace ,
104+ model;
105+ kwargs...
106+ )
105107 # setup
106108 n = length (point)
107109 scale_init = prop. scale
@@ -129,8 +131,7 @@ function (prop::RWalk)(rng::AbstractRNG,
129131 end
130132 end
131133 # check proposed point
132- v_prop = prior_transform (u_prop)
133- logl_prop = loglike (v_prop)
134+ v_prop, logl_prop = prior_transform_and_loglikelihood (model, u_prop)
134135 if logl_prop ≥ logl_star
135136 u = u_prop
136137 v = v_prop
@@ -188,13 +189,14 @@ proposals.
188189 @assert scale ≥ 0 " Proposal scale must be non-negative"
189190end
190191
191- function (prop:: RStagger )(rng:: AbstractRNG ,
192- point:: AbstractVector ,
193- logl_star,
194- bounds:: AbstractBoundingSpace ,
195- loglike,
196- prior_transform;
197- kwargs... )
192+ function (prop:: RStagger )(
193+ rng:: AbstractRNG ,
194+ point:: AbstractVector ,
195+ logl_star,
196+ bounds:: AbstractBoundingSpace ,
197+ model;
198+ kwargs...
199+ )
198200 # setup
199201 n = length (point)
200202 scale_init = prop. scale
@@ -223,8 +225,7 @@ function (prop::RStagger)(rng::AbstractRNG,
223225 end
224226 end
225227 # check proposed point
226- v_prop = prior_transform (u_prop)
227- logl_prop = loglike (v_prop)
228+ v_prop, logl_prop = prior_transform_and_loglikelihood (model, u_prop)
228229 if logl_prop ≥ logl_star
229230 u = u_prop
230231 v = v_prop
@@ -276,13 +277,14 @@ This is a standard _Gibbs-like_ implementation where a single multivariate slice
276277 @assert scale ≥ 0 " Proposal scale must be non-negative"
277278end
278279
279- function (prop:: Slice )(rng:: AbstractRNG ,
280- point:: AbstractVector ,
281- logl_star,
282- bounds:: AbstractBoundingSpace ,
283- loglike,
284- prior_transform;
285- kwargs... )
280+ function (prop:: Slice )(
281+ rng:: AbstractRNG ,
282+ point:: AbstractVector ,
283+ logl_star,
284+ bounds:: AbstractBoundingSpace ,
285+ model;
286+ kwargs...
287+ )
286288 # setup
287289 n = length (point)
288290 nc = nexpand = ncontract = 0
@@ -303,8 +305,11 @@ function (prop::Slice)(rng::AbstractRNG,
303305 # select axis
304306 axis = axes[idx, :]
305307
306- u, v, logl, nc, nexpand, ncontract = sample_slice (rng, axis, point, logl_star, loglike,
307- prior_transform, nc, nexpand, ncontract)
308+ u, v, logl, nc, nexpand, ncontract = sample_slice (
309+ rng, axis, point, logl_star,
310+ model,
311+ nc, nexpand, ncontract
312+ )
308313 end # end of slice sample along a random direction
309314 end # end of slice sampling loop
310315
@@ -330,13 +335,14 @@ This is a standard _random_ implementation where each slice is along a random di
330335 @assert scale ≥ 0 " Proposal scale must be non-negative"
331336end
332337
333- function (prop:: RSlice )(rng:: AbstractRNG ,
334- point:: AbstractVector ,
335- logl_star,
336- bounds:: AbstractBoundingSpace ,
337- loglike,
338- prior_transform;
339- kwargs... )
338+ function (prop:: RSlice )(
339+ rng:: AbstractRNG ,
340+ point:: AbstractVector ,
341+ logl_star,
342+ bounds:: AbstractBoundingSpace ,
343+ model;
344+ kwargs...
345+ )
340346 # setup
341347 n = length (point)
342348 nc = nexpand = ncontract = 0
@@ -350,8 +356,11 @@ function (prop::RSlice)(rng::AbstractRNG,
350356
351357 # transform and scale into parameter space
352358 axis = prop. scale .* (Bounds. axes (bounds) * drhat)
353- u, v, logl, nc, nexpand, ncontract = sample_slice (rng, axis, point, logl_star, loglike,
354- prior_transform, nc, nexpand, ncontract)
359+ u, v, logl, nc, nexpand, ncontract = sample_slice (
360+ rng, axis, point, logl_star,
361+ model,
362+ nc, nexpand, ncontract
363+ )
355364 end # end of random slice sampling loop
356365
357366 # update random slice proposal scale based on the relative size of the slices compared to the initial guess
@@ -361,13 +370,12 @@ function (prop::RSlice)(rng::AbstractRNG,
361370end # end of function RSlice
362371
363372# Method for slice sampling
364- function sample_slice (rng, axis, u, logl_star, loglike, prior_transform, nc, nexpand, ncontract)
373+ function sample_slice (rng, axis, u, logl_star, model, nc, nexpand, ncontract)
365374 # define starting window
366375 r = rand (rng) # initial scale/offset
367376 u_l = @. u - r * axis # left bound
368377 if unitcheck (u_l)
369- v_l = prior_transform (u_l)
370- logl_l = loglike (v_l)
378+ v_l, logl_l = prior_transform_and_loglikelihood (model, u_l)
371379 else
372380 logl_l = - Inf
373381 end
@@ -376,8 +384,7 @@ function sample_slice(rng, axis, u, logl_star, loglike, prior_transform, nc, nex
376384
377385 u_r = u_l .+ axis # right bound
378386 if unitcheck (u_r)
379- v_r = prior_transform (u_r)
380- logl_r = loglike (v_r)
387+ v_r, logl_r = prior_transform_and_loglikelihood (model, u_r)
381388 else
382389 logl_r = - Inf
383390 end
@@ -387,9 +394,8 @@ function sample_slice(rng, axis, u, logl_star, loglike, prior_transform, nc, nex
387394 # stepping out left and right bounds
388395 while logl_l ≥ logl_star
389396 u_l .- = axis
390- if unitcheck (u_l)
391- v_l = prior_transform (u_l)
392- logl_l = loglike (v_l)
397+ if unitcheck (u_l)
398+ v_l, logl_l = prior_transform_and_loglikelihood (model, u_l)
393399 else
394400 logl_l = - Inf
395401 end
@@ -399,9 +405,8 @@ function sample_slice(rng, axis, u, logl_star, loglike, prior_transform, nc, nex
399405
400406 while logl_r ≥ logl_star
401407 u_r .+ = axis
402- if unitcheck (u_r)
403- v_r = prior_transform (u_r)
404- logl_r = loglike (v_r)
408+ if unitcheck (u_r)
409+ v_r, logl_r = prior_transform_and_loglikelihood (model, u_r)
405410 else
406411 logl_r = - Inf
407412 end
@@ -422,9 +427,8 @@ function sample_slice(rng, axis, u, logl_star, loglike, prior_transform, nc, nex
422427 # propose a new position
423428 r = rand (rng)
424429 u_prop = @. u_l + r * u_hat
425- if unitcheck (u_prop)
426- v_prop = prior_transform (u_prop)
427- logl_prop = loglike (v_prop)
430+ if unitcheck (u_prop)
431+ v_prop, logl_prop = prior_transform_and_loglikelihood (model, u_prop)
428432 else
429433 logl_prop = - Inf
430434 end
0 commit comments