Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 67 additions & 35 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ func TestIntegration_App_Connect(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

ctx := context.Background()

// Test explicit connection with connection string
err = appInstance.Connect(connectionString)
err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// Test that we can get current database
dbName, err := appInstance.GetCurrentDatabase()
dbName, err := appInstance.GetCurrentDatabase(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, dbName)
}
Expand All @@ -167,12 +169,14 @@ func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

ctx := context.Background()

// Explicitly call ensureConnection which will trigger tryConnect() fallback
err = appInstance.ValidateConnection()
err = appInstance.ValidateConnection(ctx)
assert.NoError(t, err)

// Verify connection works
dbName, err := appInstance.GetCurrentDatabase()
dbName, err := appInstance.GetCurrentDatabase(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, dbName)
}
Expand All @@ -185,10 +189,12 @@ func TestIntegration_App_ListDatabases(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

databases, err := appInstance.ListDatabases()
databases, err := appInstance.ListDatabases(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, databases)

Expand All @@ -212,10 +218,12 @@ func TestIntegration_App_ListSchemas(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

schemas, err := appInstance.ListSchemas()
schemas, err := appInstance.ListSchemas(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, schemas)

Expand All @@ -237,15 +245,17 @@ func TestIntegration_App_ListTables(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// List tables in test schema
listOpts := &app.ListTablesOptions{
Schema: "test_mcp_schema",
}

tables, err := appInstance.ListTables(listOpts)
tables, err := appInstance.ListTables(ctx, listOpts)
assert.NoError(t, err)
assert.NotEmpty(t, tables)

Expand All @@ -270,7 +280,9 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// List tables with size information
Expand All @@ -279,7 +291,7 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
IncludeSize: true,
}

tables, err := appInstance.ListTables(listOpts)
tables, err := appInstance.ListTables(ctx, listOpts)
assert.NoError(t, err)
assert.NotEmpty(t, tables)

Expand All @@ -300,10 +312,12 @@ func TestIntegration_App_DescribeTable(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

columns, err := appInstance.DescribeTable("test_mcp_schema", "test_users")
columns, err := appInstance.DescribeTable(ctx, "test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotEmpty(t, columns)

Expand Down Expand Up @@ -345,15 +359,17 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// Test simple SELECT query
queryOpts := &app.ExecuteQueryOptions{
Query: "SELECT id, name, email FROM test_mcp_schema.test_users WHERE active = true ORDER BY id",
}

result, err := appInstance.ExecuteQuery(queryOpts)
result, err := appInstance.ExecuteQuery(ctx, queryOpts)
assert.NoError(t, err)
assert.NotNil(t, result)

Expand All @@ -378,7 +394,9 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// Test query with limit
Expand All @@ -387,7 +405,7 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
Limit: 2,
}

result, err := appInstance.ExecuteQuery(queryOpts)
result, err := appInstance.ExecuteQuery(ctx, queryOpts)
assert.NoError(t, err)
assert.NotNil(t, result)

Expand All @@ -404,10 +422,12 @@ func TestIntegration_App_ListIndexes(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

indexes, err := appInstance.ListIndexes("test_mcp_schema", "test_users")
indexes, err := appInstance.ListIndexes(ctx, "test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotEmpty(t, indexes)

Expand Down Expand Up @@ -487,10 +507,10 @@ func TestIntegration_App_ListIndexes_SpecialCharacters(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

indexes, err := appInstance.ListIndexes(testSchema, "test_table")
indexes, err := appInstance.ListIndexes(ctx, testSchema, "test_table")
assert.NoError(t, err)
assert.NotEmpty(t, indexes)

Expand Down Expand Up @@ -538,11 +558,13 @@ func TestIntegration_App_ExplainQuery(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

// Test EXPLAIN query
result, err := appInstance.ExplainQuery("SELECT * FROM test_mcp_schema.test_users WHERE active = true")
result, err := appInstance.ExplainQuery(ctx, "SELECT * FROM test_mcp_schema.test_users WHERE active = true")
require.NoError(t, err)
require.NotNil(t, result)

Expand All @@ -559,10 +581,12 @@ func TestIntegration_App_GetTableStats(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
ctx := context.Background()

err = appInstance.Connect(ctx, connectionString)
require.NoError(t, err)

stats, err := appInstance.GetTableStats("test_mcp_schema", "test_users")
stats, err := appInstance.GetTableStats(ctx, "test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotNil(t, stats)

Expand All @@ -584,22 +608,24 @@ func TestIntegration_App_ErrorHandling(t *testing.T) {
require.NoError(t, err)
defer appInstance.Disconnect()

ctx := context.Background()

// Test query to non-existent table
_, err = appInstance.DescribeTable("public", "nonexistent_table")
_, err = appInstance.DescribeTable(ctx, "public", "nonexistent_table")
assert.Error(t, err)

// Test invalid query
queryOpts := &app.ExecuteQueryOptions{
Query: "INVALID SQL QUERY",
}
_, err = appInstance.ExecuteQuery(queryOpts)
_, err = appInstance.ExecuteQuery(ctx, queryOpts)
assert.Error(t, err)

// Test non-existent schema
listOpts := &app.ListTablesOptions{
Schema: "nonexistent_schema",
}
tables, err := appInstance.ListTables(listOpts)
tables, err := appInstance.ListTables(ctx, listOpts)
assert.NoError(t, err) // This might succeed but return empty results
assert.Empty(t, tables)
}
Expand All @@ -612,7 +638,9 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) {
appInstance, err := app.New()
require.NoError(t, err)

err = appInstance.ValidateConnection()
ctx := context.Background()

err = appInstance.ValidateConnection(ctx)
assert.Error(t, err)

// Set environment variable and test validation
Expand All @@ -624,7 +652,7 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) {
require.NoError(t, err)
defer appInstance2.Disconnect()

err = appInstance2.ValidateConnection()
err = appInstance2.ValidateConnection(ctx)
assert.NoError(t, err)
}

Expand All @@ -640,11 +668,13 @@ func BenchmarkIntegration_ListTables(b *testing.B) {
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

ctx := context.Background()

appInstance, err := app.New()
require.NoError(b, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
err = appInstance.Connect(ctx, connectionString)
require.NoError(b, err)

listOpts := &app.ListTablesOptions{
Expand All @@ -653,7 +683,7 @@ func BenchmarkIntegration_ListTables(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := appInstance.ListTables(listOpts)
_, err := appInstance.ListTables(ctx, listOpts)
if err != nil {
b.Fatal(err)
}
Expand All @@ -670,11 +700,13 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

ctx := context.Background()

appInstance, err := app.New()
require.NoError(b, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
err = appInstance.Connect(ctx, connectionString)
require.NoError(b, err)

queryOpts := &app.ExecuteQueryOptions{
Expand All @@ -683,7 +715,7 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := appInstance.ExecuteQuery(queryOpts)
_, err := appInstance.ExecuteQuery(ctx, queryOpts)
if err != nil {
b.Fatal(err)
}
Expand Down
Loading
Loading