Skip to content

Commit aace5f4

Browse files
mshauneureturnString
authored andcommitted
Add Time32, Time64 and Duration
1 parent 1ac88ac commit aace5f4

File tree

4 files changed

+227
-49
lines changed

4 files changed

+227
-49
lines changed

convergence-arrow/src/table.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlSta
66
use convergence::protocol_ext::DataRowBatch;
77
use datafusion::arrow::array::timezone::Tz;
88
use datafusion::arrow::array::{
9-
BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, Float64Array, Int16Array,
10-
Int32Array, Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray,
11-
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
12-
UInt8Array,
9+
BooleanArray, Date32Array, Date64Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
10+
DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array,
11+
Int64Array, Int8Array, StringArray, StringViewArray, Time32MillisecondArray, Time32SecondArray,
12+
Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
13+
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
1314
};
1415
use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
1516
use datafusion::arrow::record_batch::RecordBatch;
@@ -64,6 +65,54 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat
6465
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
6566
})?)
6667
}
68+
DataType::Time32(unit) => match unit {
69+
TimeUnit::Second => {
70+
row.write_time(array_val!(Time32SecondArray, col, row_idx, value_as_time).ok_or_else(
71+
|| ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type"),
72+
)?)
73+
}
74+
TimeUnit::Millisecond => row.write_time(
75+
array_val!(Time32MillisecondArray, col, row_idx, value_as_time).ok_or_else(|| {
76+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
77+
})?,
78+
),
79+
_ => {}
80+
},
81+
DataType::Time64(unit) => match unit {
82+
TimeUnit::Microsecond => row.write_time(
83+
array_val!(Time64MicrosecondArray, col, row_idx, value_as_time).ok_or_else(|| {
84+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
85+
})?,
86+
),
87+
TimeUnit::Nanosecond => row.write_time(
88+
array_val!(Time64NanosecondArray, col, row_idx, value_as_time).ok_or_else(|| {
89+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
90+
})?,
91+
),
92+
_ => {}
93+
},
94+
DataType::Duration(unit) => match unit {
95+
TimeUnit::Second => row.write_duration(
96+
array_val!(DurationSecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
97+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
98+
})?,
99+
),
100+
TimeUnit::Millisecond => row.write_duration(
101+
array_val!(DurationMillisecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
102+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
103+
})?,
104+
),
105+
TimeUnit::Microsecond => row.write_duration(
106+
array_val!(DurationMicrosecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
107+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
108+
})?,
109+
),
110+
TimeUnit::Nanosecond => row.write_duration(
111+
array_val!(DurationNanosecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
112+
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
113+
})?,
114+
),
115+
},
67116
DataType::Timestamp(unit, tz) => {
68117
match tz {
69118
Some(tz) => {
@@ -140,6 +189,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
140189
DataType::Decimal128(_, _) => DataTypeOid::Numeric,
141190
DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
142191
DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
192+
DataType::Time32(_) | DataType::Time64(_) => DataTypeOid::Time,
193+
DataType::Duration(_) => DataTypeOid::Interval,
143194
DataType::Timestamp(_, tz) => match tz {
144195
Some(_) => DataTypeOid::Timestamptz,
145196
None => DataTypeOid::Timestamp,

convergence-arrow/tests/test_arrow.rs

Lines changed: 141 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
use async_trait::async_trait;
2-
use chrono::{DateTime, NaiveDate, NaiveDateTime};
2+
use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta};
33
use convergence::engine::{Engine, Portal};
44
use convergence::protocol::{ErrorResponse, FieldDescription};
55
use convergence::protocol_ext::DataRowBatch;
66
use convergence::server::{self, BindOptions};
77
use convergence::sqlparser::ast::Statement;
88
use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc};
99
use datafusion::arrow::array::{
10-
ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray,
11-
TimestampSecondArray,
10+
ArrayRef, Date32Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
11+
DurationNanosecondArray, DurationSecondArray, Float32Array, Int32Array, StringArray, StringViewArray,
12+
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray,
1213
};
1314
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
1415
use datafusion::arrow::record_batch::RecordBatch;
1516
use rust_decimal::Decimal;
17+
use std::convert::TryInto;
1618
use std::sync::Arc;
19+
use tokio_postgres::types::{FromSql, Type};
1720
use tokio_postgres::{connect, NoTls};
1821

1922
struct ArrowPortal {
@@ -47,6 +50,23 @@ impl ArrowEngine {
4750
Arc::new(TimestampSecondArray::from(vec![1577854800, 1580533200, 1583038800]).with_timezone("+05:00"))
4851
as ArrayRef;
4952
let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef;
53+
let time_s_col = Arc::new(Time32SecondArray::from(vec![30, 60, 90])) as ArrayRef;
54+
let time_ms_col = Arc::new(Time32MillisecondArray::from(vec![30_000, 60_000, 90_000])) as ArrayRef;
55+
let time_mcs_col = Arc::new(Time64MicrosecondArray::from(vec![30_000_000, 60_000_000, 90_000_000])) as ArrayRef;
56+
let time_ns_col = Arc::new(Time64NanosecondArray::from(vec![
57+
30_000_000_000,
58+
60_000_000_000,
59+
90_000_000_000,
60+
])) as ArrayRef;
61+
let duration_s_col = Arc::new(DurationSecondArray::from(vec![3, 6, 9])) as ArrayRef;
62+
let duration_ms_col = Arc::new(DurationMillisecondArray::from(vec![3_000, 6_000, 9_000])) as ArrayRef;
63+
let duration_mcs_col =
64+
Arc::new(DurationMicrosecondArray::from(vec![3_000_000, 6_000_000, 9_000_000])) as ArrayRef;
65+
let duration_ns_col = Arc::new(DurationNanosecondArray::from(vec![
66+
3_000_000_000,
67+
6_000_000_000,
68+
9_000_000_000,
69+
])) as ArrayRef;
5070

5171
let schema = Schema::new(vec![
5272
Field::new("int_col", DataType::Int32, true),
@@ -61,6 +81,14 @@ impl ArrowEngine {
6181
true,
6282
),
6383
Field::new("date_col", DataType::Date32, true),
84+
Field::new("time_s_col", DataType::Time32(TimeUnit::Second), true),
85+
Field::new("time_ms_col", DataType::Time32(TimeUnit::Millisecond), true),
86+
Field::new("time_mcs_col", DataType::Time64(TimeUnit::Microsecond), true),
87+
Field::new("time_ns_col", DataType::Time64(TimeUnit::Nanosecond), true),
88+
Field::new("duration_s_col", DataType::Duration(TimeUnit::Second), true),
89+
Field::new("duration_ms_col", DataType::Duration(TimeUnit::Millisecond), true),
90+
Field::new("duration_mcs_col", DataType::Duration(TimeUnit::Microsecond), true),
91+
Field::new("duration_ns_col", DataType::Duration(TimeUnit::Nanosecond), true),
6492
]);
6593

6694
Self {
@@ -75,6 +103,14 @@ impl ArrowEngine {
75103
ts_col,
76104
ts_tz_col,
77105
date_col,
106+
time_s_col,
107+
time_ms_col,
108+
time_mcs_col,
109+
time_ns_col,
110+
duration_s_col,
111+
duration_ms_col,
112+
duration_mcs_col,
113+
duration_ns_col,
78114
],
79115
)
80116
.expect("failed to create batch"),
@@ -114,14 +150,45 @@ async fn setup() -> tokio_postgres::Client {
114150
client
115151
}
116152

153+
// remove after https://github.com/sfackler/rust-postgres/pull/1238 is merged
154+
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
155+
struct DurationWrapper(TimeDelta);
156+
157+
impl<'a> FromSql<'a> for DurationWrapper {
158+
fn from_sql(_ty: &Type, raw: &[u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
159+
let micros = i64::from_be_bytes(raw.try_into().unwrap());
160+
Ok(DurationWrapper(Duration::microseconds(micros)))
161+
}
162+
fn accepts(ty: &Type) -> bool {
163+
matches!(ty, &Type::INTERVAL)
164+
}
165+
}
166+
117167
#[tokio::test]
118168
async fn basic_data_types() {
119169
let client = setup().await;
120170

121171
let rows = client.query("select 1", &[]).await.unwrap();
122172
let get_row = |idx: usize| {
123173
let row = &rows[idx];
124-
let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, DateTime<_>, NaiveDate) = (
174+
let cols: (
175+
i32,
176+
f32,
177+
Decimal,
178+
&str,
179+
&str,
180+
NaiveDateTime,
181+
DateTime<FixedOffset>,
182+
NaiveDate,
183+
NaiveTime,
184+
NaiveTime,
185+
NaiveTime,
186+
NaiveTime,
187+
DurationWrapper,
188+
DurationWrapper,
189+
DurationWrapper,
190+
DurationWrapper,
191+
) = (
125192
row.get(0),
126193
row.get(1),
127194
row.get(2),
@@ -130,56 +197,87 @@ async fn basic_data_types() {
130197
row.get(5),
131198
row.get(6),
132199
row.get(7),
200+
row.get(8),
201+
row.get(9),
202+
row.get(10),
203+
row.get(11),
204+
row.get(12),
205+
row.get(13),
206+
row.get(14),
207+
row.get(15),
133208
);
134209
cols
135210
};
136211

137-
assert_eq!(
138-
get_row(0),
139-
(
140-
1,
141-
1.5,
142-
Decimal::from(11),
143-
"a",
144-
"aa",
145-
NaiveDate::from_ymd_opt(2020, 1, 1)
212+
let row = get_row(0);
213+
assert!(row.0 == 1);
214+
assert!(row.1 == 1.5);
215+
assert!(row.2 == Decimal::from(11));
216+
assert!(row.3 == "a");
217+
assert!(row.4 == "aa");
218+
assert!(
219+
row.5
220+
== NaiveDate::from_ymd_opt(2020, 1, 1)
146221
.unwrap()
147222
.and_hms_opt(0, 0, 0)
148-
.unwrap(),
149-
DateTime::from_timestamp_millis(1577854800000).unwrap(),
150-
NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(),
151-
)
223+
.unwrap()
152224
);
153-
assert_eq!(
154-
get_row(1),
155-
(
156-
2,
157-
2.5,
158-
Decimal::from(22),
159-
"b",
160-
"bb",
161-
NaiveDate::from_ymd_opt(2020, 2, 1)
225+
assert!(row.6 == DateTime::from_timestamp_millis(1577854800000).unwrap());
226+
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 1).unwrap());
227+
assert!(row.8 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
228+
assert!(row.9 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
229+
assert!(row.10 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
230+
assert!(row.11 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
231+
assert!(row.12 == DurationWrapper(Duration::seconds(3)));
232+
assert!(row.13 == DurationWrapper(Duration::seconds(3)));
233+
assert!(row.14 == DurationWrapper(Duration::seconds(3)));
234+
assert!(row.15 == DurationWrapper(Duration::seconds(3)));
235+
236+
let row = get_row(1);
237+
assert!(row.0 == 2);
238+
assert!(row.1 == 2.5);
239+
assert!(row.2 == Decimal::from(22));
240+
assert!(row.3 == "b");
241+
assert!(row.4 == "bb");
242+
assert!(
243+
row.5
244+
== NaiveDate::from_ymd_opt(2020, 2, 1)
162245
.unwrap()
163246
.and_hms_opt(0, 0, 0)
164-
.unwrap(),
165-
DateTime::from_timestamp_millis(1580533200000).unwrap(),
166-
NaiveDate::from_ymd_opt(1970, 1, 2).unwrap()
167-
)
247+
.unwrap()
168248
);
169-
assert_eq!(
170-
get_row(2),
171-
(
172-
3,
173-
3.5,
174-
Decimal::from(33),
175-
"c",
176-
"cc",
177-
NaiveDate::from_ymd_opt(2020, 3, 1)
249+
assert!(row.6 == DateTime::from_timestamp_millis(1580533200000).unwrap());
250+
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 2).unwrap());
251+
assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
252+
assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
253+
assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
254+
assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
255+
assert!(row.12 == DurationWrapper(Duration::seconds(6)));
256+
assert!(row.13 == DurationWrapper(Duration::seconds(6)));
257+
assert!(row.14 == DurationWrapper(Duration::seconds(6)));
258+
assert!(row.15 == DurationWrapper(Duration::seconds(6)));
259+
260+
let row = get_row(2);
261+
assert!(row.0 == 3);
262+
assert!(row.1 == 3.5);
263+
assert!(row.2 == Decimal::from(33));
264+
assert!(row.3 == "c");
265+
assert!(row.4 == "cc");
266+
assert!(
267+
row.5
268+
== NaiveDate::from_ymd_opt(2020, 3, 1)
178269
.unwrap()
179270
.and_hms_opt(0, 0, 0)
180-
.unwrap(),
181-
DateTime::from_timestamp_millis(1583038800000).unwrap(),
182-
NaiveDate::from_ymd_opt(1970, 1, 3).unwrap()
183-
)
271+
.unwrap()
184272
);
273+
assert!(row.6 == DateTime::from_timestamp_millis(1583038800000).unwrap());
274+
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 3).unwrap());
275+
assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
276+
assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
277+
assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
278+
assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
279+
assert!(row.12 == DurationWrapper(Duration::seconds(9)));
280+
assert!(row.13 == DurationWrapper(Duration::seconds(9)));
281+
assert!(row.14 == DurationWrapper(Duration::seconds(9)));
282+
assert!(row.15 == DurationWrapper(Duration::seconds(9)));
185283
}

convergence/src/protocol.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,14 @@ data_types! {
7878
Numeric = 1700, -1
7979

8080
Date = 1082, 4
81+
82+
Time = 1083, 8
83+
8184
Timestamp = 1114, 8
8285
Timestamptz = 1184, 8
8386

87+
Interval = 1186, 16
88+
8489
Text = 25, -1
8590
}
8691

convergence/src/protocol_ext.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription};
44
use bytes::{BufMut, BytesMut};
5-
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime};
5+
use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
66
use rust_decimal::Decimal;
77
use tokio_postgres::types::{ToSql, Type};
88
use tokio_util::codec::Encoder;
@@ -119,7 +119,31 @@ impl<'a> DataRowWriter<'a> {
119119
}
120120
}
121121

122-
/// Writes a timestamp value for the next column.
122+
/// Writes a time value for the next column.
123+
pub fn write_time(&mut self, val: NaiveTime) {
124+
match self.parent.format_code {
125+
FormatCode::Binary => {
126+
self.write_int8((val.num_seconds_from_midnight() * 1_000_000 + val.nanosecond() / 1_000) as i64);
127+
}
128+
FormatCode::Text => self.write_string(&val.to_string()),
129+
}
130+
}
131+
132+
/// Writes a time value for the next column.
133+
pub fn write_duration(&mut self, val: Duration) {
134+
match self.parent.format_code {
135+
FormatCode::Binary => {
136+
let total_micros = val.num_microseconds().unwrap_or_else(|| {
137+
// Fallback for very large durations that may not fit in i64 microseconds
138+
val.num_seconds() * 1_000_000 + (val.subsec_nanos() as i64) / 1_000
139+
});
140+
self.write_int8(total_micros);
141+
}
142+
FormatCode::Text => self.write_string(&val.to_string()),
143+
}
144+
}
145+
146+
/// Writes a timestamp value fxor the next column.
123147
pub fn write_timestamp(&mut self, val: NaiveDateTime) {
124148
match self.parent.format_code {
125149
FormatCode::Binary => {

0 commit comments

Comments
 (0)