@@ -8,13 +8,34 @@ eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) w
88```
99
1010Assuming you are only using a single ` OperatorEnum ` , you can also use
11- the following short-hand by using the expression as a function:
11+ the following shorthand by using the expression as a function:
12+
13+ ```
14+ (tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
15+
16+ Evaluate a binary tree (equation) over a given data matrix. The
17+ operators contain all of the operators used in the tree.
18+
19+ # Arguments
20+ - `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
21+ - `operators::OperatorEnum`: The operators used in the tree.
22+ - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
23+
24+ # Returns
25+ - `output::AbstractVector{T}`: the result, which is a 1D array.
26+ Any NaN, Inf, or other failure during the evaluation will result in the entire
27+ output array being set to NaN.
28+ ```
29+
30+ For example,
1231
1332``` @example
33+ using DynamicExpressions
34+
1435operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
1536tree = Node(; feature=1) * cos(Node(; feature=2) - 3.2)
1637
17- tree(X )
38+ tree([1 2 3; 4 5 6.], operators )
1839```
1940
2041This is possible because when you call ` OperatorEnum ` , it automatically re-defines
@@ -32,7 +53,31 @@ The notation is the same for `eval_tree_array`, though it will return `nothing`
3253when it can't find a method, and not do any NaN checks:
3354
3455``` @docs
35- eval_tree_array(tree, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true)
56+ eval_tree_array(tree::Node, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
57+ ```
58+
59+ Likewise for the shorthand notation:
60+
61+ ```
62+ (tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
63+
64+ # Arguments
65+ - `X::AbstractArray`: The input data to evaluate the tree on.
66+ - `operators::GenericOperatorEnum`: The operators used in the tree.
67+ - `throw_errors::Bool=true`: Whether to throw errors
68+ if they occur during evaluation. Otherwise,
69+ MethodErrors will be caught before they happen and
70+ evaluation will return `nothing`,
71+ rather than throwing an error. This is useful in cases
72+ where you are unsure if a particular tree is valid or not,
73+ and would prefer to work with `nothing` as an output.
74+
75+ # Returns
76+ - `output`: the result of the evaluation.
77+ If evaluation failed, `nothing` will be returned for the first argument.
78+ A `false` complete means an operator was called on input types
79+ that it was not defined for. You can change this behavior by
80+ setting `throw_errors=false`.
3681```
3782
3883## Derivatives
@@ -46,7 +91,32 @@ all variables (or, all constants). Both use forward-mode automatic, but use
4691
4792``` @docs
4893eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
49- eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false) where {T<:Number}
94+ eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
95+ ```
96+
97+ You can compute gradients this with shorthand notation as well (which by default computes
98+ gradients with respect to input matrix, rather than constants).
99+
100+ ```
101+ (tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
102+
103+ Compute the forward-mode derivative of an expression, using a similar
104+ structure and optimization to eval_tree_array. `variable` specifies whether
105+ we should take derivatives with respect to features (i.e., X), or with respect
106+ to every constant in the expression.
107+
108+ # Arguments
109+ - `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
110+ - `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
111+ must be `true`. This is needed to create the derivative operations.
112+ - `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
113+ or with respect to every constant in the expression (`variable=false`).
114+ - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
115+
116+ # Returns
117+
118+ - `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
119+ the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
50120```
51121
52122Alternatively, you can compute higher-order derivatives by using ` ForwardDiff ` on
0 commit comments