Skip to content

Commit a696e2a

Browse files
loloxwgKKould
andauthored
feat(type): add support for Decimal type in database (#66)
Co-authored-by: Kould <2435992353@qq.com>
1 parent a348db1 commit a696e2a

File tree

10 files changed

+183
-33
lines changed

10 files changed

+183
-33
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ comfy-table = "7.0.1"
3232
bytes = "*"
3333
kip_db = "0.1.2-alpha.15"
3434
async-recursion = "1.0.5"
35+
rust_decimal = "1"
3536

3637
[dev-dependencies]
3738
tokio-test = "0.4.2"

src/binder/insert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl<S: Storage> Binder<S> {
5050
match &self.bind_expr(expr).await? {
5151
ScalarExpression::Constant(value) => {
5252
// Check if the value length is too long
53-
value.check_length(columns[i].datatype())?;
53+
value.check_len(columns[i].datatype())?;
5454
let cast_value = DataValue::clone(value)
5555
.cast(columns[i].datatype())?;
5656
row.push(Arc::new(cast_value))

src/binder/update.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl<S: Storage> Binder<S> {
4444
bind_table_name.as_ref()
4545
).await? {
4646
ScalarExpression::ColumnRef(catalog) => {
47-
value.check_length(catalog.datatype())?;
47+
value.check_len(catalog.datatype())?;
4848
columns.push(catalog);
4949
row.push(value.clone());
5050
},

src/db.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ mod test {
186186
let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?;
187187
let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 2, 2)").await?;
188188
let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?;
189+
let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?;
190+
let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?;
191+
let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?;
189192

190193
println!("full t1:");
191194
let tuples_full_fields_t1 = kipsql.run("select * from t1").await?;
@@ -305,6 +308,10 @@ mod test {
305308
let tuples_show_tables = kipsql.run("show tables").await?;
306309
println!("{}", create_table(&tuples_show_tables));
307310

311+
println!("decimal:");
312+
let tuples_decimal = kipsql.run("select * from t3").await?;
313+
println!("{}", create_table(&tuples_decimal));
314+
308315
Ok(())
309316
}
310317
}

src/expression/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::fmt;
2-
use std::fmt::Formatter;
2+
use std::fmt::{Debug, Formatter};
33
use std::sync::Arc;
44
use itertools::Itertools;
55

src/storage/table_codec.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ mod tests {
146146
use std::ops::Bound;
147147
use std::sync::Arc;
148148
use itertools::Itertools;
149+
use rust_decimal::Decimal;
149150
use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog};
150151
use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec};
151152
use crate::types::errors::TypeError;
@@ -159,7 +160,12 @@ mod tests {
159160
"c1".into(),
160161
false,
161162
ColumnDesc::new(LogicalType::Integer, true)
162-
)
163+
),
164+
ColumnCatalog::new(
165+
"c2".into(),
166+
false,
167+
ColumnDesc::new(LogicalType::Decimal(None,None), false)
168+
),
163169
];
164170
let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap();
165171
let codec = TableCodec { table: table_catalog.clone() };
@@ -175,6 +181,7 @@ mod tests {
175181
columns: table_catalog.all_columns(),
176182
values: vec![
177183
Arc::new(DataValue::Int32(Some(0))),
184+
Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))),
178185
]
179186
};
180187

src/types/errors.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,10 @@ pub enum TypeError {
4646
#[from]
4747
ParseError,
4848
),
49+
#[error("try from decimal")]
50+
TryFromDecimal(
51+
#[source]
52+
#[from]
53+
rust_decimal::Error,
54+
),
4955
}

src/types/mod.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::atomic::Ordering::{Acquire, Release};
77
use serde::{Deserialize, Serialize};
88

99
use integer_encoding::FixedInt;
10+
use sqlparser::ast::ExactNumberInfo;
1011
use strum_macros::AsRefStr;
1112

1213
use crate::types::errors::TypeError;
@@ -57,6 +58,8 @@ pub enum LogicalType {
5758
Varchar(Option<u32>),
5859
Date,
5960
DateTime,
61+
// decimal (precision, scale)
62+
Decimal(Option<u8>, Option<u8>),
6063
}
6164

6265
impl LogicalType {
@@ -75,8 +78,9 @@ impl LogicalType {
7578
LogicalType::UBigint => Some(8),
7679
LogicalType::Float => Some(4),
7780
LogicalType::Double => Some(8),
78-
/// Note: The non-fixed length type's raw_len is None
79-
LogicalType::Varchar(_)=>None,
81+
/// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal
82+
LogicalType::Varchar(_) => None,
83+
LogicalType::Decimal(_, _) => Some(16),
8084
LogicalType::Date => Some(4),
8185
LogicalType::DateTime => Some(8),
8286
}
@@ -269,6 +273,7 @@ impl LogicalType {
269273
LogicalType::Varchar(_) => false,
270274
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)),
271275
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)),
276+
LogicalType::Decimal(_, _) => false,
272277
}
273278
}
274279
}
@@ -296,6 +301,13 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {
296301
sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint),
297302
sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean),
298303
sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime),
304+
sqlparser::ast::DataType::Decimal(info) => match info {
305+
ExactNumberInfo::None => Ok(Self::Decimal(None, None)),
306+
ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)),
307+
ExactNumberInfo::PrecisionAndScale(p, s) => {
308+
Ok(Self::Decimal(Some(p as u8), Some(s as u8)))
309+
}
310+
},
299311
other => Err(TypeError::NotImplementedSqlparserDataType(
300312
other.to_string(),
301313
)),
@@ -313,7 +325,7 @@ impl std::fmt::Display for LogicalType {
313325
mod test {
314326
use std::sync::atomic::Ordering::Release;
315327

316-
use crate::types::{IdGenerator, ID_BUF, LogicalType};
328+
use crate::types::{IdGenerator, ID_BUF};
317329

318330
/// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰
319331
#[test]

0 commit comments

Comments
 (0)