Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ func (drv) OpenConnector(name string) (driver.Connector, error) {
//
// The tracer may be nil.
func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer) driver.Connector {
return &connector{
name: sqliteURI,
tracer: tracer,
connInitFunc: connInitFunc,
}
return ConnectorWithOpts(sqliteURI, connInitFunc, WithTracer(tracer))
}

// ConnectorWithLogger returns a [driver.Connector] for the given connection
Expand All @@ -137,11 +133,48 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
//
// The tracer may also be nil.
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
return &connector{
return ConnectorWithOpts(sqliteURI, connInitFunc, WithTracer(tracer), WithConnLogger(makeLogger))
}

// ConnectorWithOpts returns a [driver.Connector] for the given connection parameters, optionally
// configured with one or more [ConnectorOpt]s.
func ConnectorWithOpts(sqliteURI string, connInitFunc ConnInitFunc, opts ...ConnectorOpt) driver.Connector {
p := &connector{
name: sqliteURI,
tracer: tracer,
makeLogger: makeLogger,
connInitFunc: connInitFunc,
openFlags: sqliteh.OpenFlagsDefault, // default flags unless [WithOpenFlags] option is used
}
for _, opt := range opts {
opt(p)
}
return p
}

// ConnectorOpt is an option to [ConnectorWithOpts].
type ConnectorOpt func(p *connector)

// WithTracer returns a [ConnectorOpt] that configures the [driver.Connector]
// to enable tracing on new connections using the given [sqliteh.Tracer].
func WithTracer(tracer sqliteh.Tracer) ConnectorOpt {
return func(p *connector) {
p.tracer = tracer
}
}

// WithConnLogger returns a [ConnectorOpt] that configures the [driver.Connector]
// to use a [ConnLogger] returned by the provided makeLogger when opening new
// connections.
func WithConnLogger(makeLogger func() ConnLogger) ConnectorOpt {
return func(p *connector) {
p.makeLogger = makeLogger
}
}

// WithOpenFlags returns a [ConnectorOpt] that configures the [driver.Connector]
// to use the given [sqliteh.OpenFlags] when opening new connections.
func WithOpenFlags(openFlags sqliteh.OpenFlags) ConnectorOpt {
return func(p *connector) {
p.openFlags = openFlags
}
}

Expand All @@ -150,11 +183,19 @@ type connector struct {
tracer sqliteh.Tracer // or nil
makeLogger func() ConnLogger // or nil
connInitFunc ConnInitFunc
openFlags sqliteh.OpenFlags
}

// Driver implements [driver.Connector.Driver].
func (p *connector) Driver() driver.Driver { return drv{} }

// Connect implements [driver.Connector.Connect]
func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
db, err := Open(p.name, sqliteh.OpenFlagsDefault, "")
return p.ConnectFlags(ctx, sqliteh.OpenFlagsDefault)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use the passed flags here instead of the default?

Suggested change
return p.ConnectFlags(ctx, sqliteh.OpenFlagsDefault)
return p.ConnectFlags(ctx, p.openFlags)

}

func (p *connector) ConnectFlags(ctx context.Context, flags sqliteh.OpenFlags) (driver.Conn, error) {
db, err := Open(p.name, flags, "")
if err != nil {
if ec, ok := err.(sqliteh.ErrCode); ok {
e := &Error{
Expand Down
Loading