|
| 1 | +@testitem "Literate example" begin |
| 2 | + #literate_begin file="src/examples/base_operations.md" |
| 3 | + |
| 4 | + #= |
| 5 | + # Node and Tree Operations |
| 6 | +
|
| 7 | + This example demonstrates how to create and manipulate expression trees |
| 8 | + using the [`Node`](@ref) type. We'll create a tree, |
| 9 | + perform various operations, and show how to traverse and modify it. |
| 10 | +
|
| 11 | + First, let's create a simple expression tree. |
| 12 | + We can bootstrap this by creating a node to hold `feature=1`, |
| 13 | + indicating the first input variable (first column of data): |
| 14 | + =# |
| 15 | + using DynamicExpressions, Random |
| 16 | + |
| 17 | + x = Node{Float64}(; feature=1) |
| 18 | + |
| 19 | + # We can also create values, using `val`: |
| 20 | + const_1 = Node{Float64}(; val=1.0) |
| 21 | + |
| 22 | + #= |
| 23 | + Now, let's declare some operators to use in our expression tree. |
| 24 | +
|
| 25 | + Note that the declaration of the `OperatorEnum` updates |
| 26 | + a global mapping from operators to their index in a list. |
| 27 | + This is purely for convenience, and most of the time, you would |
| 28 | + either operate directly on the `OperatorEnum`, like with [`eval_tree_array`](@ref), |
| 29 | + or use [`Expression`](@ref) objects to store them alongside the expression. |
| 30 | + =# |
| 31 | + operators = OperatorEnum(; unary_operators=(sin, exp), binary_operators=(+, -, *, /)) |
| 32 | + |
| 33 | + # Now, let's create another variable |
| 34 | + y = Node{Float64}(; feature=2) |
| 35 | + |
| 36 | + # And we can now create expression trees: |
| 37 | + tree = (x + y) * const_1 - sin(x) |
| 38 | + |
| 39 | + # The type of this is the same as the type of the variables |
| 40 | + # and constants, meaning we have type stability: |
| 41 | + typeof(tree), typeof(x) |
| 42 | + @test typeof(tree) == typeof(x) #src |
| 43 | + |
| 44 | + # We can also just use scalars directly: |
| 45 | + tree2 = 2x - sin(x) |
| 46 | + |
| 47 | + # As you have noticed, the tree is printed as an expression. |
| 48 | + # We can control this with the [`string_tree`](@ref) function, |
| 49 | + # which also lets us pass the `operators` explicitly: |
| 50 | + string_tree(tree, operators; variable_names=["x", "y"]) |
| 51 | + #= |
| 52 | + This also lets us control how each branch node and leaf node (variable/constant) |
| 53 | + is printed in the tree. |
| 54 | +
|
| 55 | + There are a lot of operations you can do on tree objects, |
| 56 | + such as evaluating them over batched data: |
| 57 | + =# |
| 58 | + rng = Random.MersenneTwister(0) |
| 59 | + tree2(randn(rng, Float64, 2, 5), operators) |
| 60 | + |
| 61 | + #= |
| 62 | + Now, how does this actually work? How do these functions traverse |
| 63 | + the tree? |
| 64 | +
|
| 65 | + The core operation is the [`tree_mapreduce`](@ref) function, |
| 66 | + which applies a function to each node in the tree, |
| 67 | + and then combines the results. Unlike a standard `mapreduce`, |
| 68 | + the `tree_mapreduce` allows you to specify different maps for |
| 69 | + branch nodes and leaf nodes. Also unlike a `mapreduce`, the |
| 70 | + reduction function needs to handle a variable number of inputs – it takes |
| 71 | + the mapped branch node, as well as all of the mapped children. |
| 72 | +
|
| 73 | + Let's see an example. Say we just want to count the nodes in the tree: |
| 74 | + =# |
| 75 | + tree_mapreduce(node -> 1, +, tree) |
| 76 | + #= |
| 77 | + Here, the `+` handles both the cases of 1 child and 2 children. |
| 78 | + Here, we didn't need to specify a custom branch function, but we could do that too: |
| 79 | + =# |
| 80 | + tree_mapreduce(leaf_node -> 1, branch_node -> 0, +, tree) |
| 81 | + #= |
| 82 | + This counts the number of leaf nodes in the tree. For `tree`, |
| 83 | + this was `x`, `y`, `const_1`, and `x`. |
| 84 | +
|
| 85 | + You can access fields of the [`Node`](@ref) type here to create more |
| 86 | + complex operations, just be careful to not access undefined fields (be sure |
| 87 | + to read the API specification). |
| 88 | +
|
| 89 | + Most operators can be built with this simple pattern, even including |
| 90 | + evaluation of the tree, and printing of expressions. (It also allows |
| 91 | + for graph-like expressions like [`GraphNode`](@ref) via a `f_on_shared` keyword.) |
| 92 | +
|
| 93 | + As a more complex example, let's compute the depth of a tree. Here, we need |
| 94 | + to use a more complicated reduction operation – the `max`: |
| 95 | + =# |
| 96 | + tree_mapreduce( |
| 97 | + node -> 1, (parent, children...) -> 1 + max(children...), x + sin(sin(exp(x))) |
| 98 | + ) |
| 99 | + #= |
| 100 | + Here, the `max` handles both the cases of 1 child and 2 children. |
| 101 | + The parent node contributes `1` at each depth. Note that the inputs |
| 102 | + to the reduction are already mapped to `1`. |
| 103 | +
|
| 104 | + Many operations do not need to handle branching, and thus, many of the typical |
| 105 | + operations on collections in Julia are available. For example, |
| 106 | + we can collect each node in the tree into a list: |
| 107 | + =# |
| 108 | + collect(tree) |
| 109 | + # Note that the first node in this list is the root note, which is |
| 110 | + # the subtraction operation: |
| 111 | + tree == first(collect(tree)) |
| 112 | + @test tree == first(collect(tree)) #src |
| 113 | + # We can look at the operator: |
| 114 | + tree.degree, tree.op |
| 115 | + # And compare it to our list: |
| 116 | + operators.binops |
| 117 | + # Many other collection operations are available. For example, we can aggregate a relationship over each node: |
| 118 | + sum(node -> node.degree == 0 ? 1.5 : 0.0, tree) |
| 119 | + # We can even use `any` which has an early exit from the depth-first tree traversal: |
| 120 | + any(node -> node.degree == 2, tree) |
| 121 | + # We can also randomly sample nodes, using [`NodeSampler`](@ref), |
| 122 | + # which permits filters: |
| 123 | + rand(rng, NodeSampler(; tree, filter=node -> node.degree == 1)) |
| 124 | + |
| 125 | + #literate_end |
| 126 | + |
| 127 | + # y = Node(Float64; feature=2) # Represents variable y |
| 128 | + # const_3 = Node(3.0) # Constant node with value 3 |
| 129 | + # expr_tree = (x + y) * const_3 - sin(x) |
| 130 | + |
| 131 | + # println("Original expression tree:") |
| 132 | + # println(expr_tree) |
| 133 | + |
| 134 | + # # Now, let's demonstrate some operations on this tree |
| 135 | + |
| 136 | + # # 1. Counting nodes |
| 137 | + # node_count = count_nodes(expr_tree) |
| 138 | + # println("\nNumber of nodes in the tree: ", node_count) |
| 139 | + |
| 140 | + # # 2. Finding all constant nodes |
| 141 | + # constant_nodes = filter(t -> t.degree == 0 && t.constant, expr_tree) |
| 142 | + # println("\nConstant nodes in the tree:") |
| 143 | + # for node in constant_nodes |
| 144 | + # println(node) |
| 145 | + # end |
| 146 | + |
| 147 | + # # 3. Mapping: Double all constant values |
| 148 | + # map(expr_tree) do t |
| 149 | + # if t.degree == 0 && t.constant |
| 150 | + # t.val *= 2 |
| 151 | + # end |
| 152 | + # end |
| 153 | + |
| 154 | + # println("\nExpression after doubling constants:") |
| 155 | + # println(expr_tree) |
| 156 | + |
| 157 | + # # 4. Checking if a node is in the tree |
| 158 | + # println("\nIs x in the tree? ", x in expr_tree) |
| 159 | + # println("Is Node(4.0) in the tree? ", Node(4.0) in expr_tree) |
| 160 | + |
| 161 | + # # 5. Sum of all constant values |
| 162 | + # const_sum = sum(t -> t.degree == 0 && t.constant ? t.val : 0.0, expr_tree) |
| 163 | + # println("\nSum of all constant values: ", const_sum) |
| 164 | + |
| 165 | + # # This example showcases how to create expression trees using `Node`, |
| 166 | + # # and how to use various operations like count, filter, map, in, and sum |
| 167 | + # # to analyze and modify the tree structure. |
| 168 | +end |
0 commit comments