@@ -30,7 +30,7 @@ impl<S: Storage> Binder<S> {
3030 select_items : & mut [ ScalarExpression ] ,
3131 ) -> Result < ( ) , BindError > {
3232 for column in select_items {
33- self . visit_column_agg_expr ( column) ;
33+ self . visit_column_agg_expr ( column, true ) ? ;
3434 }
3535 Ok ( ( ) )
3636 }
@@ -57,7 +57,8 @@ impl<S: Storage> Binder<S> {
5757 // Extract having expression.
5858 let return_having = if let Some ( having) = having {
5959 let mut having = self . bind_expr ( having) . await ?;
60- self . visit_column_agg_expr ( & mut having) ;
60+ self . visit_column_agg_expr ( & mut having, false ) ?;
61+
6162 Some ( having)
6263 } else {
6364 None
@@ -73,11 +74,11 @@ impl<S: Storage> Binder<S> {
7374 nulls_first,
7475 } = orderby;
7576 let mut expr = self . bind_expr ( expr) . await ?;
76- self . visit_column_agg_expr ( & mut expr) ;
77+ self . visit_column_agg_expr ( & mut expr, false ) ? ;
7778
7879 return_orderby. push ( SortField :: new (
7980 expr,
80- asc. map_or ( true , |asc| ! asc) ,
81+ asc. map_or ( true , |asc| asc) ,
8182 nulls_first. map_or ( false , |first| first) ,
8283 ) ) ;
8384 }
@@ -88,50 +89,67 @@ impl<S: Storage> Binder<S> {
8889 Ok ( ( return_having, return_orderby) )
8990 }
9091
91- fn visit_column_agg_expr ( & mut self , expr : & mut ScalarExpression ) {
92+ fn visit_column_agg_expr ( & mut self , expr : & mut ScalarExpression , is_select : bool ) -> Result < ( ) , BindError > {
9293 match expr {
9394 ScalarExpression :: AggCall {
9495 ty : return_type, ..
9596 } => {
96- let index = self . context . input_ref_index ( InputRefType :: AggCall ) ;
97- let input_ref = ScalarExpression :: InputRef {
98- index,
99- ty : return_type. clone ( ) ,
100- } ;
101- match std:: mem:: replace ( expr, input_ref) {
102- ScalarExpression :: AggCall {
103- kind,
104- args,
97+ let ty = return_type. clone ( ) ;
98+ if is_select {
99+ let index = self . context . input_ref_index ( InputRefType :: AggCall ) ;
100+ let input_ref = ScalarExpression :: InputRef {
101+ index,
105102 ty,
106- distinct
107- } => {
108- self . context . agg_calls . push ( ScalarExpression :: AggCall {
109- distinct,
103+ } ;
104+ match std:: mem:: replace ( expr, input_ref) {
105+ ScalarExpression :: AggCall {
110106 kind,
111107 args,
112108 ty,
113- } ) ;
109+ distinct
110+ } => {
111+ self . context . agg_calls . push ( ScalarExpression :: AggCall {
112+ distinct,
113+ kind,
114+ args,
115+ ty,
116+ } ) ;
117+ }
118+ _ => unreachable ! ( ) ,
114119 }
115- _ => unreachable ! ( ) ,
120+ } else {
121+ let ( index, _) = self
122+ . context
123+ . agg_calls
124+ . iter ( )
125+ . find_position ( |agg_expr| agg_expr == & expr)
126+ . ok_or_else ( || BindError :: AggMiss ( format ! ( "{:?}" , expr) ) ) ?;
127+
128+ let _ = std:: mem:: replace ( expr, ScalarExpression :: InputRef {
129+ index,
130+ ty,
131+ } ) ;
116132 }
117133 }
118134
119- ScalarExpression :: TypeCast { expr, .. } => self . visit_column_agg_expr ( expr) ,
120- ScalarExpression :: IsNull { expr } => self . visit_column_agg_expr ( expr) ,
121- ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr) ,
122- ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr) ,
135+ ScalarExpression :: TypeCast { expr, .. } => self . visit_column_agg_expr ( expr, is_select ) ? ,
136+ ScalarExpression :: IsNull { expr } => self . visit_column_agg_expr ( expr, is_select ) ? ,
137+ ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr, is_select ) ? ,
138+ ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr, is_select ) ? ,
123139 ScalarExpression :: Binary {
124140 left_expr,
125141 right_expr,
126142 ..
127143 } => {
128- self . visit_column_agg_expr ( left_expr) ;
129- self . visit_column_agg_expr ( right_expr) ;
144+ self . visit_column_agg_expr ( left_expr, is_select ) ? ;
145+ self . visit_column_agg_expr ( right_expr, is_select ) ? ;
130146 }
131147 ScalarExpression :: Constant ( _)
132148 | ScalarExpression :: ColumnRef { .. }
133149 | ScalarExpression :: InputRef { .. } => { }
134150 }
151+
152+ Ok ( ( ) )
135153 }
136154
137155 /// Validate select exprs must appear in the GROUP BY clause or be used in
@@ -173,6 +191,7 @@ impl<S: Storage> Binder<S> {
173191 if expr. has_agg_call ( & self . context ) {
174192 continue ;
175193 }
194+
176195 group_raw_set. remove ( expr) ;
177196
178197 if !group_raw_exprs. iter ( ) . contains ( expr) {
@@ -271,6 +290,9 @@ impl<S: Storage> Binder<S> {
271290 if self . context . group_by_exprs . contains ( expr) {
272291 return Ok ( ( ) ) ;
273292 }
293+ if matches ! ( expr, ScalarExpression :: Alias { .. } ) {
294+ return self . validate_having_orderby ( expr. unpack_alias ( ) ) ;
295+ }
274296
275297 Err ( BindError :: AggMiss (
276298 format ! (
0 commit comments