Skip to content

Commit ec82909

Browse files
committed
implement transaction support
1 parent 63b3379 commit ec82909

File tree

5 files changed

+210
-11
lines changed

5 files changed

+210
-11
lines changed

conn.go

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package redshiftdatasqldriver
22

33
import (
44
"context"
5+
"database/sql"
56
"database/sql/driver"
67
"fmt"
78
"time"
@@ -16,6 +17,11 @@ type redshiftDataConn struct {
1617
cfg *RedshiftDataConfig
1718
aliveCh chan struct{}
1819
isClosed bool
20+
21+
inTx bool
22+
txOpts driver.TxOptions
23+
sqls []string
24+
delayedResult []*redshiftDataDelayedResult
1925
}
2026

2127
func newConn(client RedshiftDataClient, cfg *RedshiftDataConfig) *redshiftDataConn {
@@ -44,16 +50,80 @@ func (conn *redshiftDataConn) Close() error {
4450
}
4551

4652
func (conn *redshiftDataConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
47-
return nil, fmt.Errorf("transaction %w", ErrNotSupported)
53+
if conn.inTx {
54+
return nil, ErrInTx
55+
}
56+
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
57+
return nil, fmt.Errorf("transaction isolation level change: %w", ErrNotSupported)
58+
}
59+
conn.inTx = true
60+
conn.txOpts = opts
61+
cleanup := func() error {
62+
conn.inTx = false
63+
conn.sqls = nil
64+
conn.delayedResult = nil
65+
return nil
66+
}
67+
tx := &redshiftDataTx{
68+
onRollback: func() error {
69+
if !conn.inTx {
70+
return ErrNotInTx
71+
}
72+
return cleanup()
73+
},
74+
onCommit: func() error {
75+
if !conn.inTx {
76+
return ErrNotInTx
77+
}
78+
if len(conn.sqls) == 0 {
79+
return cleanup()
80+
}
81+
if len(conn.sqls) != len(conn.delayedResult) {
82+
panic(fmt.Sprintf("sqls and delayedResult length is not match: sqls=%d delayedResult=%d", len(conn.sqls), len(conn.delayedResult)))
83+
}
84+
if len(conn.sqls) == 1 {
85+
result, err := conn.ExecContext(ctx, conn.sqls[0], []driver.NamedValue{})
86+
if err != nil {
87+
return err
88+
}
89+
if conn.delayedResult[0] != nil {
90+
conn.delayedResult[0].Result = result
91+
}
92+
return nil
93+
}
94+
input := &redshiftdata.BatchExecuteStatementInput{
95+
Sqls: append(make([]string, 0, len(conn.sqls)), conn.sqls...),
96+
}
97+
_, desc, err := conn.batchExecuteStatement(ctx, input)
98+
if err != nil {
99+
return err
100+
}
101+
for i := range input.Sqls {
102+
if i >= len(desc.SubStatements) {
103+
return fmt.Errorf("sub statement not found: %d", i)
104+
}
105+
if conn.delayedResult[i] != nil {
106+
conn.delayedResult[i].Result = newResultWithSubStatementData(desc.SubStatements[i])
107+
}
108+
}
109+
return cleanup()
110+
},
111+
}
112+
113+
return tx, nil
48114
}
49115

50116
func (conn *redshiftDataConn) Begin() (driver.Tx, error) {
51117
return conn.BeginTx(context.Background(), driver.TxOptions{})
52118
}
53119

54120
func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
121+
if conn.inTx {
122+
return nil, fmt.Errorf("query in transaction: %w", ErrNotSupported)
123+
}
124+
55125
params := &redshiftdata.ExecuteStatementInput{
56-
Sql: nullif(query),
126+
Sql: nullif(rewriteQuery(query, len(args))),
57127
Parameters: convertArgsToParameters(args),
58128
}
59129
p, output, err := conn.executeStatement(ctx, params)
@@ -65,6 +135,20 @@ func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, ar
65135
}
66136

67137
func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
138+
if conn.inTx {
139+
if len(args) > 0 {
140+
return nil, fmt.Errorf("exec with args in transaction: %w", ErrNotSupported)
141+
}
142+
if conn.txOpts.ReadOnly {
143+
return nil, fmt.Errorf("exec in read only transaction: %w", ErrNotSupported)
144+
}
145+
conn.sqls = append(conn.sqls, query)
146+
result := &redshiftDataDelayedResult{}
147+
conn.delayedResult = append(conn.delayedResult, result)
148+
debugLogger.Printf("delayedResult[%d] creaed for %q", len(conn.delayedResult)-1, query)
149+
return result, nil
150+
}
151+
68152
params := &redshiftdata.ExecuteStatementInput{
69153
Sql: nullif(rewriteQuery(query, len(args))),
70154
Parameters: convertArgsToParameters(args),
@@ -132,15 +216,7 @@ func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *reds
132216
params.WorkgroupName = conn.cfg.WorkgroupName
133217
params.SecretArn = conn.cfg.SecretsARN
134218

135-
if conn.cfg.Timeout == 0 {
136-
conn.cfg.Timeout = 15 * time.Minute
137-
}
138-
if conn.cfg.Polling == 0 {
139-
conn.cfg.Polling = 10 * time.Millisecond
140-
}
141-
ectx, cancel := context.WithTimeout(ctx, conn.cfg.Timeout)
142-
defer cancel()
143-
executeOutput, err := conn.client.ExecuteStatement(ectx, params)
219+
executeOutput, err := conn.client.ExecuteStatement(ctx, params)
144220
if err != nil {
145221
return nil, nil, fmt.Errorf("execute statement:%w", err)
146222
}
@@ -170,6 +246,45 @@ func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *reds
170246
return p, describeOutput, nil
171247
}
172248

249+
func (conn *redshiftDataConn) batchExecuteStatement(ctx context.Context, params *redshiftdata.BatchExecuteStatementInput) ([]*redshiftdata.GetStatementResultPaginator, *redshiftdata.DescribeStatementOutput, error) {
250+
params.ClusterIdentifier = conn.cfg.ClusterIdentifier
251+
params.Database = conn.cfg.Database
252+
params.DbUser = conn.cfg.DbUser
253+
params.WorkgroupName = conn.cfg.WorkgroupName
254+
params.SecretArn = conn.cfg.SecretsARN
255+
256+
batchExecuteOutput, err := conn.client.BatchExecuteStatement(ctx, params)
257+
if err != nil {
258+
return nil, nil, fmt.Errorf("execute statement:%w", err)
259+
}
260+
queryStart := time.Now()
261+
debugLogger.Printf("[%s] sucess execute statement: %d sqls", *batchExecuteOutput.Id, len(params.Sqls))
262+
describeOutput, err := conn.waitWithCancel(ctx, batchExecuteOutput.Id, queryStart)
263+
if err != nil {
264+
return nil, nil, err
265+
}
266+
if describeOutput.Status == types.StatusStringAborted {
267+
return nil, nil, fmt.Errorf("query aborted: %s", *describeOutput.Error)
268+
}
269+
if describeOutput.Status == types.StatusStringFailed {
270+
return nil, nil, fmt.Errorf("query failed: %s", *describeOutput.Error)
271+
}
272+
if describeOutput.Status != types.StatusStringFinished {
273+
return nil, nil, fmt.Errorf("query status is not finished: %s", describeOutput.Status)
274+
}
275+
debugLogger.Printf("[%s] success query: elapsed_time=%s", *batchExecuteOutput.Id, time.Since(queryStart))
276+
ps := make([]*redshiftdata.GetStatementResultPaginator, len(params.Sqls))
277+
for i, st := range describeOutput.SubStatements {
278+
if *st.HasResultSet {
279+
continue
280+
}
281+
ps[i] = redshiftdata.NewGetStatementResultPaginator(conn.client, &redshiftdata.GetStatementResultInput{
282+
Id: st.Id,
283+
})
284+
}
285+
return ps, describeOutput, nil
286+
}
287+
173288
func isFinishedStatus(status types.StatusString) bool {
174289
return status == types.StatusStringFinished || status == types.StatusStringFailed || status == types.StatusStringAborted
175290
}

driver_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,45 @@ func TestSimpleExec(t *testing.T) {
140140
`SELECT * FROM "public"."redshift_data_sql_driver_test" WHERE id = :id`,
141141
sql.Named("id", 1),
142142
)
143+
require.NoError(t, err)
143144
defer func() {
144145
require.NoError(t, rows.Close())
145146
}()
147+
require.True(t, rows.Next())
148+
var id int64
149+
var createdAt sql.NullTime
150+
require.NoError(t, rows.Scan(&id, &createdAt))
151+
require.Equal(t, int64(1), id)
152+
require.True(t, time.Until(createdAt.Time) <= time.Hour)
153+
})
154+
}
155+
156+
func TestTx(t *testing.T) {
157+
runTestsWithDB(t, dsn, func(t *testing.T, db *sql.DB) {
158+
restore := requireNoErrorLog(t)
159+
defer restore()
160+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
161+
defer cancel()
162+
tx, err := db.BeginTx(ctx, nil)
163+
require.NoError(t, err)
164+
_, err = tx.ExecContext(ctx, `DROP TABLE IF EXISTS "public"."redshift_data_sql_driver_test"`)
165+
require.NoError(t, err)
166+
_, err = tx.ExecContext(ctx, `CREATE TABLE "public"."redshift_data_sql_driver_test" (id BIGINT, created_at TIMESTAMP)`)
167+
require.NoError(t, err)
168+
result, err := tx.ExecContext(ctx, `INSERT INTO "public"."redshift_data_sql_driver_test" SELECT 1 as id, getdate() as created_at`)
169+
require.NoError(t, err)
170+
_, err = result.RowsAffected()
171+
require.ErrorIs(t, err, ErrBeforeCommit)
172+
err = tx.Commit() // BatchExecuteStatement is called here
173+
require.NoError(t, err)
174+
rows, err := db.QueryContext(ctx, `SELECT * FROM "public"."redshift_data_sql_driver_test" WHERE id = 1`)
146175
require.NoError(t, err)
176+
defer func() {
177+
require.NoError(t, rows.Close())
178+
}()
179+
rowsAffected, err := result.RowsAffected()
180+
require.NoError(t, err)
181+
require.Equal(t, int64(1), rowsAffected)
147182
require.True(t, rows.Next())
148183
var id int64
149184
var createdAt sql.NullTime

errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,7 @@ var (
66
ErrNotSupported = errors.New("not supported")
77
ErrDSNEmpty = errors.New("dsn is empty")
88
ErrConnClosed = errors.New("connection closed")
9+
ErrBeforeCommit = errors.New("transaction is not committed")
10+
ErrNotInTx = errors.New("not in transaction")
11+
ErrInTx = errors.New("already in transaction")
912
)

result.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package redshiftdatasqldriver
22

33
import (
4+
"database/sql/driver"
45
"fmt"
56

67
"github.com/aws/aws-sdk-go-v2/service/redshiftdata"
8+
"github.com/aws/aws-sdk-go-v2/service/redshiftdata/types"
79
)
810

911
type redshiftDataResult struct {
@@ -16,10 +18,38 @@ func newResult(output *redshiftdata.DescribeStatementOutput) *redshiftDataResult
1618
affectedRows: output.ResultRows,
1719
}
1820
}
21+
22+
func newResultWithSubStatementData(st types.SubStatementData) *redshiftDataResult {
23+
debugLogger.Printf("[%s] create result", coalesce(st.Id))
24+
return &redshiftDataResult{
25+
affectedRows: st.ResultRows,
26+
}
27+
}
28+
1929
func (r *redshiftDataResult) LastInsertId() (int64, error) {
2030
return 0, fmt.Errorf("LastInsertId %w", ErrNotSupported)
2131
}
2232

2333
func (r *redshiftDataResult) RowsAffected() (int64, error) {
2434
return r.affectedRows, nil
2535
}
36+
37+
type redshiftDataDelayedResult struct {
38+
driver.Result
39+
}
40+
41+
func (r *redshiftDataDelayedResult) LastInsertId() (int64, error) {
42+
debugLogger.Printf("delayed result LastInsertId called")
43+
if r.Result != nil {
44+
return r.Result.LastInsertId()
45+
}
46+
return 0, ErrBeforeCommit
47+
}
48+
49+
func (r *redshiftDataDelayedResult) RowsAffected() (int64, error) {
50+
debugLogger.Printf("delayed result RowsAffected called")
51+
if r.Result != nil {
52+
return r.Result.RowsAffected()
53+
}
54+
return 0, ErrBeforeCommit
55+
}

tx.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package redshiftdatasqldriver
2+
3+
type redshiftDataTx struct {
4+
onCommit func() error
5+
onRollback func() error
6+
}
7+
8+
func (tx *redshiftDataTx) Commit() error {
9+
debugLogger.Printf("tx commit called")
10+
return tx.onCommit()
11+
}
12+
13+
func (tx *redshiftDataTx) Rollback() error {
14+
debugLogger.Printf("tx rollback called")
15+
return tx.onRollback()
16+
}

0 commit comments

Comments
 (0)