Skip to content

Commit 192d16e

Browse files
committed
feat(context): propagate request context to database operations
Add context.Context parameter to all database operations to enable proper request cancellation, timeout handling, and request tracing. Changes: - Add context parameter to all PostgreSQLClient interface methods - Update client implementation to use provided context instead of Background() - Propagate context through App layer methods - Update MCP tool handlers to pass request context - Implement smart reconnection logic (uses Background() for reconnection) - Update all tests to work with context-enabled signatures This enables MCP clients to cancel in-flight database operations and set request-level timeouts that properly propagate to PostgreSQL. Closes #39
1 parent b4f5f71 commit 192d16e

File tree

10 files changed

+299
-257
lines changed

10 files changed

+299
-257
lines changed

integration_test.go

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,14 @@ func TestIntegration_App_Connect(t *testing.T) {
145145
require.NoError(t, err)
146146
defer appInstance.Disconnect()
147147

148+
ctx := context.Background()
149+
148150
// Test explicit connection with connection string
149-
err = appInstance.Connect(connectionString)
151+
err = appInstance.Connect(ctx, connectionString)
150152
require.NoError(t, err)
151153

152154
// Test that we can get current database
153-
dbName, err := appInstance.GetCurrentDatabase()
155+
dbName, err := appInstance.GetCurrentDatabase(ctx)
154156
assert.NoError(t, err)
155157
assert.NotEmpty(t, dbName)
156158
}
@@ -167,12 +169,14 @@ func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) {
167169
require.NoError(t, err)
168170
defer appInstance.Disconnect()
169171

172+
ctx := context.Background()
173+
170174
// Explicitly call ensureConnection which will trigger tryConnect() fallback
171-
err = appInstance.ValidateConnection()
175+
err = appInstance.ValidateConnection(ctx)
172176
assert.NoError(t, err)
173177

174178
// Verify connection works
175-
dbName, err := appInstance.GetCurrentDatabase()
179+
dbName, err := appInstance.GetCurrentDatabase(ctx)
176180
assert.NoError(t, err)
177181
assert.NotEmpty(t, dbName)
178182
}
@@ -185,10 +189,12 @@ func TestIntegration_App_ListDatabases(t *testing.T) {
185189
require.NoError(t, err)
186190
defer appInstance.Disconnect()
187191

188-
err = appInstance.Connect(connectionString)
192+
ctx := context.Background()
193+
194+
err = appInstance.Connect(ctx, connectionString)
189195
require.NoError(t, err)
190196

191-
databases, err := appInstance.ListDatabases()
197+
databases, err := appInstance.ListDatabases(ctx)
192198
assert.NoError(t, err)
193199
assert.NotEmpty(t, databases)
194200

@@ -212,10 +218,12 @@ func TestIntegration_App_ListSchemas(t *testing.T) {
212218
require.NoError(t, err)
213219
defer appInstance.Disconnect()
214220

215-
err = appInstance.Connect(connectionString)
221+
ctx := context.Background()
222+
223+
err = appInstance.Connect(ctx, connectionString)
216224
require.NoError(t, err)
217225

218-
schemas, err := appInstance.ListSchemas()
226+
schemas, err := appInstance.ListSchemas(ctx)
219227
assert.NoError(t, err)
220228
assert.NotEmpty(t, schemas)
221229

@@ -237,15 +245,17 @@ func TestIntegration_App_ListTables(t *testing.T) {
237245
require.NoError(t, err)
238246
defer appInstance.Disconnect()
239247

240-
err = appInstance.Connect(connectionString)
248+
ctx := context.Background()
249+
250+
err = appInstance.Connect(ctx, connectionString)
241251
require.NoError(t, err)
242252

243253
// List tables in test schema
244254
listOpts := &app.ListTablesOptions{
245255
Schema: "test_mcp_schema",
246256
}
247257

248-
tables, err := appInstance.ListTables(listOpts)
258+
tables, err := appInstance.ListTables(ctx, listOpts)
249259
assert.NoError(t, err)
250260
assert.NotEmpty(t, tables)
251261

@@ -270,7 +280,9 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
270280
require.NoError(t, err)
271281
defer appInstance.Disconnect()
272282

273-
err = appInstance.Connect(connectionString)
283+
ctx := context.Background()
284+
285+
err = appInstance.Connect(ctx, connectionString)
274286
require.NoError(t, err)
275287

276288
// List tables with size information
@@ -279,7 +291,7 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
279291
IncludeSize: true,
280292
}
281293

282-
tables, err := appInstance.ListTables(listOpts)
294+
tables, err := appInstance.ListTables(ctx, listOpts)
283295
assert.NoError(t, err)
284296
assert.NotEmpty(t, tables)
285297

@@ -300,10 +312,12 @@ func TestIntegration_App_DescribeTable(t *testing.T) {
300312
require.NoError(t, err)
301313
defer appInstance.Disconnect()
302314

303-
err = appInstance.Connect(connectionString)
315+
ctx := context.Background()
316+
317+
err = appInstance.Connect(ctx, connectionString)
304318
require.NoError(t, err)
305319

306-
columns, err := appInstance.DescribeTable("test_mcp_schema", "test_users")
320+
columns, err := appInstance.DescribeTable(ctx, "test_mcp_schema", "test_users")
307321
assert.NoError(t, err)
308322
assert.NotEmpty(t, columns)
309323

@@ -345,15 +359,17 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) {
345359
require.NoError(t, err)
346360
defer appInstance.Disconnect()
347361

348-
err = appInstance.Connect(connectionString)
362+
ctx := context.Background()
363+
364+
err = appInstance.Connect(ctx, connectionString)
349365
require.NoError(t, err)
350366

351367
// Test simple SELECT query
352368
queryOpts := &app.ExecuteQueryOptions{
353369
Query: "SELECT id, name, email FROM test_mcp_schema.test_users WHERE active = true ORDER BY id",
354370
}
355371

356-
result, err := appInstance.ExecuteQuery(queryOpts)
372+
result, err := appInstance.ExecuteQuery(ctx, queryOpts)
357373
assert.NoError(t, err)
358374
assert.NotNil(t, result)
359375

@@ -378,7 +394,9 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
378394
require.NoError(t, err)
379395
defer appInstance.Disconnect()
380396

381-
err = appInstance.Connect(connectionString)
397+
ctx := context.Background()
398+
399+
err = appInstance.Connect(ctx, connectionString)
382400
require.NoError(t, err)
383401

384402
// Test query with limit
@@ -387,7 +405,7 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
387405
Limit: 2,
388406
}
389407

390-
result, err := appInstance.ExecuteQuery(queryOpts)
408+
result, err := appInstance.ExecuteQuery(ctx, queryOpts)
391409
assert.NoError(t, err)
392410
assert.NotNil(t, result)
393411

@@ -404,10 +422,12 @@ func TestIntegration_App_ListIndexes(t *testing.T) {
404422
require.NoError(t, err)
405423
defer appInstance.Disconnect()
406424

407-
err = appInstance.Connect(connectionString)
425+
ctx := context.Background()
426+
427+
err = appInstance.Connect(ctx, connectionString)
408428
require.NoError(t, err)
409429

410-
indexes, err := appInstance.ListIndexes("test_mcp_schema", "test_users")
430+
indexes, err := appInstance.ListIndexes(ctx, "test_mcp_schema", "test_users")
411431
assert.NoError(t, err)
412432
assert.NotEmpty(t, indexes)
413433

@@ -487,10 +507,10 @@ func TestIntegration_App_ListIndexes_SpecialCharacters(t *testing.T) {
487507
require.NoError(t, err)
488508
defer appInstance.Disconnect()
489509

490-
err = appInstance.Connect(connectionString)
510+
err = appInstance.Connect(ctx, connectionString)
491511
require.NoError(t, err)
492512

493-
indexes, err := appInstance.ListIndexes(testSchema, "test_table")
513+
indexes, err := appInstance.ListIndexes(ctx, testSchema, "test_table")
494514
assert.NoError(t, err)
495515
assert.NotEmpty(t, indexes)
496516

@@ -538,11 +558,13 @@ func TestIntegration_App_ExplainQuery(t *testing.T) {
538558
require.NoError(t, err)
539559
defer appInstance.Disconnect()
540560

541-
err = appInstance.Connect(connectionString)
561+
ctx := context.Background()
562+
563+
err = appInstance.Connect(ctx, connectionString)
542564
require.NoError(t, err)
543565

544566
// Test EXPLAIN query
545-
result, err := appInstance.ExplainQuery("SELECT * FROM test_mcp_schema.test_users WHERE active = true")
567+
result, err := appInstance.ExplainQuery(ctx, "SELECT * FROM test_mcp_schema.test_users WHERE active = true")
546568
require.NoError(t, err)
547569
require.NotNil(t, result)
548570

@@ -559,10 +581,12 @@ func TestIntegration_App_GetTableStats(t *testing.T) {
559581
require.NoError(t, err)
560582
defer appInstance.Disconnect()
561583

562-
err = appInstance.Connect(connectionString)
584+
ctx := context.Background()
585+
586+
err = appInstance.Connect(ctx, connectionString)
563587
require.NoError(t, err)
564588

565-
stats, err := appInstance.GetTableStats("test_mcp_schema", "test_users")
589+
stats, err := appInstance.GetTableStats(ctx, "test_mcp_schema", "test_users")
566590
assert.NoError(t, err)
567591
assert.NotNil(t, stats)
568592

@@ -584,22 +608,24 @@ func TestIntegration_App_ErrorHandling(t *testing.T) {
584608
require.NoError(t, err)
585609
defer appInstance.Disconnect()
586610

611+
ctx := context.Background()
612+
587613
// Test query to non-existent table
588-
_, err = appInstance.DescribeTable("public", "nonexistent_table")
614+
_, err = appInstance.DescribeTable(ctx, "public", "nonexistent_table")
589615
assert.Error(t, err)
590616

591617
// Test invalid query
592618
queryOpts := &app.ExecuteQueryOptions{
593619
Query: "INVALID SQL QUERY",
594620
}
595-
_, err = appInstance.ExecuteQuery(queryOpts)
621+
_, err = appInstance.ExecuteQuery(ctx, queryOpts)
596622
assert.Error(t, err)
597623

598624
// Test non-existent schema
599625
listOpts := &app.ListTablesOptions{
600626
Schema: "nonexistent_schema",
601627
}
602-
tables, err := appInstance.ListTables(listOpts)
628+
tables, err := appInstance.ListTables(ctx, listOpts)
603629
assert.NoError(t, err) // This might succeed but return empty results
604630
assert.Empty(t, tables)
605631
}
@@ -612,7 +638,9 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) {
612638
appInstance, err := app.New()
613639
require.NoError(t, err)
614640

615-
err = appInstance.ValidateConnection()
641+
ctx := context.Background()
642+
643+
err = appInstance.ValidateConnection(ctx)
616644
assert.Error(t, err)
617645

618646
// Set environment variable and test validation
@@ -624,7 +652,7 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) {
624652
require.NoError(t, err)
625653
defer appInstance2.Disconnect()
626654

627-
err = appInstance2.ValidateConnection()
655+
err = appInstance2.ValidateConnection(ctx)
628656
assert.NoError(t, err)
629657
}
630658

@@ -640,11 +668,13 @@ func BenchmarkIntegration_ListTables(b *testing.B) {
640668
_, connectionString, cleanup := setupTestDatabase(t)
641669
defer cleanup()
642670

671+
ctx := context.Background()
672+
643673
appInstance, err := app.New()
644674
require.NoError(b, err)
645675
defer appInstance.Disconnect()
646676

647-
err = appInstance.Connect(connectionString)
677+
err = appInstance.Connect(ctx, connectionString)
648678
require.NoError(b, err)
649679

650680
listOpts := &app.ListTablesOptions{
@@ -653,7 +683,7 @@ func BenchmarkIntegration_ListTables(b *testing.B) {
653683

654684
b.ResetTimer()
655685
for i := 0; i < b.N; i++ {
656-
_, err := appInstance.ListTables(listOpts)
686+
_, err := appInstance.ListTables(ctx, listOpts)
657687
if err != nil {
658688
b.Fatal(err)
659689
}
@@ -670,11 +700,13 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {
670700
_, connectionString, cleanup := setupTestDatabase(t)
671701
defer cleanup()
672702

703+
ctx := context.Background()
704+
673705
appInstance, err := app.New()
674706
require.NoError(b, err)
675707
defer appInstance.Disconnect()
676708

677-
err = appInstance.Connect(connectionString)
709+
err = appInstance.Connect(ctx, connectionString)
678710
require.NoError(b, err)
679711

680712
queryOpts := &app.ExecuteQueryOptions{
@@ -683,7 +715,7 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {
683715

684716
b.ResetTimer()
685717
for i := 0; i < b.N; i++ {
686-
_, err := appInstance.ExecuteQuery(queryOpts)
718+
_, err := appInstance.ExecuteQuery(ctx, queryOpts)
687719
if err != nil {
688720
b.Fatal(err)
689721
}

0 commit comments

Comments
 (0)