|
17 | 17 |
|
18 | 18 | from bigframes.core import bigframe_node |
19 | 19 | from bigframes.core import expression as ex |
20 | | -from bigframes.core import nodes |
| 20 | +from bigframes.core import nodes, ordering |
21 | 21 |
|
22 | 22 |
|
23 | 23 | def bind_schema_to_tree( |
@@ -79,46 +79,77 @@ def bind_schema_to_node( |
79 | 79 | if isinstance(node, nodes.AggregateNode): |
80 | 80 | aggregations = [] |
81 | 81 | for aggregation, id in node.aggregations: |
82 | | - if isinstance(aggregation, ex.UnaryAggregation): |
83 | | - replaced = typing.cast( |
84 | | - ex.Aggregation, |
85 | | - dataclasses.replace( |
86 | | - aggregation, |
87 | | - arg=typing.cast( |
88 | | - ex.RefOrConstant, |
89 | | - ex.bind_schema_fields( |
90 | | - aggregation.arg, node.child.field_by_id |
91 | | - ), |
92 | | - ), |
93 | | - ), |
| 82 | + aggregations.append( |
| 83 | + (_bind_schema_to_aggregation_expr(aggregation, node.child), id) |
| 84 | + ) |
| 85 | + |
| 86 | + return dataclasses.replace( |
| 87 | + node, |
| 88 | + aggregations=tuple(aggregations), |
| 89 | + ) |
| 90 | + |
| 91 | + if isinstance(node, nodes.WindowOpNode): |
| 92 | + window_spec = dataclasses.replace( |
| 93 | + node.window_spec, |
| 94 | + grouping_keys=tuple( |
| 95 | + typing.cast( |
| 96 | + ex.DerefOp, ex.bind_schema_fields(expr, node.child.field_by_id) |
94 | 97 | ) |
95 | | - aggregations.append((replaced, id)) |
96 | | - elif isinstance(aggregation, ex.BinaryAggregation): |
97 | | - replaced = typing.cast( |
98 | | - ex.Aggregation, |
99 | | - dataclasses.replace( |
100 | | - aggregation, |
101 | | - left=typing.cast( |
102 | | - ex.RefOrConstant, |
103 | | - ex.bind_schema_fields( |
104 | | - aggregation.left, node.child.field_by_id |
105 | | - ), |
106 | | - ), |
107 | | - right=typing.cast( |
108 | | - ex.RefOrConstant, |
109 | | - ex.bind_schema_fields( |
110 | | - aggregation.right, node.child.field_by_id |
111 | | - ), |
112 | | - ), |
| 98 | + for expr in node.window_spec.grouping_keys |
| 99 | + ), |
| 100 | + ordering=tuple( |
| 101 | + ordering.OrderingExpression( |
| 102 | + scalar_expression=ex.bind_schema_fields( |
| 103 | + expr.scalar_expression, node.child.field_by_id |
113 | 104 | ), |
| 105 | + direction=expr.direction, |
| 106 | + na_last=expr.na_last, |
114 | 107 | ) |
115 | | - aggregations.append((replaced, id)) |
116 | | - else: |
117 | | - aggregations.append((aggregation, id)) |
118 | | - |
| 108 | + for expr in node.window_spec.ordering |
| 109 | + ), |
| 110 | + ) |
119 | 111 | return dataclasses.replace( |
120 | 112 | node, |
121 | | - aggregations=tuple(aggregations), |
| 113 | + expression=_bind_schema_to_aggregation_expr(node.expression, node.child), |
| 114 | + window_spec=window_spec, |
122 | 115 | ) |
123 | 116 |
|
124 | 117 | return node |
| 118 | + |
| 119 | + |
| 120 | +def _bind_schema_to_aggregation_expr( |
| 121 | + aggregation: ex.Aggregation, |
| 122 | + child: bigframe_node.BigFrameNode, |
| 123 | +) -> ex.Aggregation: |
| 124 | + assert isinstance( |
| 125 | + aggregation, ex.Aggregation |
| 126 | + ), f"Expected Aggregation, got {type(aggregation)}" |
| 127 | + |
| 128 | + if isinstance(aggregation, ex.UnaryAggregation): |
| 129 | + return typing.cast( |
| 130 | + ex.Aggregation, |
| 131 | + dataclasses.replace( |
| 132 | + aggregation, |
| 133 | + arg=typing.cast( |
| 134 | + ex.RefOrConstant, |
| 135 | + ex.bind_schema_fields(aggregation.arg, child.field_by_id), |
| 136 | + ), |
| 137 | + ), |
| 138 | + ) |
| 139 | + elif isinstance(aggregation, ex.BinaryAggregation): |
| 140 | + return typing.cast( |
| 141 | + ex.Aggregation, |
| 142 | + dataclasses.replace( |
| 143 | + aggregation, |
| 144 | + left=typing.cast( |
| 145 | + ex.RefOrConstant, |
| 146 | + ex.bind_schema_fields(aggregation.left, child.field_by_id), |
| 147 | + ), |
| 148 | + right=typing.cast( |
| 149 | + ex.RefOrConstant, |
| 150 | + ex.bind_schema_fields(aggregation.right, child.field_by_id), |
| 151 | + ), |
| 152 | + ), |
| 153 | + ) |
| 154 | + else: |
| 155 | + return aggregation |
0 commit comments