@@ -2,6 +2,7 @@ package cmd
22
33import (
44 "context"
5+ "database/sql"
56 "errors"
67 "fmt"
78 "io"
@@ -11,6 +12,7 @@ import (
1112 "strings"
1213 "time"
1314
15+ _ "github.com/go-sql-driver/mysql"
1416 "github.com/google/cel-go/cel"
1517 "github.com/google/cel-go/ext"
1618 "github.com/jackc/pgx/v5"
@@ -140,17 +142,6 @@ func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) err
140142 return nil
141143}
142144
143- type checker struct {
144- Checks map [string ]cel.Program
145- Conf * config.Config
146- Dbenv * cel.Env
147- Dir string
148- Env * cel.Env
149- Envmap map [string ]string
150- Msgs map [string ]string
151- Stderr io.Writer
152- }
153-
154145// Determine if a query can be prepared based on the engine and the statement
155146// type.
156147func prepareable (sql config.SQL , raw * ast.RawStmt ) bool {
@@ -169,92 +160,151 @@ func prepareable(sql config.SQL, raw *ast.RawStmt) bool {
169160 return false
170161 }
171162 }
163+ // Almost all statements in MySQL can be prepared, so I'm just going to assume they can be
164+ // https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
165+ if sql .Engine == config .EngineMySQL {
166+ return true
167+ }
172168 return false
173169}
174170
175- func (c * checker ) checkSQL (ctx context.Context , sql config.SQL ) error {
171+ type preparer interface {
172+ Prepare (context.Context , string , string ) error
173+ }
174+
175+ type pgxPreparer struct {
176+ c * pgx.Conn
177+ }
178+
179+ func (p * pgxPreparer ) Prepare (ctx context.Context , name , query string ) error {
180+ _ , err := p .c .Prepare (ctx , name , query )
181+ return err
182+ }
183+
184+ type dbPreparer struct {
185+ db * sql.DB
186+ }
187+
188+ func (p * dbPreparer ) Prepare (ctx context.Context , name , query string ) error {
189+ _ , err := p .db .PrepareContext (ctx , query )
190+ return err
191+ }
192+
193+ type checker struct {
194+ Checks map [string ]cel.Program
195+ Conf * config.Config
196+ Dbenv * cel.Env
197+ Dir string
198+ Env * cel.Env
199+ Envmap map [string ]string
200+ Msgs map [string ]string
201+ Stderr io.Writer
202+ }
203+
204+ func (c * checker ) DSN (expr string ) (string , error ) {
205+ ast , issues := c .Dbenv .Compile (expr )
206+ if issues != nil && issues .Err () != nil {
207+ return "" , fmt .Errorf ("type-check error: database url %s" , issues .Err ())
208+ }
209+ prg , err := c .Dbenv .Program (ast )
210+ if err != nil {
211+ return "" , fmt .Errorf ("program construction error: database url %s" , err )
212+ }
213+ // Populate the environment variable map if it is empty
214+ if len (c .Envmap ) == 0 {
215+ for _ , e := range os .Environ () {
216+ k , v , _ := strings .Cut (e , "=" )
217+ c .Envmap [k ] = v
218+ }
219+ }
220+ out , _ , err := prg .Eval (map [string ]any {
221+ "env" : c .Envmap ,
222+ })
223+ if err != nil {
224+ return "" , fmt .Errorf ("expression error: %s" , err )
225+ }
226+ dsn , ok := out .Value ().(string )
227+ if ! ok {
228+ return "" , fmt .Errorf ("expression returned non-string value: %v" , out .Value ())
229+ }
230+ return dsn , nil
231+ }
232+
233+ func (c * checker ) checkSQL (ctx context.Context , s config.SQL ) error {
176234 // TODO: Create a separate function for this logic so we can
177- combo := config .Combine (* c .Conf , sql )
235+ combo := config .Combine (* c .Conf , s )
178236
179237 // TODO: This feels like a hack that will bite us later
180- joined := make ([]string , 0 , len (sql .Schema ))
181- for _ , s := range sql .Schema {
238+ joined := make ([]string , 0 , len (s .Schema ))
239+ for _ , s := range s .Schema {
182240 joined = append (joined , filepath .Join (c .Dir , s ))
183241 }
184- sql .Schema = joined
242+ s .Schema = joined
185243
186- joined = make ([]string , 0 , len (sql .Queries ))
187- for _ , q := range sql .Queries {
244+ joined = make ([]string , 0 , len (s .Queries ))
245+ for _ , q := range s .Queries {
188246 joined = append (joined , filepath .Join (c .Dir , q ))
189247 }
190- sql .Queries = joined
248+ s .Queries = joined
191249
192250 var name string
193251 parseOpts := opts.Parser {
194252 Debug : debug .Debug ,
195253 }
196254
197- result , failed := parse (ctx , name , c .Dir , sql , combo , parseOpts , c .Stderr )
255+ result , failed := parse (ctx , name , c .Dir , s , combo , parseOpts , c .Stderr )
198256 if failed {
199257 return ErrFailedChecks
200258 }
201259
202260 // TODO: Add MySQL support
203- var pgconn * pgx.Conn
204- if sql .Engine == config .EnginePostgreSQL && sql .Database != nil {
205- ast , issues := c .Dbenv .Compile (sql .Database .URL )
206- if issues != nil && issues .Err () != nil {
207- return fmt .Errorf ("type-check error: database url %s" , issues .Err ())
208- }
209- prg , err := c .Dbenv .Program (ast )
261+ var prep preparer
262+ if s .Database != nil {
263+ dburl , err := c .DSN (s .Database .URL )
210264 if err != nil {
211- return fmt . Errorf ( "program construction error: database url %s" , err )
265+ return err
212266 }
213- // Populate the environment variable map if it is empty
214- if len ( c . Envmap ) == 0 {
215- for _ , e := range os . Environ () {
216- k , v , _ := strings . Cut ( e , "=" )
217- c . Envmap [ k ] = v
267+ switch s . Engine {
268+ case config . EnginePostgreSQL :
269+ conn , err := pgx . Connect ( ctx , dburl )
270+ if err != nil {
271+ return fmt . Errorf ( "database: connection error: %s" , err )
218272 }
273+ if err := conn .Ping (ctx ); err != nil {
274+ return fmt .Errorf ("database: connection error: %s" , err )
275+ }
276+ defer conn .Close (ctx )
277+ prep = & pgxPreparer {conn }
278+ case config .EngineMySQL :
279+ db , err := sql .Open ("mysql" , dburl )
280+ if err != nil {
281+ return fmt .Errorf ("database: connection error: %s" , err )
282+ }
283+ if err := db .PingContext (ctx ); err != nil {
284+ return fmt .Errorf ("database: connection error: %s" , err )
285+ }
286+ defer db .Close ()
287+ prep = & dbPreparer {db }
288+ default :
289+ return fmt .Errorf ("unsupported database url: %s" , s .Engine )
219290 }
220- out , _ , err := prg .Eval (map [string ]any {
221- "env" : c .Envmap ,
222- })
223- if err != nil {
224- return fmt .Errorf ("expression error: %s" , err )
225- }
226- dburl , ok := out .Value ().(string )
227- if ! ok {
228- return fmt .Errorf ("expression returned non-string value: %v" , out .Value ())
229- }
230- fmt .Println ("URL" , dburl )
231- conn , err := pgx .Connect (ctx , dburl )
232- if err != nil {
233- return fmt .Errorf ("database: connection error: %s" , err )
234- }
235- if err := conn .Ping (ctx ); err != nil {
236- return fmt .Errorf ("database: connection error: %s" , err )
237- }
238- defer conn .Close (ctx )
239- pgconn = conn
240291 }
241292
242293 errored := false
243294 req := codeGenRequest (result , combo )
244295 cfg := vetConfig (req )
245296 for i , query := range req .Queries {
246297 original := result .Queries [i ]
247- if pgconn != nil && prepareable (sql , original .RawStmt ) {
298+ if prep != nil && prepareable (s , original .RawStmt ) {
248299 name := fmt .Sprintf ("sqlc_vet_%d_%d" , time .Now ().Unix (), i )
249- _ , err := pgconn .Prepare (ctx , name , query .Text )
250- if err != nil {
300+ if err := prep .Prepare (ctx , name , query .Text ); err != nil {
251301 fmt .Fprintf (c .Stderr , "%s: error preparing %s: %s\n " , query .Filename , query .Name , err )
252302 errored = true
253303 continue
254304 }
255305 }
256306 q := vetQuery (query )
257- for _ , name := range sql .Rules {
307+ for _ , name := range s .Rules {
258308 prg , ok := c .Checks [name ]
259309 if ! ok {
260310 return fmt .Errorf ("type-check error: a check with the name '%s' does not exist" , name )
0 commit comments