Skip to content

Commit 0ad3c93

Browse files
committed
ci: add extra benchmark for parametric nodes
1 parent f8e5acf commit 0ad3c93

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

benchmark/benchmarks.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ function benchmark_utilities()
213213
setup=(
214214
ntrees=100;
215215
n=20;
216-
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
216+
rng=Random.MersenneTwister(0);
217+
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32, Node, rng)) for _ in 1:ntrees]
217218
)
218219
)
219220
#! format: on
@@ -222,6 +223,37 @@ function benchmark_utilities()
222223
end
223224
end
224225

226+
# Additional methods
227+
@static if PACKAGE_VERSION >= v"0.18.0"
228+
suite["get_set_constants_parametric"] = @benchmarkable(
229+
[get_set_constants!(ex) for ex in exs],
230+
seconds = 10.0,
231+
setup = (
232+
operators = $operators;
233+
ntrees = 100;
234+
n = 20;
235+
n_features = 5;
236+
n_params = 3;
237+
n_param_classes = 10;
238+
rng = Random.MersenneTwister(0);
239+
exs = [
240+
let tree = gen_random_tree_fixed_size(
241+
n, operators, n_features, Float32, ParametricNode, rng
242+
)
243+
ex = ParametricExpression(
244+
tree;
245+
operators,
246+
variable_names=map(i -> "x$i", 1:n_features),
247+
parameters=randn(rng, Float32, n_params, n_param_classes),
248+
parameter_names=map(i -> "p$i", 1:n_params),
249+
)
250+
ex
251+
end for _ in 1:ntrees
252+
]
253+
)
254+
)
255+
end
256+
225257
return suite
226258
end
227259

src/ParametricExpression.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,26 +173,8 @@ end
173173
###############################################################################
174174
# Extra utilities for parametric-specific behavior ############################
175175
###############################################################################
176-
## As explained in AbstractExpressionNode, we can implement custom behavior for
177-
## the parametric expression by implementing the following methods:
178-
# - `count_nodes`
179-
# - `count_constants`
180-
# - `count_depth`
181-
# - `index_constant_nodes`
182-
# - `has_operators`
183-
# - `has_constants`
184-
# - `get_scalar_constants`
185-
# - `set_scalar_constants!`
186-
# - `string_tree`
187-
# - `max_feature`
188-
# - `eval_tree_array`
189-
# - `eval_grad_tree_array`
190-
# - `_grad_evaluator`
191-
192-
## For a parametric struct, we only wish to implement the following
193176

194177
#! format: off
195-
196178
struct InterfaceError <: Exception
197179
end
198180
_interface_error() = throw(InterfaceError())

0 commit comments

Comments
 (0)