File tree Expand file tree Collapse file tree 5 files changed +15
-15
lines changed
Expand file tree Collapse file tree 5 files changed +15
-15
lines changed Original file line number Diff line number Diff line change 3434from aesara .scalar import UnaryScalarOp , upgrade_to_float_no_complex
3535from aesara .tensor import gammaln
3636from aesara .tensor .elemwise import Elemwise
37- from aesara .tensor .slinalg import Cholesky
38- from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
39- from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
37+ from aesara .tensor .slinalg import Cholesky , SolveTriangular
4038
4139from pymc .aesaraf import floatX
4240from pymc .distributions .shape_utils import to_tuple
4341
42+ solve_lower = SolveTriangular (lower = True )
43+ solve_upper = SolveTriangular (lower = False )
44+
4445f = floatX
4546c = - 0.5 * np .log (2.0 * np .pi )
4647_beta_clip_values = {
Original file line number Diff line number Diff line change 3333from aesara .tensor .random .basic import dirichlet , multinomial , multivariate_normal
3434from aesara .tensor .random .op import RandomVariable , default_supp_shape_from_params
3535from aesara .tensor .random .utils import broadcast_params , normalize_size_param
36- from aesara .tensor .slinalg import Cholesky
37- from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
38- from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
36+ from aesara .tensor .slinalg import Cholesky , SolveTriangular
3937from aesara .tensor .type import TensorType
4038from scipy import linalg , stats
4139
7977 "StickBreakingWeights" ,
8078]
8179
80+ solve_lower = SolveTriangular (lower = True )
81+ solve_upper = SolveTriangular (lower = False )
82+
8283
8384class SimplexContinuous (Continuous ):
8485 """Base class for simplex continuous distributions"""
Original file line number Diff line number Diff line change 1919
2020from aesara .compile import SharedVariable
2121from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
22+ SolveTriangular ,
2223 cholesky ,
2324 solve ,
2425)
25- from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
26- solve_lower_triangular as solve_lower ,
27- )
28- from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
29- solve_upper_triangular as solve_upper ,
30- )
3126from aesara .tensor .var import TensorConstant
3227from scipy .cluster .vq import kmeans
3328
4136
4237JITTER_DEFAULT = 1e-6
4338
39+ solve_lower = SolveTriangular (lower = True )
40+ solve_upper = SolveTriangular (lower = False )
41+
4442
4543def replace_with_values (vars_needed , replacements = None , model = None ):
4644 R"""
Original file line number Diff line number Diff line change @@ -230,8 +230,8 @@ def kron_vector_op(v):
230230
231231# Define kronecker functions that work on 1D and 2D arrays
232232kron_dot = partial (kron_matrix_op , op = at .dot )
233- kron_solve_lower = partial (kron_matrix_op , op = at .slinalg .solve_lower_triangular )
234- kron_solve_upper = partial (kron_matrix_op , op = at .slinalg .solve_upper_triangular )
233+ kron_solve_lower = partial (kron_matrix_op , op = at .slinalg .SolveTriangular ( lower = True ) )
234+ kron_solve_upper = partial (kron_matrix_op , op = at .slinalg .SolveTriangular ( lower = False ) )
235235
236236
237237def flat_outer (a , b ):
Original file line number Diff line number Diff line change @@ -116,7 +116,7 @@ def test_kron_solve_lower():
116116 x = np .random .rand (tot_size ).reshape ((tot_size , 1 ))
117117 # Construct entire kronecker product then solve
118118 big = kronecker (* Ls )
119- slow_ans = at .slinalg .solve_lower_triangular (big , x )
119+ slow_ans = at .slinalg .solve_triangular (big , x , lower = True )
120120 # Use tricks to avoid construction of entire kronecker product
121121 fast_ans = kron_solve_lower (Ls , x )
122122 np .testing .assert_array_almost_equal (slow_ans .eval (), fast_ans .eval ())
You can’t perform that action at this time.
0 commit comments