@@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
33using LoopVectorization: @turbo
44using DynamicExpressions: AbstractExpressionNode
55using DynamicExpressions. UtilsModule: ResultOk, fill_similar
6- using DynamicExpressions. EvaluateModule: @return_on_check
6+ using DynamicExpressions. EvaluateModule: @return_on_nonfinite_val , EvalOptions
77import DynamicExpressions. EvaluateModule:
88 deg1_eval,
99 deg2_eval,
@@ -18,7 +18,10 @@ import DynamicExpressions.ExtensionInterfaceModule:
1818_is_loopvectorization_loaded (:: Int ) = true
1919
2020function deg2_eval (
21- cumulator_l:: AbstractVector{T} , cumulator_r:: AbstractVector{T} , op:: F , :: Val{true}
21+ cumulator_l:: AbstractVector{T} ,
22+ cumulator_r:: AbstractVector{T} ,
23+ op:: F ,
24+ :: EvalOptions{true} ,
2225):: ResultOk where {T<: Number ,F}
2326 @turbo for j in eachindex (cumulator_l)
2427 x = op (cumulator_l[j], cumulator_r[j])
@@ -28,7 +31,7 @@ function deg2_eval(
2831end
2932
3033function deg1_eval (
31- cumulator:: AbstractVector{T} , op:: F , :: Val {true}
34+ cumulator:: AbstractVector{T} , op:: F , :: EvalOptions {true}
3235):: ResultOk where {T<: Number ,F}
3336 @turbo for j in eachindex (cumulator)
3437 x = op (cumulator[j])
@@ -38,21 +41,25 @@ function deg1_eval(
3841end
3942
4043function deg1_l2_ll0_lr0_eval (
41- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{true}
44+ tree:: AbstractExpressionNode{T} ,
45+ cX:: AbstractMatrix{T} ,
46+ op:: F ,
47+ op_l:: F2 ,
48+ eval_options:: EvalOptions{true} ,
4249) where {T<: Number ,F,F2}
4350 if tree. l. l. constant && tree. l. r. constant
4451 val_ll = tree. l. l. val
4552 val_lr = tree. l. r. val
46- @return_on_check val_ll cX
47- @return_on_check val_lr cX
53+ @return_on_nonfinite_val (eval_options, val_ll, cX)
54+ @return_on_nonfinite_val (eval_options, val_lr, cX)
4855 x_l = op_l (val_ll, val_lr):: T
49- @return_on_check x_l cX
56+ @return_on_nonfinite_val (eval_options, x_l, cX)
5057 x = op (x_l):: T
51- @return_on_check x cX
58+ @return_on_nonfinite_val (eval_options, x, cX)
5259 return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
5360 elseif tree. l. l. constant
5461 val_ll = tree. l. l. val
55- @return_on_check val_ll cX
62+ @return_on_nonfinite_val (eval_options, val_ll, cX)
5663 feature_lr = tree. l. r. feature
5764 cumulator = similar (cX, axes (cX, 2 ))
5865 @turbo for j in axes (cX, 2 )
@@ -64,7 +71,7 @@ function deg1_l2_ll0_lr0_eval(
6471 elseif tree. l. r. constant
6572 feature_ll = tree. l. l. feature
6673 val_lr = tree. l. r. val
67- @return_on_check val_lr cX
74+ @return_on_nonfinite_val (eval_options, val_lr, cX)
6875 cumulator = similar (cX, axes (cX, 2 ))
6976 @turbo for j in axes (cX, 2 )
7077 x_l = op_l (cX[feature_ll, j], val_lr)
@@ -86,15 +93,19 @@ function deg1_l2_ll0_lr0_eval(
8693end
8794
8895function deg1_l1_ll0_eval (
89- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{true}
96+ tree:: AbstractExpressionNode{T} ,
97+ cX:: AbstractMatrix{T} ,
98+ op:: F ,
99+ op_l:: F2 ,
100+ eval_options:: EvalOptions{true} ,
90101) where {T<: Number ,F,F2}
91102 if tree. l. l. constant
92103 val_ll = tree. l. l. val
93- @return_on_check val_ll cX
104+ @return_on_nonfinite_val (eval_options, val_ll, cX)
94105 x_l = op_l (val_ll):: T
95- @return_on_check x_l cX
106+ @return_on_nonfinite_val (eval_options, x_l, cX)
96107 x = op (x_l):: T
97- @return_on_check x cX
108+ @return_on_nonfinite_val (eval_options, x, cX)
98109 return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
99110 else
100111 feature_ll = tree. l. l. feature
@@ -109,20 +120,23 @@ function deg1_l1_ll0_eval(
109120end
110121
111122function deg2_l0_r0_eval (
112- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , :: Val{true}
123+ tree:: AbstractExpressionNode{T} ,
124+ cX:: AbstractMatrix{T} ,
125+ op:: F ,
126+ eval_options:: EvalOptions{true} ,
113127) where {T<: Number ,F}
114128 if tree. l. constant && tree. r. constant
115129 val_l = tree. l. val
116- @return_on_check val_l cX
130+ @return_on_nonfinite_val (eval_options, val_l, cX)
117131 val_r = tree. r. val
118- @return_on_check val_r cX
132+ @return_on_nonfinite_val (eval_options, val_r, cX)
119133 x = op (val_l, val_r):: T
120- @return_on_check x cX
134+ @return_on_nonfinite_val (eval_options, x, cX)
121135 return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
122136 elseif tree. l. constant
123137 cumulator = similar (cX, axes (cX, 2 ))
124138 val_l = tree. l. val
125- @return_on_check val_l cX
139+ @return_on_nonfinite_val (eval_options, val_l, cX)
126140 feature_r = tree. r. feature
127141 @turbo for j in axes (cX, 2 )
128142 x = op (val_l, cX[feature_r, j])
@@ -133,7 +147,7 @@ function deg2_l0_r0_eval(
133147 cumulator = similar (cX, axes (cX, 2 ))
134148 feature_l = tree. l. feature
135149 val_r = tree. r. val
136- @return_on_check val_r cX
150+ @return_on_nonfinite_val (eval_options, val_r, cX)
137151 @turbo for j in axes (cX, 2 )
138152 x = op (cX[feature_l, j], val_r)
139153 cumulator[j] = x
@@ -157,11 +171,11 @@ function deg2_l0_eval(
157171 cumulator:: AbstractVector{T} ,
158172 cX:: AbstractArray{T} ,
159173 op:: F ,
160- :: Val {true} ,
174+ eval_options :: EvalOptions {true} ,
161175) where {T<: Number ,F}
162176 if tree. l. constant
163177 val = tree. l. val
164- @return_on_check val cX
178+ @return_on_nonfinite_val (eval_options, val, cX)
165179 @turbo for j in eachindex (cumulator)
166180 x = op (val, cumulator[j])
167181 cumulator[j] = x
@@ -182,11 +196,11 @@ function deg2_r0_eval(
182196 cumulator:: AbstractVector{T} ,
183197 cX:: AbstractArray{T} ,
184198 op:: F ,
185- :: Val {true} ,
199+ eval_options :: EvalOptions {true} ,
186200) where {T<: Number ,F}
187201 if tree. r. constant
188202 val = tree. r. val
189- @return_on_check val cX
203+ @return_on_nonfinite_val (eval_options, val, cX)
190204 @turbo for j in eachindex (cumulator)
191205 x = op (cumulator[j], val)
192206 cumulator[j] = x
@@ -203,11 +217,15 @@ function deg2_r0_eval(
203217end
204218
205219# # Interface with Bumper.jl
206- function bumper_kern1! (op:: F , cumulator, :: Val{true} ) where {F}
220+ function bumper_kern1! (
221+ op:: F , cumulator, :: EvalOptions{true,true,early_exit}
222+ ) where {F,early_exit}
207223 @turbo @. cumulator = op (cumulator)
208224 return cumulator
209225end
210- function bumper_kern2! (op:: F , cumulator1, cumulator2, :: Val{true} ) where {F}
226+ function bumper_kern2! (
227+ op:: F , cumulator1, cumulator2, :: EvalOptions{true,true,early_exit}
228+ ) where {F,early_exit}
211229 @turbo @. cumulator1 = op (cumulator1, cumulator2)
212230 return cumulator1
213231end
0 commit comments