Skip to content

Commit 1853a9b

Browse files
authored
feat: Refactor 'DataValue' length check logic; add support for varchar computations (#60)
1 parent 15e72bc commit 1853a9b

File tree

8 files changed

+193
-43
lines changed

8 files changed

+193
-43
lines changed

src/binder/create_table.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ mod tests {
7878
assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true));
7979
assert_eq!(op.columns[1].name, "name".to_string());
8080
assert_eq!(op.columns[1].nullable, true);
81-
assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar, false));
81+
assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar(Some(10)), false));
8282
}
8383
_ => unreachable!()
8484
}

src/binder/insert.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ impl<S: Storage> Binder<S> {
4949
for (i, expr) in expr_row.into_iter().enumerate() {
5050
match &self.bind_expr(expr).await? {
5151
ScalarExpression::Constant(value) => {
52+
// Check if the value length is too long
53+
value.check_length(columns[i].datatype())?;
5254
let cast_value = DataValue::clone(value)
5355
.cast(columns[i].datatype())?;
54-
5556
row.push(Arc::new(cast_value))
5657
},
5758
ScalarExpression::Unary { expr, op, .. } => {

src/binder/update.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ impl<S: Storage> Binder<S> {
4444
bind_table_name.as_ref()
4545
).await? {
4646
ScalarExpression::ColumnRef(catalog) => {
47+
value.check_length(catalog.datatype())?;
4748
columns.push(catalog);
4849
row.push(value.clone());
4950
},

src/expression/value_compute.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ fn unpack_date(value: DataValue) -> Option<i64> {
5959
}
6060
}
6161

62+
fn unpack_utf8(value: DataValue) -> Option<String> {
63+
match value {
64+
DataValue::Utf8(inner) => inner,
65+
_ => None
66+
}
67+
}
68+
6269
pub fn unary_op(
6370
value: &DataValue,
6471
op: &UnaryOperator,
@@ -114,7 +121,7 @@ pub fn binary_op(
114121
) -> Result<DataValue, TypeError> {
115122
let unified_type = LogicalType::max_logical_type(
116123
&left.logical_type(),
117-
&right.logical_type()
124+
&right.logical_type(),
118125
)?;
119126

120127
let value = match &unified_type {
@@ -844,6 +851,76 @@ pub fn binary_op(
844851
_ => todo!("unsupported operator")
845852
}
846853
}
854+
LogicalType::Varchar(None) => {
855+
let left_value = unpack_utf8(left.clone().cast(&unified_type)?);
856+
let right_value = unpack_utf8(right.clone().cast(&unified_type)?);
857+
858+
match op {
859+
BinaryOperator::Gt => {
860+
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
861+
Some(v1 > v2)
862+
} else {
863+
None
864+
};
865+
866+
DataValue::Boolean(value)
867+
}
868+
BinaryOperator::Lt => {
869+
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
870+
Some(v1 < v2)
871+
} else {
872+
None
873+
};
874+
875+
DataValue::Boolean(value)
876+
}
877+
BinaryOperator::GtEq => {
878+
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
879+
Some(v1 >= v2)
880+
} else {
881+
None
882+
};
883+
884+
DataValue::Boolean(value)
885+
}
886+
BinaryOperator::LtEq => {
887+
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
888+
Some(v1 <= v2)
889+
} else {
890+
None
891+
};
892+
893+
DataValue::Boolean(value)
894+
}
895+
BinaryOperator::Eq => {
896+
let value = match (left_value, right_value) {
897+
(Some(v1), Some(v2)) => {
898+
Some(v1 == v2)
899+
}
900+
(None, None) => {
901+
Some(true)
902+
}
903+
(_, _) => {
904+
None
905+
}
906+
};
907+
908+
DataValue::Boolean(value)
909+
}
910+
BinaryOperator::NotEq => {
911+
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
912+
Some(v1 != v2)
913+
} else {
914+
None
915+
};
916+
917+
DataValue::Boolean(value)
918+
}
919+
_ => todo!("unsupported operator")
920+
}
921+
}
922+
// Utf8
923+
847924
_ => todo!("unsupported data type"),
848925
};
849926

@@ -1105,4 +1182,22 @@ mod test {
11051182

11061183
Ok(())
11071184
}
1185+
1186+
#[test]
1187+
fn test_binary_op_Utf8_compare()->Result<(),TypeError>{
1188+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(Some(false)));
1189+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(Some(true)));
1190+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true)));
1191+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true)));
1192+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false)));
1193+
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true)));
1194+
1195+
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(None));
1196+
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(None));
1197+
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(None));
1198+
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(None));
1199+
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(None));
1200+
1201+
Ok(())
1202+
}
11081203
}

src/types/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub enum TypeError {
1212
InternalError(String),
1313
#[error("cast fail")]
1414
CastFail,
15+
#[error("Too long")]
16+
TooLong,
1517
#[error("cannot be Null")]
1618
NotNull,
1719
#[error("try from int")]

src/types/mod.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pub enum LogicalType {
5454
UBigint,
5555
Float,
5656
Double,
57-
Varchar,
57+
Varchar(Option<u32>),
5858
Date,
5959
DateTime,
6060
}
@@ -75,7 +75,8 @@ impl LogicalType {
7575
LogicalType::UBigint => Some(8),
7676
LogicalType::Float => Some(4),
7777
LogicalType::Double => Some(8),
78-
LogicalType::Varchar => None,
78+
/// Note: The non-fixed length type's raw_len is None
79+
LogicalType::Varchar(_)=>None,
7980
LogicalType::Date => Some(4),
8081
LogicalType::DateTime => Some(8),
8182
}
@@ -156,13 +157,13 @@ impl LogicalType {
156157
if left.is_numeric() && right.is_numeric() {
157158
return LogicalType::combine_numeric_types(left, right);
158159
}
159-
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar) | (LogicalType::Varchar, LogicalType::Date)) {
160+
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::Date)) {
160161
return Ok(LogicalType::Date);
161162
}
162163
if matches!((left, right), (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date)) {
163164
return Ok(LogicalType::DateTime);
164165
}
165-
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar) | (LogicalType::Varchar, LogicalType::DateTime)) {
166+
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::DateTime)) {
166167
return Ok(LogicalType::DateTime);
167168
}
168169
Err(TypeError::InternalError(format!(
@@ -265,9 +266,9 @@ impl LogicalType {
265266
LogicalType::UBigint => matches!(to, LogicalType::Float | LogicalType::Double),
266267
LogicalType::Float => matches!(to, LogicalType::Double),
267268
LogicalType::Double => false,
268-
LogicalType::Varchar => false,
269-
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar),
270-
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar),
269+
LogicalType::Varchar(_) => false,
270+
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)),
271+
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)),
271272
}
272273
}
273274
}
@@ -278,11 +279,8 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {
278279

279280
fn try_from(value: sqlparser::ast::DataType) -> Result<Self, Self::Error> {
280281
match value {
281-
sqlparser::ast::DataType::Char(_)
282-
| sqlparser::ast::DataType::Varchar(_)
283-
| sqlparser::ast::DataType::Nvarchar(_)
284-
| sqlparser::ast::DataType::Text
285-
| sqlparser::ast::DataType::String => Ok(LogicalType::Varchar),
282+
sqlparser::ast::DataType::Char(len)
283+
| sqlparser::ast::DataType::Varchar(len)=> Ok(LogicalType::Varchar(len.map(|len| len.length as u32))),
286284
sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float),
287285
sqlparser::ast::DataType::Double => Ok(LogicalType::Double),
288286
sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint),
@@ -315,7 +313,7 @@ impl std::fmt::Display for LogicalType {
315313
mod test {
316314
use std::sync::atomic::Ordering::Release;
317315

318-
use crate::types::{IdGenerator, ID_BUF};
316+
use crate::types::{IdGenerator, ID_BUF, LogicalType};
319317

320318
/// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰
321319
#[test]

src/types/tuple.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ impl Tuple {
3535
if bit_index(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) {
3636
values.push(Arc::new(DataValue::none(logic_type)));
3737
} else if let Some(len) = logic_type.raw_len() {
38+
/// fixed length (e.g.: int)
3839
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
3940
pos += len;
4041
} else {
42+
/// variable length (e.g.: varchar)
4143
let len = u32::decode_fixed(&bytes[pos..pos + 4]) as usize;
4244
pos += 4;
4345
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
@@ -133,7 +135,7 @@ mod tests {
133135
Arc::new(ColumnCatalog::new(
134136
"c3".to_string(),
135137
false,
136-
ColumnDesc::new(LogicalType::Varchar, false)
138+
ColumnDesc::new(LogicalType::Varchar(Some(2)), false)
137139
)),
138140
Arc::new(ColumnCatalog::new(
139141
"c4".to_string(),

0 commit comments

Comments
 (0)