Skip to content

Commit ad80af8

Browse files
committed
Add docstrings for convenience calls
1 parent 453c3bc commit ad80af8

File tree

2 files changed

+133
-4
lines changed

2 files changed

+133
-4
lines changed

docs/src/eval.md

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,34 @@ eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) w
88
```
99

1010
Assuming you are only using a single `OperatorEnum`, you can also use
11-
the following short-hand by using the expression as a function:
11+
the following shorthand by using the expression as a function:
12+
13+
```
14+
(tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
15+
16+
Evaluate a binary tree (equation) over a given data matrix. The
17+
operators contain all of the operators used in the tree.
18+
19+
# Arguments
20+
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
21+
- `operators::OperatorEnum`: The operators used in the tree.
22+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
23+
24+
# Returns
25+
- `output::AbstractVector{T}`: the result, which is a 1D array.
26+
Any NaN, Inf, or other failure during the evaluation will result in the entire
27+
output array being set to NaN.
28+
```
29+
30+
For example,
1231

1332
```@example
33+
using DynamicExpressions
34+
1435
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
1536
tree = Node(; feature=1) * cos(Node(; feature=2) - 3.2)
1637
17-
tree(X)
38+
tree([1 2 3; 4 5 6.], operators)
1839
```
1940

2041
This is possible because when you call `OperatorEnum`, it automatically re-defines
@@ -32,7 +53,31 @@ The notation is the same for `eval_tree_array`, though it will return `nothing`
3253
when it can't find a method, and not do any NaN checks:
3354

3455
```@docs
35-
eval_tree_array(tree, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true)
56+
eval_tree_array(tree::Node, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
57+
```
58+
59+
Likewise for the shorthand notation:
60+
61+
```
62+
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
63+
64+
# Arguments
65+
- `X::AbstractArray`: The input data to evaluate the tree on.
66+
- `operators::GenericOperatorEnum`: The operators used in the tree.
67+
- `throw_errors::Bool=true`: Whether to throw errors
68+
if they occur during evaluation. Otherwise,
69+
MethodErrors will be caught before they happen and
70+
evaluation will return `nothing`,
71+
rather than throwing an error. This is useful in cases
72+
where you are unsure if a particular tree is valid or not,
73+
and would prefer to work with `nothing` as an output.
74+
75+
# Returns
76+
- `output`: the result of the evaluation.
77+
If evaluation failed, `nothing` will be returned for the first argument.
78+
A `false` complete means an operator was called on input types
79+
that it was not defined for. You can change this behavior by
80+
setting `throw_errors=false`.
3681
```
3782

3883
## Derivatives
@@ -46,7 +91,32 @@ all variables (or, all constants). Both use forward-mode automatic, but use
4691

4792
```@docs
4893
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
49-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false) where {T<:Number}
94+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
95+
```
96+
97+
You can compute gradients this with shorthand notation as well (which by default computes
98+
gradients with respect to input matrix, rather than constants).
99+
100+
```
101+
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
102+
103+
Compute the forward-mode derivative of an expression, using a similar
104+
structure and optimization to eval_tree_array. `variable` specifies whether
105+
we should take derivatives with respect to features (i.e., X), or with respect
106+
to every constant in the expression.
107+
108+
# Arguments
109+
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
110+
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
111+
must be `true`. This is needed to create the derivative operations.
112+
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
113+
or with respect to every constant in the expression (`variable=false`).
114+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
115+
116+
# Returns
117+
118+
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
119+
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
50120
```
51121

52122
Alternatively, you can compute higher-order derivatives by using `ForwardDiff` on

src/EvaluationHelpers.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,48 @@ import ..EvaluateEquationModule: eval_tree_array
77
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
88

99
# Evaluation:
10+
"""
11+
(tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
12+
13+
Evaluate a binary tree (equation) over a given data matrix. The
14+
operators contain all of the operators used in the tree.
15+
16+
# Arguments
17+
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
18+
- `operators::OperatorEnum`: The operators used in the tree.
19+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
20+
21+
# Returns
22+
- `output::AbstractVector{T}`: the result, which is a 1D array.
23+
Any NaN, Inf, or other failure during the evaluation will result in the entire
24+
output array being set to NaN.
25+
"""
1026
function (tree::Node)(X, operators::OperatorEnum; kws...)
1127
out, did_finish = eval_tree_array(tree, X, operators; kws...)
1228
!did_finish && (out .= convert(eltype(out), NaN))
1329
return out
1430
end
31+
"""
32+
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
33+
34+
# Arguments
35+
- `X::AbstractArray`: The input data to evaluate the tree on.
36+
- `operators::GenericOperatorEnum`: The operators used in the tree.
37+
- `throw_errors::Bool=true`: Whether to throw errors
38+
if they occur during evaluation. Otherwise,
39+
MethodErrors will be caught before they happen and
40+
evaluation will return `nothing`,
41+
rather than throwing an error. This is useful in cases
42+
where you are unsure if a particular tree is valid or not,
43+
and would prefer to work with `nothing` as an output.
44+
45+
# Returns
46+
- `output`: the result of the evaluation.
47+
If evaluation failed, `nothing` will be returned for the first argument.
48+
A `false` complete means an operator was called on input types
49+
that it was not defined for. You can change this behavior by
50+
setting `throw_errors=false`.
51+
"""
1552
function (tree::Node)(X, operators::GenericOperatorEnum; kws...)
1653
out, did_finish = eval_tree_array(tree, X, operators; kws...)
1754
!did_finish && return nothing
@@ -37,6 +74,28 @@ function _grad_evaluator(tree::Node, X; kws...)
3774
## into a depwarn
3875
@error "The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead."
3976
end
77+
78+
"""
79+
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
80+
81+
Compute the forward-mode derivative of an expression, using a similar
82+
structure and optimization to eval_tree_array. `variable` specifies whether
83+
we should take derivatives with respect to features (i.e., X), or with respect
84+
to every constant in the expression.
85+
86+
# Arguments
87+
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
88+
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
89+
must be `true`. This is needed to create the derivative operations.
90+
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
91+
or with respect to every constant in the expression (`variable=false`).
92+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
93+
94+
# Returns
95+
96+
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
97+
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
98+
"""
4099
Base.adjoint(tree::Node) = ((args...; kws...) -> _grad_evaluator(tree, args...; kws...))
41100

42101
end

0 commit comments

Comments
 (0)