11use async_trait:: async_trait;
2- use chrono:: { DateTime , NaiveDate , NaiveDateTime } ;
2+ use chrono:: { DateTime , Duration , FixedOffset , NaiveDate , NaiveDateTime , NaiveTime , TimeDelta } ;
33use convergence:: engine:: { Engine , Portal } ;
44use convergence:: protocol:: { ErrorResponse , FieldDescription } ;
55use convergence:: protocol_ext:: DataRowBatch ;
66use convergence:: server:: { self , BindOptions } ;
77use convergence:: sqlparser:: ast:: Statement ;
88use convergence_arrow:: table:: { record_batch_to_rows, schema_to_field_desc} ;
99use datafusion:: arrow:: array:: {
10- ArrayRef , Date32Array , Decimal128Array , Float32Array , Int32Array , StringArray , StringViewArray ,
11- TimestampSecondArray ,
10+ ArrayRef , Date32Array , Decimal128Array , DurationMicrosecondArray , DurationMillisecondArray ,
11+ DurationNanosecondArray , DurationSecondArray , Float32Array , Int32Array , StringArray , StringViewArray ,
12+ Time32MillisecondArray , Time32SecondArray , Time64MicrosecondArray , Time64NanosecondArray , TimestampSecondArray ,
1213} ;
1314use datafusion:: arrow:: datatypes:: { DataType , Field , Schema , TimeUnit } ;
1415use datafusion:: arrow:: record_batch:: RecordBatch ;
1516use rust_decimal:: Decimal ;
17+ use std:: convert:: TryInto ;
1618use std:: sync:: Arc ;
19+ use tokio_postgres:: types:: { FromSql , Type } ;
1720use tokio_postgres:: { connect, NoTls } ;
1821
1922struct ArrowPortal {
@@ -47,6 +50,23 @@ impl ArrowEngine {
4750 Arc :: new ( TimestampSecondArray :: from ( vec ! [ 1577854800 , 1580533200 , 1583038800 ] ) . with_timezone ( "+05:00" ) )
4851 as ArrayRef ;
4952 let date_col = Arc :: new ( Date32Array :: from ( vec ! [ 0 , 1 , 2 ] ) ) as ArrayRef ;
53+ let time_s_col = Arc :: new ( Time32SecondArray :: from ( vec ! [ 30 , 60 , 90 ] ) ) as ArrayRef ;
54+ let time_ms_col = Arc :: new ( Time32MillisecondArray :: from ( vec ! [ 30_000 , 60_000 , 90_000 ] ) ) as ArrayRef ;
55+ let time_mcs_col = Arc :: new ( Time64MicrosecondArray :: from ( vec ! [ 30_000_000 , 60_000_000 , 90_000_000 ] ) ) as ArrayRef ;
56+ let time_ns_col = Arc :: new ( Time64NanosecondArray :: from ( vec ! [
57+ 30_000_000_000 ,
58+ 60_000_000_000 ,
59+ 90_000_000_000 ,
60+ ] ) ) as ArrayRef ;
61+ let duration_s_col = Arc :: new ( DurationSecondArray :: from ( vec ! [ 3 , 6 , 9 ] ) ) as ArrayRef ;
62+ let duration_ms_col = Arc :: new ( DurationMillisecondArray :: from ( vec ! [ 3_000 , 6_000 , 9_000 ] ) ) as ArrayRef ;
63+ let duration_mcs_col =
64+ Arc :: new ( DurationMicrosecondArray :: from ( vec ! [ 3_000_000 , 6_000_000 , 9_000_000 ] ) ) as ArrayRef ;
65+ let duration_ns_col = Arc :: new ( DurationNanosecondArray :: from ( vec ! [
66+ 3_000_000_000 ,
67+ 6_000_000_000 ,
68+ 9_000_000_000 ,
69+ ] ) ) as ArrayRef ;
5070
5171 let schema = Schema :: new ( vec ! [
5272 Field :: new( "int_col" , DataType :: Int32 , true ) ,
@@ -61,6 +81,14 @@ impl ArrowEngine {
6181 true ,
6282 ) ,
6383 Field :: new( "date_col" , DataType :: Date32 , true ) ,
84+ Field :: new( "time_s_col" , DataType :: Time32 ( TimeUnit :: Second ) , true ) ,
85+ Field :: new( "time_ms_col" , DataType :: Time32 ( TimeUnit :: Millisecond ) , true ) ,
86+ Field :: new( "time_mcs_col" , DataType :: Time64 ( TimeUnit :: Microsecond ) , true ) ,
87+ Field :: new( "time_ns_col" , DataType :: Time64 ( TimeUnit :: Nanosecond ) , true ) ,
88+ Field :: new( "duration_s_col" , DataType :: Duration ( TimeUnit :: Second ) , true ) ,
89+ Field :: new( "duration_ms_col" , DataType :: Duration ( TimeUnit :: Millisecond ) , true ) ,
90+ Field :: new( "duration_mcs_col" , DataType :: Duration ( TimeUnit :: Microsecond ) , true ) ,
91+ Field :: new( "duration_ns_col" , DataType :: Duration ( TimeUnit :: Nanosecond ) , true ) ,
6492 ] ) ;
6593
6694 Self {
@@ -75,6 +103,14 @@ impl ArrowEngine {
75103 ts_col,
76104 ts_tz_col,
77105 date_col,
106+ time_s_col,
107+ time_ms_col,
108+ time_mcs_col,
109+ time_ns_col,
110+ duration_s_col,
111+ duration_ms_col,
112+ duration_mcs_col,
113+ duration_ns_col,
78114 ] ,
79115 )
80116 . expect ( "failed to create batch" ) ,
@@ -114,14 +150,45 @@ async fn setup() -> tokio_postgres::Client {
114150 client
115151}
116152
153+ // remove after https://github.com/sfackler/rust-postgres/pull/1238 is merged
154+ #[ derive( Clone , Copy , Default , PartialEq , Eq , PartialOrd , Ord , Debug , Hash ) ]
155+ struct DurationWrapper ( TimeDelta ) ;
156+
157+ impl < ' a > FromSql < ' a > for DurationWrapper {
158+ fn from_sql ( _ty : & Type , raw : & [ u8 ] ) -> Result < Self , Box < dyn std:: error:: Error + Sync + Send > > {
159+ let micros = i64:: from_be_bytes ( raw. try_into ( ) . unwrap ( ) ) ;
160+ Ok ( DurationWrapper ( Duration :: microseconds ( micros) ) )
161+ }
162+ fn accepts ( ty : & Type ) -> bool {
163+ matches ! ( ty, & Type :: INTERVAL )
164+ }
165+ }
166+
117167#[ tokio:: test]
118168async fn basic_data_types ( ) {
119169 let client = setup ( ) . await ;
120170
121171 let rows = client. query ( "select 1" , & [ ] ) . await . unwrap ( ) ;
122172 let get_row = |idx : usize | {
123173 let row = & rows[ idx] ;
124- let cols: ( i32 , f32 , Decimal , & str , & str , NaiveDateTime , DateTime < _ > , NaiveDate ) = (
174+ let cols: (
175+ i32 ,
176+ f32 ,
177+ Decimal ,
178+ & str ,
179+ & str ,
180+ NaiveDateTime ,
181+ DateTime < FixedOffset > ,
182+ NaiveDate ,
183+ NaiveTime ,
184+ NaiveTime ,
185+ NaiveTime ,
186+ NaiveTime ,
187+ DurationWrapper ,
188+ DurationWrapper ,
189+ DurationWrapper ,
190+ DurationWrapper ,
191+ ) = (
125192 row. get ( 0 ) ,
126193 row. get ( 1 ) ,
127194 row. get ( 2 ) ,
@@ -130,56 +197,87 @@ async fn basic_data_types() {
130197 row. get ( 5 ) ,
131198 row. get ( 6 ) ,
132199 row. get ( 7 ) ,
200+ row. get ( 8 ) ,
201+ row. get ( 9 ) ,
202+ row. get ( 10 ) ,
203+ row. get ( 11 ) ,
204+ row. get ( 12 ) ,
205+ row. get ( 13 ) ,
206+ row. get ( 14 ) ,
207+ row. get ( 15 ) ,
133208 ) ;
134209 cols
135210 } ;
136211
137- assert_eq ! (
138- get_row ( 0 ) ,
139- (
140- 1 ,
141- 1.5 ,
142- Decimal :: from ( 11 ) ,
143- "a" ,
144- "aa" ,
145- NaiveDate :: from_ymd_opt( 2020 , 1 , 1 )
212+ let row = get_row ( 0 ) ;
213+ assert ! ( row . 0 == 1 ) ;
214+ assert ! ( row . 1 == 1.5 ) ;
215+ assert ! ( row . 2 == Decimal :: from ( 11 ) ) ;
216+ assert ! ( row . 3 == "a" ) ;
217+ assert ! ( row . 4 == "aa" ) ;
218+ assert ! (
219+ row . 5
220+ == NaiveDate :: from_ymd_opt( 2020 , 1 , 1 )
146221 . unwrap( )
147222 . and_hms_opt( 0 , 0 , 0 )
148- . unwrap( ) ,
149- DateTime :: from_timestamp_millis( 1577854800000 ) . unwrap( ) ,
150- NaiveDate :: from_ymd_opt( 1970 , 1 , 1 ) . unwrap( ) ,
151- )
223+ . unwrap( )
152224 ) ;
153- assert_eq ! (
154- get_row( 1 ) ,
155- (
156- 2 ,
157- 2.5 ,
158- Decimal :: from( 22 ) ,
159- "b" ,
160- "bb" ,
161- NaiveDate :: from_ymd_opt( 2020 , 2 , 1 )
225+ assert ! ( row. 6 == DateTime :: from_timestamp_millis( 1577854800000 ) . unwrap( ) ) ;
226+ assert ! ( row. 7 == NaiveDate :: from_ymd_opt( 1970 , 1 , 1 ) . unwrap( ) ) ;
227+ assert ! ( row. 8 == NaiveTime :: from_hms_opt( 0 , 0 , 30 ) . unwrap( ) ) ;
228+ assert ! ( row. 9 == NaiveTime :: from_hms_opt( 0 , 0 , 30 ) . unwrap( ) ) ;
229+ assert ! ( row. 10 == NaiveTime :: from_hms_opt( 0 , 0 , 30 ) . unwrap( ) ) ;
230+ assert ! ( row. 11 == NaiveTime :: from_hms_opt( 0 , 0 , 30 ) . unwrap( ) ) ;
231+ assert ! ( row. 12 == DurationWrapper ( Duration :: seconds( 3 ) ) ) ;
232+ assert ! ( row. 13 == DurationWrapper ( Duration :: seconds( 3 ) ) ) ;
233+ assert ! ( row. 14 == DurationWrapper ( Duration :: seconds( 3 ) ) ) ;
234+ assert ! ( row. 15 == DurationWrapper ( Duration :: seconds( 3 ) ) ) ;
235+
236+ let row = get_row ( 1 ) ;
237+ assert ! ( row. 0 == 2 ) ;
238+ assert ! ( row. 1 == 2.5 ) ;
239+ assert ! ( row. 2 == Decimal :: from( 22 ) ) ;
240+ assert ! ( row. 3 == "b" ) ;
241+ assert ! ( row. 4 == "bb" ) ;
242+ assert ! (
243+ row. 5
244+ == NaiveDate :: from_ymd_opt( 2020 , 2 , 1 )
162245 . unwrap( )
163246 . and_hms_opt( 0 , 0 , 0 )
164- . unwrap( ) ,
165- DateTime :: from_timestamp_millis( 1580533200000 ) . unwrap( ) ,
166- NaiveDate :: from_ymd_opt( 1970 , 1 , 2 ) . unwrap( )
167- )
247+ . unwrap( )
168248 ) ;
169- assert_eq ! (
170- get_row( 2 ) ,
171- (
172- 3 ,
173- 3.5 ,
174- Decimal :: from( 33 ) ,
175- "c" ,
176- "cc" ,
177- NaiveDate :: from_ymd_opt( 2020 , 3 , 1 )
249+ assert ! ( row. 6 == DateTime :: from_timestamp_millis( 1580533200000 ) . unwrap( ) ) ;
250+ assert ! ( row. 7 == NaiveDate :: from_ymd_opt( 1970 , 1 , 2 ) . unwrap( ) ) ;
251+ assert ! ( row. 8 == NaiveTime :: from_hms_opt( 0 , 1 , 0 ) . unwrap( ) ) ;
252+ assert ! ( row. 9 == NaiveTime :: from_hms_opt( 0 , 1 , 0 ) . unwrap( ) ) ;
253+ assert ! ( row. 10 == NaiveTime :: from_hms_opt( 0 , 1 , 0 ) . unwrap( ) ) ;
254+ assert ! ( row. 11 == NaiveTime :: from_hms_opt( 0 , 1 , 0 ) . unwrap( ) ) ;
255+ assert ! ( row. 12 == DurationWrapper ( Duration :: seconds( 6 ) ) ) ;
256+ assert ! ( row. 13 == DurationWrapper ( Duration :: seconds( 6 ) ) ) ;
257+ assert ! ( row. 14 == DurationWrapper ( Duration :: seconds( 6 ) ) ) ;
258+ assert ! ( row. 15 == DurationWrapper ( Duration :: seconds( 6 ) ) ) ;
259+
260+ let row = get_row ( 2 ) ;
261+ assert ! ( row. 0 == 3 ) ;
262+ assert ! ( row. 1 == 3.5 ) ;
263+ assert ! ( row. 2 == Decimal :: from( 33 ) ) ;
264+ assert ! ( row. 3 == "c" ) ;
265+ assert ! ( row. 4 == "cc" ) ;
266+ assert ! (
267+ row. 5
268+ == NaiveDate :: from_ymd_opt( 2020 , 3 , 1 )
178269 . unwrap( )
179270 . and_hms_opt( 0 , 0 , 0 )
180- . unwrap( ) ,
181- DateTime :: from_timestamp_millis( 1583038800000 ) . unwrap( ) ,
182- NaiveDate :: from_ymd_opt( 1970 , 1 , 3 ) . unwrap( )
183- )
271+ . unwrap( )
184272 ) ;
273+ assert ! ( row. 6 == DateTime :: from_timestamp_millis( 1583038800000 ) . unwrap( ) ) ;
274+ assert ! ( row. 7 == NaiveDate :: from_ymd_opt( 1970 , 1 , 3 ) . unwrap( ) ) ;
275+ assert ! ( row. 8 == NaiveTime :: from_hms_opt( 0 , 1 , 30 ) . unwrap( ) ) ;
276+ assert ! ( row. 9 == NaiveTime :: from_hms_opt( 0 , 1 , 30 ) . unwrap( ) ) ;
277+ assert ! ( row. 10 == NaiveTime :: from_hms_opt( 0 , 1 , 30 ) . unwrap( ) ) ;
278+ assert ! ( row. 11 == NaiveTime :: from_hms_opt( 0 , 1 , 30 ) . unwrap( ) ) ;
279+ assert ! ( row. 12 == DurationWrapper ( Duration :: seconds( 9 ) ) ) ;
280+ assert ! ( row. 13 == DurationWrapper ( Duration :: seconds( 9 ) ) ) ;
281+ assert ! ( row. 14 == DurationWrapper ( Duration :: seconds( 9 ) ) ) ;
282+ assert ! ( row. 15 == DurationWrapper ( Duration :: seconds( 9 ) ) ) ;
185283}
0 commit comments