Skip to content

Commit b02cc59

Browse files
authored
refactor(query): optimized UnaryState design and simplified string_agg implementation (#18941)
1 parent 8b98296 commit b02cc59

27 files changed

+471
-815
lines changed

src/query/expression/src/types.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ pub use self::bitmap::BitmapType;
6666
pub use self::boolean::Bitmap;
6767
pub use self::boolean::BooleanType;
6868
pub use self::boolean::MutableBitmap;
69-
pub use self::compute_view::StringConvert;
7069
pub use self::date::DateType;
7170
pub use self::decimal::*;
7271
pub use self::empty_array::EmptyArrayType;

src/query/expression/src/types/compute_view.rs

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,10 @@ use std::ops::Range;
2020
use databend_common_exception::Result;
2121
use num_traits::AsPrimitive;
2222

23-
use super::column_type_error;
24-
use super::domain_type_error;
25-
use super::scalar_type_error;
26-
use super::string::StringDomain;
27-
use super::string::StringIterator;
2823
use super::AccessType;
29-
use super::AnyType;
30-
use super::ArgType;
3124
use super::Number;
3225
use super::NumberType;
3326
use super::SimpleDomain;
34-
use super::StringColumn;
35-
use super::StringType;
36-
use crate::display::scalar_ref_to_string;
3727
use crate::Column;
3828
use crate::Domain;
3929
use crate::ScalarRef;
@@ -177,124 +167,3 @@ where
177167
SimpleDomain { min, max }
178168
}
179169
}
180-
181-
/// For number convert
182-
pub type StringConvertView = ComputeView<StringConvert, AnyType, OwnedStringType>;
183-
184-
#[derive(Debug, Clone, PartialEq, Eq)]
185-
pub struct OwnedStringType;
186-
187-
impl AccessType for OwnedStringType {
188-
type Scalar = String;
189-
type ScalarRef<'a> = String;
190-
type Column = StringColumn;
191-
type Domain = StringDomain;
192-
type ColumnIterator<'a> = std::iter::Map<StringIterator<'a>, fn(&str) -> String>;
193-
194-
fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar {
195-
scalar.to_string()
196-
}
197-
198-
fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> {
199-
scalar.clone()
200-
}
201-
202-
fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Result<Self::ScalarRef<'a>> {
203-
scalar
204-
.as_string()
205-
.map(|s| s.to_string())
206-
.ok_or_else(|| scalar_type_error::<Self>(scalar))
207-
}
208-
209-
fn try_downcast_column(col: &Column) -> Result<Self::Column> {
210-
col.as_string()
211-
.cloned()
212-
.ok_or_else(|| column_type_error::<Self>(col))
213-
}
214-
215-
fn try_downcast_domain(domain: &Domain) -> Result<Self::Domain> {
216-
domain
217-
.as_string()
218-
.cloned()
219-
.ok_or_else(|| domain_type_error::<Self>(domain))
220-
}
221-
222-
fn column_len(col: &Self::Column) -> usize {
223-
col.len()
224-
}
225-
226-
fn index_column(col: &Self::Column, index: usize) -> Option<Self::ScalarRef<'_>> {
227-
col.index(index).map(str::to_string)
228-
}
229-
230-
#[inline]
231-
unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> {
232-
col.value_unchecked(index).to_string()
233-
}
234-
235-
fn slice_column(col: &Self::Column, range: Range<usize>) -> Self::Column {
236-
col.clone().sliced(range.start, range.end - range.start)
237-
}
238-
239-
fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> {
240-
col.iter().map(str::to_string)
241-
}
242-
243-
fn scalar_memory_size(scalar: &Self::ScalarRef<'_>) -> usize {
244-
scalar.len()
245-
}
246-
247-
fn column_memory_size(col: &Self::Column) -> usize {
248-
col.memory_size()
249-
}
250-
251-
#[inline(always)]
252-
fn compare(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> Ordering {
253-
left.cmp(&right)
254-
}
255-
256-
#[inline(always)]
257-
fn equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
258-
left == right
259-
}
260-
261-
#[inline(always)]
262-
fn not_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
263-
left != right
264-
}
265-
266-
#[inline(always)]
267-
fn greater_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
268-
left > right
269-
}
270-
271-
#[inline(always)]
272-
fn greater_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
273-
left >= right
274-
}
275-
276-
#[inline(always)]
277-
fn less_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
278-
left < right
279-
}
280-
281-
#[inline(always)]
282-
fn less_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
283-
left <= right
284-
}
285-
}
286-
287-
#[derive(Debug, Clone, PartialEq, Eq)]
288-
pub struct StringConvert;
289-
290-
impl Compute<AnyType, OwnedStringType> for StringConvert {
291-
fn compute<'a>(
292-
value: <AnyType as AccessType>::ScalarRef<'a>,
293-
) -> <OwnedStringType as AccessType>::ScalarRef<'a> {
294-
scalar_ref_to_string(&value)
295-
}
296-
297-
fn compute_domain(_: &<AnyType as AccessType>::Domain) -> StringDomain {
298-
StringType::full_domain()
299-
}
300-
}

src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use itertools::Itertools;
4444
use super::batch_merge1;
4545
use super::batch_serialize1;
4646
use super::AggregateFunctionSortDesc;
47+
use super::SerializeInfo;
4748
use super::StateSerde;
4849

4950
#[derive(Debug, Clone)]
@@ -52,7 +53,7 @@ pub struct SortAggState {
5253
}
5354

5455
impl StateSerde for SortAggState {
55-
fn serialize_type(_: Option<&dyn super::FunctionData>) -> Vec<StateSerdeItem> {
56+
fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
5657
vec![StateSerdeItem::Binary(None)]
5758
}
5859

src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use super::AggregateFunction;
3636
use super::AggregateFunctionDescription;
3737
use super::AggregateFunctionSortDesc;
3838
use super::AggregateUnaryFunction;
39-
use super::FunctionData;
39+
use super::SerializeInfo;
4040
use super::StateSerde;
4141
use super::UnaryState;
4242

@@ -48,11 +48,7 @@ where
4848
T: ValueType,
4949
T::Scalar: Hash,
5050
{
51-
fn add(
52-
&mut self,
53-
other: T::ScalarRef<'_>,
54-
_function_data: Option<&dyn FunctionData>,
55-
) -> Result<()> {
51+
fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> {
5652
self.add_object(&T::to_owned_scalar(other));
5753
Ok(())
5854
}
@@ -65,15 +61,15 @@ where
6561
fn merge_result(
6662
&mut self,
6763
mut builder: BuilderMut<'_, UInt64Type>,
68-
_function_data: Option<&dyn FunctionData>,
64+
_: &Self::FunctionInfo,
6965
) -> Result<()> {
7066
builder.push(self.count() as u64);
7167
Ok(())
7268
}
7369
}
7470

7571
impl<const HLL_P: usize> StateSerde for AggregateApproxCountDistinctState<HLL_P> {
76-
fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec<StateSerdeItem> {
72+
fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
7773
vec![StateSerdeItem::Binary(None)]
7874
}
7975

@@ -144,39 +140,39 @@ fn create_templated<const P: usize>(
144140
let return_type = DataType::Number(NumberDataType::UInt64);
145141
with_number_mapped_type!(|NUM_TYPE| match &arguments[0] {
146142
DataType::Number(NumberDataType::NUM_TYPE) => {
147-
AggregateUnaryFunction::<HyperLogLog<P>, NumberType<NUM_TYPE>, UInt64Type>::create(
143+
AggregateUnaryFunction::<HyperLogLog<P>, NumberType<NUM_TYPE>, UInt64Type>::new(
148144
display_name,
149145
return_type,
150146
)
151147
.with_need_drop(true)
152148
.finish()
153149
}
154150
DataType::String => {
155-
AggregateUnaryFunction::<HyperLogLog<P>, StringType, UInt64Type>::create(
151+
AggregateUnaryFunction::<HyperLogLog<P>, StringType, UInt64Type>::new(
156152
display_name,
157153
return_type,
158154
)
159155
.with_need_drop(true)
160156
.finish()
161157
}
162158
DataType::Date => {
163-
AggregateUnaryFunction::<HyperLogLog<P>, DateType, UInt64Type>::create(
159+
AggregateUnaryFunction::<HyperLogLog<P>, DateType, UInt64Type>::new(
164160
display_name,
165161
return_type,
166162
)
167163
.with_need_drop(true)
168164
.finish()
169165
}
170166
DataType::Timestamp => {
171-
AggregateUnaryFunction::<HyperLogLog<P>, TimestampType, UInt64Type>::create(
167+
AggregateUnaryFunction::<HyperLogLog<P>, TimestampType, UInt64Type>::new(
172168
display_name,
173169
return_type,
174170
)
175171
.with_need_drop(true)
176172
.finish()
177173
}
178174
_ => {
179-
AggregateUnaryFunction::<HyperLogLog<P>, AnyType, UInt64Type>::create(
175+
AggregateUnaryFunction::<HyperLogLog<P>, AnyType, UInt64Type>::new(
180176
display_name,
181177
return_type,
182178
)

src/query/functions/src/aggregates/aggregate_array_agg.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use super::AggrStateLoc;
6363
use super::AggregateFunction;
6464
use super::AggregateFunctionDescription;
6565
use super::AggregateFunctionSortDesc;
66-
use super::FunctionData;
66+
use super::SerializeInfo;
6767
use super::StateAddr;
6868
use super::StateSerde;
6969

@@ -132,8 +132,8 @@ where
132132
Self: ScalarStateFunc<T>,
133133
T: ValueType,
134134
{
135-
fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec<StateSerdeItem> {
136-
let return_type = function_data
135+
fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
136+
let return_type = info
137137
.and_then(|data| data.as_any().downcast_ref::<DataType>())
138138
.cloned()
139139
.unwrap();
@@ -234,8 +234,8 @@ where T: SimpleType + Debug
234234
impl<T> StateSerde for ArrayAggStateSimple<T>
235235
where T: SimpleType
236236
{
237-
fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec<StateSerdeItem> {
238-
let data_type = function_data
237+
fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
238+
let data_type = info
239239
.and_then(|data| data.as_any().downcast_ref::<DataType>())
240240
.and_then(|ty| ty.as_array())
241241
.unwrap()
@@ -336,7 +336,7 @@ where V: ZeroSizeType
336336
}
337337

338338
impl<const IS_NULL: bool> StateSerde for ArrayAggStateZST<IS_NULL> {
339-
fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec<StateSerdeItem> {
339+
fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
340340
vec![ArrayType::<BooleanType>::data_type().into()]
341341
}
342342

@@ -464,8 +464,8 @@ where T: ArgType + Debug + std::marker::Send
464464
impl<T> StateSerde for ArrayAggStateBinary<T>
465465
where T: ArgType + std::marker::Send
466466
{
467-
fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec<StateSerdeItem> {
468-
let data_type = function_data
467+
fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
468+
let data_type = info
469469
.and_then(|data| data.as_any().downcast_ref::<DataType>())
470470
.and_then(|ty| ty.as_array())
471471
.unwrap()

src/query/functions/src/aggregates/aggregate_array_moving.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use super::AggregateFunction;
4646
use super::AggregateFunctionDescription;
4747
use super::AggregateFunctionRef;
4848
use super::AggregateFunctionSortDesc;
49+
use super::SerializeInfo;
4950
use super::StateAddr;
5051
use super::StateSerde;
5152

@@ -221,7 +222,7 @@ where
221222
Self: SumState,
222223
T: Number,
223224
{
224-
fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec<StateSerdeItem> {
225+
fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
225226
vec![ArrayType::<NumberType<T>>::data_type().into()]
226227
}
227228

@@ -452,7 +453,7 @@ where
452453
Self: SumState,
453454
T: Decimal,
454455
{
455-
fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec<StateSerdeItem> {
456+
fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec<StateSerdeItem> {
456457
vec![DataType::Array(Box::new(DataType::Decimal(T::default_decimal_size()))).into()]
457458
}
458459

0 commit comments

Comments
 (0)