Skip to content

Commit a88eeb3

Browse files
authored
feat(connection): add explicit connection parameter support
feat(connection): add explicit connection parameter support
2 parents bb5a076 + c669b41 commit a88eeb3

File tree

6 files changed

+423
-49
lines changed

6 files changed

+423
-49
lines changed

integration_test.go

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,9 @@ func setupTestContainer(t *testing.T) (*postgres.PostgresContainer, string, func
7575
return postgresContainer, connStr, cleanup
7676
}
7777

78-
func setupTestDatabase(t *testing.T) (*sql.DB, func()) {
78+
func setupTestDatabase(t *testing.T) (*sql.DB, string, func()) {
7979
_, connectionString, containerCleanup := setupTestContainer(t)
8080

81-
// Set environment variable for the app to use
82-
os.Setenv("POSTGRES_URL", connectionString)
83-
8481
// Connect to PostgreSQL
8582
db, err := sql.Open("postgres", connectionString)
8683
require.NoError(t, err)
@@ -134,56 +131,63 @@ func setupTestDatabase(t *testing.T) (*sql.DB, func()) {
134131
cleanup := func() {
135132
_, _ = db.ExecContext(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", testSchema))
136133
db.Close()
137-
os.Unsetenv("POSTGRES_URL")
138134
containerCleanup() // Clean up container
139135
}
140136

141-
return db, cleanup
137+
return db, connectionString, cleanup
142138
}
143139

144140
func TestIntegration_App_Connect(t *testing.T) {
145141
_, connectionString, cleanup := setupTestContainer(t)
146142
defer cleanup()
147143

148-
// Set environment variable for connection
149-
os.Setenv("POSTGRES_URL", connectionString)
150-
defer os.Unsetenv("POSTGRES_URL")
151-
152144
appInstance, err := app.New()
153145
require.NoError(t, err)
154146
defer appInstance.Disconnect()
155147

148+
// Test explicit connection with connection string
149+
err = appInstance.Connect(connectionString)
150+
require.NoError(t, err)
151+
156152
// Test that we can get current database
157153
dbName, err := appInstance.GetCurrentDatabase()
158154
assert.NoError(t, err)
159155
assert.NotEmpty(t, dbName)
160156
}
161157

162-
func TestIntegration_App_ConnectWithDatabaseURL(t *testing.T) {
158+
func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) {
163159
_, connectionString, cleanup := setupTestContainer(t)
164160
defer cleanup()
165161

166-
// Test with DATABASE_URL environment variable
167-
os.Setenv("DATABASE_URL", connectionString)
168-
defer os.Unsetenv("DATABASE_URL")
162+
// Test with POSTGRES_URL environment variable (backward compatibility via tryConnect)
163+
os.Setenv("POSTGRES_URL", connectionString)
164+
defer os.Unsetenv("POSTGRES_URL")
169165

170166
appInstance, err := app.New()
171167
require.NoError(t, err)
172168
defer appInstance.Disconnect()
173169

174-
// Test that connection works
170+
// Explicitly call ensureConnection which will trigger tryConnect() fallback
175171
err = appInstance.ValidateConnection()
176172
assert.NoError(t, err)
173+
174+
// Verify connection works
175+
dbName, err := appInstance.GetCurrentDatabase()
176+
assert.NoError(t, err)
177+
assert.NotEmpty(t, dbName)
177178
}
178179

179180
func TestIntegration_App_ListDatabases(t *testing.T) {
180-
_, cleanup := setupTestDatabase(t)
181+
_, connectionString, cleanup := setupTestDatabase(t)
181182
defer cleanup()
182183

183184
appInstance, err := app.New()
184185
require.NoError(t, err)
185186
defer appInstance.Disconnect()
186187

188+
err = appInstance.Connect(connectionString)
189+
require.NoError(t, err)
190+
187191
databases, err := appInstance.ListDatabases()
188192
assert.NoError(t, err)
189193
assert.NotEmpty(t, databases)
@@ -201,13 +205,16 @@ func TestIntegration_App_ListDatabases(t *testing.T) {
201205
}
202206

203207
func TestIntegration_App_ListSchemas(t *testing.T) {
204-
_, cleanup := setupTestDatabase(t)
208+
_, connectionString, cleanup := setupTestDatabase(t)
205209
defer cleanup()
206210

207211
appInstance, err := app.New()
208212
require.NoError(t, err)
209213
defer appInstance.Disconnect()
210214

215+
err = appInstance.Connect(connectionString)
216+
require.NoError(t, err)
217+
211218
schemas, err := appInstance.ListSchemas()
212219
assert.NoError(t, err)
213220
assert.NotEmpty(t, schemas)
@@ -223,13 +230,16 @@ func TestIntegration_App_ListSchemas(t *testing.T) {
223230
}
224231

225232
func TestIntegration_App_ListTables(t *testing.T) {
226-
_, cleanup := setupTestDatabase(t)
233+
_, connectionString, cleanup := setupTestDatabase(t)
227234
defer cleanup()
228235

229236
appInstance, err := app.New()
230237
require.NoError(t, err)
231238
defer appInstance.Disconnect()
232239

240+
err = appInstance.Connect(connectionString)
241+
require.NoError(t, err)
242+
233243
// List tables in test schema
234244
listOpts := &app.ListTablesOptions{
235245
Schema: "test_mcp_schema",
@@ -253,13 +263,16 @@ func TestIntegration_App_ListTables(t *testing.T) {
253263
}
254264

255265
func TestIntegration_App_ListTablesWithSize(t *testing.T) {
256-
_, cleanup := setupTestDatabase(t)
266+
_, connectionString, cleanup := setupTestDatabase(t)
257267
defer cleanup()
258268

259269
appInstance, err := app.New()
260270
require.NoError(t, err)
261271
defer appInstance.Disconnect()
262272

273+
err = appInstance.Connect(connectionString)
274+
require.NoError(t, err)
275+
263276
// List tables with size information
264277
listOpts := &app.ListTablesOptions{
265278
Schema: "test_mcp_schema",
@@ -280,13 +293,16 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
280293
}
281294

282295
func TestIntegration_App_DescribeTable(t *testing.T) {
283-
_, cleanup := setupTestDatabase(t)
296+
_, connectionString, cleanup := setupTestDatabase(t)
284297
defer cleanup()
285298

286299
appInstance, err := app.New()
287300
require.NoError(t, err)
288301
defer appInstance.Disconnect()
289302

303+
err = appInstance.Connect(connectionString)
304+
require.NoError(t, err)
305+
290306
columns, err := appInstance.DescribeTable("test_mcp_schema", "test_users")
291307
assert.NoError(t, err)
292308
assert.NotEmpty(t, columns)
@@ -322,13 +338,16 @@ func TestIntegration_App_DescribeTable(t *testing.T) {
322338
}
323339

324340
func TestIntegration_App_ExecuteQuery(t *testing.T) {
325-
_, cleanup := setupTestDatabase(t)
341+
_, connectionString, cleanup := setupTestDatabase(t)
326342
defer cleanup()
327343

328344
appInstance, err := app.New()
329345
require.NoError(t, err)
330346
defer appInstance.Disconnect()
331347

348+
err = appInstance.Connect(connectionString)
349+
require.NoError(t, err)
350+
332351
// Test simple SELECT query
333352
queryOpts := &app.ExecuteQueryOptions{
334353
Query: "SELECT id, name, email FROM test_mcp_schema.test_users WHERE active = true ORDER BY id",
@@ -352,13 +371,16 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) {
352371
}
353372

354373
func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
355-
_, cleanup := setupTestDatabase(t)
374+
_, connectionString, cleanup := setupTestDatabase(t)
356375
defer cleanup()
357376

358377
appInstance, err := app.New()
359378
require.NoError(t, err)
360379
defer appInstance.Disconnect()
361380

381+
err = appInstance.Connect(connectionString)
382+
require.NoError(t, err)
383+
362384
// Test query with limit
363385
queryOpts := &app.ExecuteQueryOptions{
364386
Query: "SELECT * FROM test_mcp_schema.test_users ORDER BY id",
@@ -375,13 +397,16 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
375397
}
376398

377399
func TestIntegration_App_ListIndexes(t *testing.T) {
378-
_, cleanup := setupTestDatabase(t)
400+
_, connectionString, cleanup := setupTestDatabase(t)
379401
defer cleanup()
380402

381403
appInstance, err := app.New()
382404
require.NoError(t, err)
383405
defer appInstance.Disconnect()
384406

407+
err = appInstance.Connect(connectionString)
408+
require.NoError(t, err)
409+
385410
indexes, err := appInstance.ListIndexes("test_mcp_schema", "test_users")
386411
assert.NoError(t, err)
387412
assert.NotEmpty(t, indexes)
@@ -412,13 +437,16 @@ func TestIntegration_App_ListIndexes(t *testing.T) {
412437
}
413438

414439
func TestIntegration_App_ExplainQuery(t *testing.T) {
415-
_, cleanup := setupTestDatabase(t)
440+
_, connectionString, cleanup := setupTestDatabase(t)
416441
defer cleanup()
417442

418443
appInstance, err := app.New()
419444
require.NoError(t, err)
420445
defer appInstance.Disconnect()
421446

447+
err = appInstance.Connect(connectionString)
448+
require.NoError(t, err)
449+
422450
// Test EXPLAIN query
423451
result, err := appInstance.ExplainQuery("SELECT * FROM test_mcp_schema.test_users WHERE active = true")
424452
require.NoError(t, err)
@@ -430,13 +458,16 @@ func TestIntegration_App_ExplainQuery(t *testing.T) {
430458
}
431459

432460
func TestIntegration_App_GetTableStats(t *testing.T) {
433-
_, cleanup := setupTestDatabase(t)
461+
_, connectionString, cleanup := setupTestDatabase(t)
434462
defer cleanup()
435463

436464
appInstance, err := app.New()
437465
require.NoError(t, err)
438466
defer appInstance.Disconnect()
439467

468+
err = appInstance.Connect(connectionString)
469+
require.NoError(t, err)
470+
440471
stats, err := appInstance.GetTableStats("test_mcp_schema", "test_users")
441472
assert.NoError(t, err)
442473
assert.NotNil(t, stats)
@@ -512,13 +543,16 @@ func BenchmarkIntegration_ListTables(b *testing.B) {
512543

513544
// Use a testing.T wrapper for setup functions
514545
t := &testing.T{}
515-
_, cleanup := setupTestDatabase(t)
546+
_, connectionString, cleanup := setupTestDatabase(t)
516547
defer cleanup()
517548

518549
appInstance, err := app.New()
519550
require.NoError(b, err)
520551
defer appInstance.Disconnect()
521552

553+
err = appInstance.Connect(connectionString)
554+
require.NoError(b, err)
555+
522556
listOpts := &app.ListTablesOptions{
523557
Schema: "test_mcp_schema",
524558
}
@@ -539,13 +573,16 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {
539573

540574
// Use a testing.T wrapper for setup functions
541575
t := &testing.T{}
542-
_, cleanup := setupTestDatabase(t)
576+
_, connectionString, cleanup := setupTestDatabase(t)
543577
defer cleanup()
544578

545579
appInstance, err := app.New()
546580
require.NoError(b, err)
547581
defer appInstance.Disconnect()
548582

583+
err = appInstance.Connect(connectionString)
584+
require.NoError(b, err)
585+
549586
queryOpts := &app.ExecuteQueryOptions{
550587
Query: "SELECT COUNT(*) FROM test_mcp_schema.test_users",
551588
}

internal/app/app.go

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ type App struct {
3232
logger *slog.Logger
3333
}
3434

35-
// New creates a new App instance and attempts to connect to the database.
35+
// New creates a new App instance without establishing a connection.
36+
// Use Connect() method or connect_database tool to establish connection.
3637
func New() (*App, error) {
3738
app := &App{
3839
client: NewPostgreSQLClient(),
3940
logger: logger.NewLogger("info"),
4041
}
4142

42-
// Attempt initial connection
43-
if err := app.tryConnect(); err != nil {
44-
app.logger.Warn("Could not connect to database on startup, will retry on first tool request", "error", err)
45-
}
43+
// Note: Connection is now explicit via Connect() or connect_database tool
44+
// Environment variables are still supported as fallback via tryConnect()
4645

4746
return app, nil
4847
}
@@ -52,6 +51,34 @@ func (a *App) SetLogger(logger *slog.Logger) {
5251
a.logger = logger
5352
}
5453

54+
// Connect establishes a database connection with the provided connection string.
55+
// If a connection already exists, it will be closed before establishing a new one.
56+
func (a *App) Connect(connectionString string) error {
57+
if connectionString == "" {
58+
return ErrNoConnectionString
59+
}
60+
61+
// Close existing connection if any
62+
if a.client != nil {
63+
if err := a.client.Ping(); err == nil {
64+
// Connection exists and is active, close it first
65+
if closeErr := a.client.Close(); closeErr != nil {
66+
a.logger.Warn("Failed to close existing connection", "error", closeErr)
67+
}
68+
}
69+
}
70+
71+
a.logger.Debug("Connecting to PostgreSQL database")
72+
73+
if err := a.client.Connect(connectionString); err != nil {
74+
a.logger.Error("Failed to connect to database", "error", err)
75+
return fmt.Errorf("failed to connect: %w", err)
76+
}
77+
78+
a.logger.Info("Successfully connected to PostgreSQL database")
79+
return nil
80+
}
81+
5582
// Disconnect closes the database connection.
5683
func (a *App) Disconnect() error {
5784
if a.client != nil {
@@ -280,9 +307,10 @@ func (a *App) ValidateConnection() error {
280307
return a.ensureConnection()
281308
}
282309

283-
// tryConnect attempts to connect to the database using environment variables.
310+
// tryConnect attempts to connect using environment variables as a fallback mechanism.
311+
// Returns ErrNoConnectionString if no environment variables are set.
284312
func (a *App) tryConnect() error {
285-
// Try environment variables
313+
// Try environment variables as fallback
286314
connectionString := os.Getenv("POSTGRES_URL")
287315
if connectionString == "" {
288316
connectionString = os.Getenv("DATABASE_URL")
@@ -292,15 +320,7 @@ func (a *App) tryConnect() error {
292320
return ErrNoConnectionString
293321
}
294322

295-
a.logger.Debug("Connecting to PostgreSQL database")
296-
297-
if err := a.client.Connect(connectionString); err != nil {
298-
a.logger.Error("Failed to connect to database", "error", err)
299-
return fmt.Errorf("failed to connect: %w", err)
300-
}
301-
302-
a.logger.Info("Successfully connected to PostgreSQL database")
303-
return nil
323+
return a.Connect(connectionString)
304324
}
305325

306326
// ensureConnection checks if the database connection is valid and attempts to reconnect if needed.

0 commit comments

Comments
 (0)