@@ -2,6 +2,7 @@ package redshiftdatasqldriver
22
33import (
44 "context"
5+ "database/sql"
56 "database/sql/driver"
67 "fmt"
78 "time"
@@ -16,6 +17,11 @@ type redshiftDataConn struct {
1617 cfg * RedshiftDataConfig
1718 aliveCh chan struct {}
1819 isClosed bool
20+
21+ inTx bool
22+ txOpts driver.TxOptions
23+ sqls []string
24+ delayedResult []* redshiftDataDelayedResult
1925}
2026
2127func newConn (client RedshiftDataClient , cfg * RedshiftDataConfig ) * redshiftDataConn {
@@ -44,16 +50,80 @@ func (conn *redshiftDataConn) Close() error {
4450}
4551
4652func (conn * redshiftDataConn ) BeginTx (ctx context.Context , opts driver.TxOptions ) (driver.Tx , error ) {
47- return nil , fmt .Errorf ("transaction %w" , ErrNotSupported )
53+ if conn .inTx {
54+ return nil , ErrInTx
55+ }
56+ if opts .Isolation != driver .IsolationLevel (sql .LevelDefault ) {
57+ return nil , fmt .Errorf ("transaction isolation level change: %w" , ErrNotSupported )
58+ }
59+ conn .inTx = true
60+ conn .txOpts = opts
61+ cleanup := func () error {
62+ conn .inTx = false
63+ conn .sqls = nil
64+ conn .delayedResult = nil
65+ return nil
66+ }
67+ tx := & redshiftDataTx {
68+ onRollback : func () error {
69+ if ! conn .inTx {
70+ return ErrNotInTx
71+ }
72+ return cleanup ()
73+ },
74+ onCommit : func () error {
75+ if ! conn .inTx {
76+ return ErrNotInTx
77+ }
78+ if len (conn .sqls ) == 0 {
79+ return cleanup ()
80+ }
81+ if len (conn .sqls ) != len (conn .delayedResult ) {
82+ panic (fmt .Sprintf ("sqls and delayedResult length is not match: sqls=%d delayedResult=%d" , len (conn .sqls ), len (conn .delayedResult )))
83+ }
84+ if len (conn .sqls ) == 1 {
85+ result , err := conn .ExecContext (ctx , conn .sqls [0 ], []driver.NamedValue {})
86+ if err != nil {
87+ return err
88+ }
89+ if conn .delayedResult [0 ] != nil {
90+ conn .delayedResult [0 ].Result = result
91+ }
92+ return nil
93+ }
94+ input := & redshiftdata.BatchExecuteStatementInput {
95+ Sqls : append (make ([]string , 0 , len (conn .sqls )), conn .sqls ... ),
96+ }
97+ _ , desc , err := conn .batchExecuteStatement (ctx , input )
98+ if err != nil {
99+ return err
100+ }
101+ for i := range input .Sqls {
102+ if i >= len (desc .SubStatements ) {
103+ return fmt .Errorf ("sub statement not found: %d" , i )
104+ }
105+ if conn .delayedResult [i ] != nil {
106+ conn .delayedResult [i ].Result = newResultWithSubStatementData (desc .SubStatements [i ])
107+ }
108+ }
109+ return cleanup ()
110+ },
111+ }
112+
113+ return tx , nil
48114}
49115
50116func (conn * redshiftDataConn ) Begin () (driver.Tx , error ) {
51117 return conn .BeginTx (context .Background (), driver.TxOptions {})
52118}
53119
54120func (conn * redshiftDataConn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
121+ if conn .inTx {
122+ return nil , fmt .Errorf ("query in transaction: %w" , ErrNotSupported )
123+ }
124+
55125 params := & redshiftdata.ExecuteStatementInput {
56- Sql : nullif (query ),
126+ Sql : nullif (rewriteQuery ( query , len ( args )) ),
57127 Parameters : convertArgsToParameters (args ),
58128 }
59129 p , output , err := conn .executeStatement (ctx , params )
@@ -65,6 +135,20 @@ func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, ar
65135}
66136
67137func (conn * redshiftDataConn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
138+ if conn .inTx {
139+ if len (args ) > 0 {
140+ return nil , fmt .Errorf ("exec with args in transaction: %w" , ErrNotSupported )
141+ }
142+ if conn .txOpts .ReadOnly {
143+ return nil , fmt .Errorf ("exec in read only transaction: %w" , ErrNotSupported )
144+ }
145+ conn .sqls = append (conn .sqls , query )
146+ result := & redshiftDataDelayedResult {}
147+ conn .delayedResult = append (conn .delayedResult , result )
148+ debugLogger .Printf ("delayedResult[%d] creaed for %q" , len (conn .delayedResult )- 1 , query )
149+ return result , nil
150+ }
151+
68152 params := & redshiftdata.ExecuteStatementInput {
69153 Sql : nullif (rewriteQuery (query , len (args ))),
70154 Parameters : convertArgsToParameters (args ),
@@ -132,15 +216,7 @@ func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *reds
132216 params .WorkgroupName = conn .cfg .WorkgroupName
133217 params .SecretArn = conn .cfg .SecretsARN
134218
135- if conn .cfg .Timeout == 0 {
136- conn .cfg .Timeout = 15 * time .Minute
137- }
138- if conn .cfg .Polling == 0 {
139- conn .cfg .Polling = 10 * time .Millisecond
140- }
141- ectx , cancel := context .WithTimeout (ctx , conn .cfg .Timeout )
142- defer cancel ()
143- executeOutput , err := conn .client .ExecuteStatement (ectx , params )
219+ executeOutput , err := conn .client .ExecuteStatement (ctx , params )
144220 if err != nil {
145221 return nil , nil , fmt .Errorf ("execute statement:%w" , err )
146222 }
@@ -170,6 +246,45 @@ func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *reds
170246 return p , describeOutput , nil
171247}
172248
249+ func (conn * redshiftDataConn ) batchExecuteStatement (ctx context.Context , params * redshiftdata.BatchExecuteStatementInput ) ([]* redshiftdata.GetStatementResultPaginator , * redshiftdata.DescribeStatementOutput , error ) {
250+ params .ClusterIdentifier = conn .cfg .ClusterIdentifier
251+ params .Database = conn .cfg .Database
252+ params .DbUser = conn .cfg .DbUser
253+ params .WorkgroupName = conn .cfg .WorkgroupName
254+ params .SecretArn = conn .cfg .SecretsARN
255+
256+ batchExecuteOutput , err := conn .client .BatchExecuteStatement (ctx , params )
257+ if err != nil {
258+ return nil , nil , fmt .Errorf ("execute statement:%w" , err )
259+ }
260+ queryStart := time .Now ()
261+ debugLogger .Printf ("[%s] sucess execute statement: %d sqls" , * batchExecuteOutput .Id , len (params .Sqls ))
262+ describeOutput , err := conn .waitWithCancel (ctx , batchExecuteOutput .Id , queryStart )
263+ if err != nil {
264+ return nil , nil , err
265+ }
266+ if describeOutput .Status == types .StatusStringAborted {
267+ return nil , nil , fmt .Errorf ("query aborted: %s" , * describeOutput .Error )
268+ }
269+ if describeOutput .Status == types .StatusStringFailed {
270+ return nil , nil , fmt .Errorf ("query failed: %s" , * describeOutput .Error )
271+ }
272+ if describeOutput .Status != types .StatusStringFinished {
273+ return nil , nil , fmt .Errorf ("query status is not finished: %s" , describeOutput .Status )
274+ }
275+ debugLogger .Printf ("[%s] success query: elapsed_time=%s" , * batchExecuteOutput .Id , time .Since (queryStart ))
276+ ps := make ([]* redshiftdata.GetStatementResultPaginator , len (params .Sqls ))
277+ for i , st := range describeOutput .SubStatements {
278+ if * st .HasResultSet {
279+ continue
280+ }
281+ ps [i ] = redshiftdata .NewGetStatementResultPaginator (conn .client , & redshiftdata.GetStatementResultInput {
282+ Id : st .Id ,
283+ })
284+ }
285+ return ps , describeOutput , nil
286+ }
287+
173288func isFinishedStatus (status types.StatusString ) bool {
174289 return status == types .StatusStringFinished || status == types .StatusStringFailed || status == types .StatusStringAborted
175290}
0 commit comments