diff --git a/integration_test.go b/integration_test.go index df462ca..3ac21e6 100644 --- a/integration_test.go +++ b/integration_test.go @@ -145,12 +145,14 @@ func TestIntegration_App_Connect(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() + ctx := context.Background() + // Test explicit connection with connection string - err = appInstance.Connect(connectionString) + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // Test that we can get current database - dbName, err := appInstance.GetCurrentDatabase() + dbName, err := appInstance.GetCurrentDatabase(ctx) assert.NoError(t, err) assert.NotEmpty(t, dbName) } @@ -167,12 +169,14 @@ func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() + ctx := context.Background() + // Explicitly call ensureConnection which will trigger tryConnect() fallback - err = appInstance.ValidateConnection() + err = appInstance.ValidateConnection(ctx) assert.NoError(t, err) // Verify connection works - dbName, err := appInstance.GetCurrentDatabase() + dbName, err := appInstance.GetCurrentDatabase(ctx) assert.NoError(t, err) assert.NotEmpty(t, dbName) } @@ -185,10 +189,12 @@ func TestIntegration_App_ListDatabases(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - databases, err := appInstance.ListDatabases() + databases, err := appInstance.ListDatabases(ctx) assert.NoError(t, err) assert.NotEmpty(t, databases) @@ -212,10 +218,12 @@ func TestIntegration_App_ListSchemas(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - schemas, err := appInstance.ListSchemas() + schemas, err := appInstance.ListSchemas(ctx) assert.NoError(t, err) assert.NotEmpty(t, schemas) @@ -237,7 +245,9 @@ func TestIntegration_App_ListTables(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // List tables in test schema @@ -245,7 +255,7 @@ func TestIntegration_App_ListTables(t *testing.T) { Schema: "test_mcp_schema", } - tables, err := appInstance.ListTables(listOpts) + tables, err := appInstance.ListTables(ctx, listOpts) assert.NoError(t, err) assert.NotEmpty(t, tables) @@ -270,7 +280,9 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // List tables with size information @@ -279,7 +291,7 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) { IncludeSize: true, } - tables, err := appInstance.ListTables(listOpts) + tables, err := appInstance.ListTables(ctx, listOpts) assert.NoError(t, err) assert.NotEmpty(t, tables) @@ -300,10 +312,12 @@ func TestIntegration_App_DescribeTable(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - columns, err := appInstance.DescribeTable("test_mcp_schema", "test_users") + columns, err := appInstance.DescribeTable(ctx, "test_mcp_schema", "test_users") assert.NoError(t, err) assert.NotEmpty(t, columns) @@ -345,7 +359,9 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // Test simple SELECT query @@ -353,7 +369,7 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) { Query: "SELECT id, name, email FROM test_mcp_schema.test_users WHERE active = true ORDER BY id", } - result, err := appInstance.ExecuteQuery(queryOpts) + result, err := appInstance.ExecuteQuery(ctx, queryOpts) assert.NoError(t, err) assert.NotNil(t, result) @@ -378,7 +394,9 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // Test query with limit @@ -387,7 +405,7 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) { Limit: 2, } - result, err := appInstance.ExecuteQuery(queryOpts) + result, err := appInstance.ExecuteQuery(ctx, queryOpts) assert.NoError(t, err) assert.NotNil(t, result) @@ -404,10 +422,12 @@ func TestIntegration_App_ListIndexes(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - indexes, err := appInstance.ListIndexes("test_mcp_schema", "test_users") + indexes, err := appInstance.ListIndexes(ctx, "test_mcp_schema", "test_users") assert.NoError(t, err) assert.NotEmpty(t, indexes) @@ -487,10 +507,10 @@ func TestIntegration_App_ListIndexes_SpecialCharacters(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - indexes, err := appInstance.ListIndexes(testSchema, "test_table") + indexes, err := appInstance.ListIndexes(ctx, testSchema, "test_table") assert.NoError(t, err) assert.NotEmpty(t, indexes) @@ -538,11 +558,13 @@ func TestIntegration_App_ExplainQuery(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) // Test EXPLAIN query - result, err := appInstance.ExplainQuery("SELECT * FROM test_mcp_schema.test_users WHERE active = true") + result, err := appInstance.ExplainQuery(ctx, "SELECT * FROM test_mcp_schema.test_users WHERE active = true") require.NoError(t, err) require.NotNil(t, result) @@ -559,10 +581,12 @@ func TestIntegration_App_GetTableStats(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + ctx := context.Background() + + err = appInstance.Connect(ctx, connectionString) require.NoError(t, err) - stats, err := appInstance.GetTableStats("test_mcp_schema", "test_users") + stats, err := appInstance.GetTableStats(ctx, "test_mcp_schema", "test_users") assert.NoError(t, err) assert.NotNil(t, stats) @@ -584,22 +608,24 @@ func TestIntegration_App_ErrorHandling(t *testing.T) { require.NoError(t, err) defer appInstance.Disconnect() + ctx := context.Background() + // Test query to non-existent table - _, err = appInstance.DescribeTable("public", "nonexistent_table") + _, err = appInstance.DescribeTable(ctx, "public", "nonexistent_table") assert.Error(t, err) // Test invalid query queryOpts := &app.ExecuteQueryOptions{ Query: "INVALID SQL QUERY", } - _, err = appInstance.ExecuteQuery(queryOpts) + _, err = appInstance.ExecuteQuery(ctx, queryOpts) assert.Error(t, err) // Test non-existent schema listOpts := &app.ListTablesOptions{ Schema: "nonexistent_schema", } - tables, err := appInstance.ListTables(listOpts) + tables, err := appInstance.ListTables(ctx, listOpts) assert.NoError(t, err) // This might succeed but return empty results assert.Empty(t, tables) } @@ -612,7 +638,9 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) { appInstance, err := app.New() require.NoError(t, err) - err = appInstance.ValidateConnection() + ctx := context.Background() + + err = appInstance.ValidateConnection(ctx) assert.Error(t, err) // Set environment variable and test validation @@ -624,7 +652,7 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) { require.NoError(t, err) defer appInstance2.Disconnect() - err = appInstance2.ValidateConnection() + err = appInstance2.ValidateConnection(ctx) assert.NoError(t, err) } @@ -640,11 +668,13 @@ func BenchmarkIntegration_ListTables(b *testing.B) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() + ctx := context.Background() + appInstance, err := app.New() require.NoError(b, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + err = appInstance.Connect(ctx, connectionString) require.NoError(b, err) listOpts := &app.ListTablesOptions{ @@ -653,7 +683,7 @@ func BenchmarkIntegration_ListTables(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := appInstance.ListTables(listOpts) + _, err := appInstance.ListTables(ctx, listOpts) if err != nil { b.Fatal(err) } @@ -670,11 +700,13 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() + ctx := context.Background() + appInstance, err := app.New() require.NoError(b, err) defer appInstance.Disconnect() - err = appInstance.Connect(connectionString) + err = appInstance.Connect(ctx, connectionString) require.NoError(b, err) queryOpts := &app.ExecuteQueryOptions{ @@ -683,7 +715,7 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := appInstance.ExecuteQuery(queryOpts) + _, err := appInstance.ExecuteQuery(ctx, queryOpts) if err != nil { b.Fatal(err) } diff --git a/internal/app/app.go b/internal/app/app.go index 017d677..a2dd4fe 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -1,6 +1,7 @@ package app import ( + "context" "fmt" "log/slog" "os" @@ -53,14 +54,14 @@ func (a *App) SetLogger(logger *slog.Logger) { // Connect establishes a database connection with the provided connection string. // If a connection already exists, it will be closed before establishing a new one. -func (a *App) Connect(connectionString string) error { +func (a *App) Connect(ctx context.Context, connectionString string) error { if connectionString == "" { return ErrNoConnectionString } // Close existing connection if any if a.client != nil { - if err := a.client.Ping(); err == nil { + if err := a.client.Ping(ctx); err == nil { // Connection exists and is active, close it first if closeErr := a.client.Close(); closeErr != nil { a.logger.Warn("Failed to close existing connection", "error", closeErr) @@ -70,7 +71,7 @@ func (a *App) Connect(connectionString string) error { a.logger.Debug("Connecting to PostgreSQL database") - if err := a.client.Connect(connectionString); err != nil { + if err := a.client.Connect(ctx, connectionString); err != nil { a.logger.Error("Failed to connect to database", "error", err) return fmt.Errorf("failed to connect: %w", err) } @@ -90,14 +91,14 @@ func (a *App) Disconnect() error { } // ListDatabases returns a list of all databases. -func (a *App) ListDatabases() ([]*DatabaseInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ListDatabases(ctx context.Context) ([]*DatabaseInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } a.logger.Debug("Listing databases") - databases, err := a.client.ListDatabases() + databases, err := a.client.ListDatabases(ctx) if err != nil { a.logger.Error("Failed to list databases", "error", err) return nil, fmt.Errorf("failed to list databases: %w", err) @@ -108,14 +109,14 @@ func (a *App) ListDatabases() ([]*DatabaseInfo, error) { } // ListSchemas returns a list of schemas in the current database. -func (a *App) ListSchemas() ([]*SchemaInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ListSchemas(ctx context.Context) ([]*SchemaInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } a.logger.Debug("Listing schemas") - schemas, err := a.client.ListSchemas() + schemas, err := a.client.ListSchemas(ctx) if err != nil { a.logger.Error("Failed to list schemas", "error", err) return nil, fmt.Errorf("failed to list schemas: %w", err) @@ -126,8 +127,8 @@ func (a *App) ListSchemas() ([]*SchemaInfo, error) { } // ListTables returns a list of tables in the specified schema. -func (a *App) ListTables(opts *ListTablesOptions) ([]*TableInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ListTables(ctx context.Context, opts *ListTablesOptions) ([]*TableInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -143,13 +144,13 @@ func (a *App) ListTables(opts *ListTablesOptions) ([]*TableInfo, error) { // Use optimized query when stats are requested to avoid N+1 query pattern if opts != nil && opts.IncludeSize { - tables, err = a.client.ListTablesWithStats(schema) + tables, err = a.client.ListTablesWithStats(ctx, schema) if err != nil { a.logger.Error("Failed to list tables with stats", "error", err, "schema", schema) return nil, fmt.Errorf("failed to list tables with stats: %w", err) } } else { - tables, err = a.client.ListTables(schema) + tables, err = a.client.ListTables(ctx, schema) if err != nil { a.logger.Error("Failed to list tables", "error", err, "schema", schema) return nil, fmt.Errorf("failed to list tables: %w", err) @@ -161,8 +162,8 @@ func (a *App) ListTables(opts *ListTablesOptions) ([]*TableInfo, error) { } // DescribeTable returns detailed information about a table's structure. -func (a *App) DescribeTable(schema, table string) ([]*ColumnInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) DescribeTable(ctx context.Context, schema, table string) ([]*ColumnInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -176,7 +177,7 @@ func (a *App) DescribeTable(schema, table string) ([]*ColumnInfo, error) { a.logger.Debug("Describing table", "schema", schema, "table", table) - columns, err := a.client.DescribeTable(schema, table) + columns, err := a.client.DescribeTable(ctx, schema, table) if err != nil { a.logger.Error("Failed to describe table", "error", err, "schema", schema, "table", table) return nil, fmt.Errorf("failed to describe table: %w", err) @@ -187,8 +188,8 @@ func (a *App) DescribeTable(schema, table string) ([]*ColumnInfo, error) { } // GetTableStats returns statistics for a specific table. -func (a *App) GetTableStats(schema, table string) (*TableInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) GetTableStats(ctx context.Context, schema, table string) (*TableInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -202,7 +203,7 @@ func (a *App) GetTableStats(schema, table string) (*TableInfo, error) { a.logger.Debug("Getting table stats", "schema", schema, "table", table) - stats, err := a.client.GetTableStats(schema, table) + stats, err := a.client.GetTableStats(ctx, schema, table) if err != nil { a.logger.Error("Failed to get table stats", "error", err, "schema", schema, "table", table) return nil, fmt.Errorf("failed to get table stats: %w", err) @@ -213,8 +214,8 @@ func (a *App) GetTableStats(schema, table string) (*TableInfo, error) { } // ListIndexes returns a list of indexes for the specified table. -func (a *App) ListIndexes(schema, table string) ([]*IndexInfo, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ListIndexes(ctx context.Context, schema, table string) ([]*IndexInfo, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -228,7 +229,7 @@ func (a *App) ListIndexes(schema, table string) ([]*IndexInfo, error) { a.logger.Debug("Listing indexes", "schema", schema, "table", table) - indexes, err := a.client.ListIndexes(schema, table) + indexes, err := a.client.ListIndexes(ctx, schema, table) if err != nil { a.logger.Error("Failed to list indexes", "error", err, "schema", schema, "table", table) return nil, fmt.Errorf("failed to list indexes: %w", err) @@ -239,8 +240,8 @@ func (a *App) ListIndexes(schema, table string) ([]*IndexInfo, error) { } // ExecuteQuery executes a read-only query and returns the results. -func (a *App) ExecuteQuery(opts *ExecuteQueryOptions) (*QueryResult, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ExecuteQuery(ctx context.Context, opts *ExecuteQueryOptions) (*QueryResult, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -250,7 +251,7 @@ func (a *App) ExecuteQuery(opts *ExecuteQueryOptions) (*QueryResult, error) { a.logger.Debug("Executing query", "query", opts.Query) - result, err := a.client.ExecuteQuery(opts.Query, opts.Args...) + result, err := a.client.ExecuteQuery(ctx, opts.Query, opts.Args...) if err != nil { a.logger.Error("Failed to execute query", "error", err, "query", opts.Query) return nil, fmt.Errorf("failed to execute query: %w", err) @@ -267,12 +268,12 @@ func (a *App) ExecuteQuery(opts *ExecuteQueryOptions) (*QueryResult, error) { } // GetCurrentDatabase returns the name of the current database. -func (a *App) GetCurrentDatabase() (string, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) GetCurrentDatabase(ctx context.Context) (string, error) { + if err := a.ensureConnection(ctx); err != nil { return "", err } - dbName, err := a.client.GetCurrentDatabase() + dbName, err := a.client.GetCurrentDatabase(ctx) if err != nil { return "", fmt.Errorf("failed to get current database: %w", err) } @@ -280,8 +281,8 @@ func (a *App) GetCurrentDatabase() (string, error) { } // ExplainQuery returns the execution plan for a query. -func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) { - if err := a.ensureConnection(); err != nil { +func (a *App) ExplainQuery(ctx context.Context, query string, args ...any) (*QueryResult, error) { + if err := a.ensureConnection(ctx); err != nil { return nil, err } @@ -291,7 +292,7 @@ func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) { a.logger.Debug("Explaining query", "query", query) - result, err := a.client.ExplainQuery(query, args...) + result, err := a.client.ExplainQuery(ctx, query, args...) if err != nil { a.logger.Error("Failed to explain query", "error", err, "query", query) return nil, fmt.Errorf("failed to explain query: %w", err) @@ -302,13 +303,13 @@ func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) { } // ValidateConnection checks if the database connection is valid (for backward compatibility). -func (a *App) ValidateConnection() error { - return a.ensureConnection() +func (a *App) ValidateConnection(ctx context.Context) error { + return a.ensureConnection(ctx) } // tryConnect attempts to connect using environment variables as a fallback mechanism. // Returns ErrNoConnectionString if no environment variables are set. -func (a *App) tryConnect() error { +func (a *App) tryConnect(ctx context.Context) error { // Try environment variables as fallback connectionString := os.Getenv("POSTGRES_URL") if connectionString == "" { @@ -319,21 +320,23 @@ func (a *App) tryConnect() error { return ErrNoConnectionString } - return a.Connect(connectionString) + return a.Connect(ctx, connectionString) } // ensureConnection checks if the database connection is valid and attempts to reconnect if needed. -func (a *App) ensureConnection() error { +func (a *App) ensureConnection(ctx context.Context) error { if a.client == nil { return ErrConnectionRequired } - // Test current connection - if err := a.client.Ping(); err != nil { + // Test current connection with request context + if err := a.client.Ping(ctx); err != nil { a.logger.Debug("Database connection lost, attempting to reconnect", "error", err) - // Attempt to reconnect - if reconnectErr := a.tryConnect(); reconnectErr != nil { + // Attempt to reconnect using background context + // Reconnection is infrastructure work and shouldn't be cancelled by request timeout + reconnectCtx := context.Background() + if reconnectErr := a.tryConnect(reconnectCtx); reconnectErr != nil { //nolint:contextcheck // Intentional: reconnection must not be cancelled by request context a.logger.Error("Failed to reconnect to database", "ping_error", err, "reconnect_error", reconnectErr) return ErrConnectionRequired } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 49e8367..df6e74f 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -1,6 +1,7 @@ package app import ( + "context" "database/sql" "errors" "log/slog" @@ -15,8 +16,8 @@ type MockPostgreSQLClient struct { mock.Mock } -func (m *MockPostgreSQLClient) Connect(connectionString string) error { - args := m.Called(connectionString) +func (m *MockPostgreSQLClient) Connect(ctx context.Context, connectionString string) error { + args := m.Called(ctx, connectionString) return args.Error(0) } @@ -25,82 +26,82 @@ func (m *MockPostgreSQLClient) Close() error { return args.Error(0) } -func (m *MockPostgreSQLClient) Ping() error { - args := m.Called() +func (m *MockPostgreSQLClient) Ping(ctx context.Context) error { + args := m.Called(ctx) return args.Error(0) } -func (m *MockPostgreSQLClient) ListDatabases() ([]*DatabaseInfo, error) { - args := m.Called() +func (m *MockPostgreSQLClient) ListDatabases(ctx context.Context) ([]*DatabaseInfo, error) { + args := m.Called(ctx) if databases, ok := args.Get(0).([]*DatabaseInfo); ok { return databases, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) GetCurrentDatabase() (string, error) { - args := m.Called() +func (m *MockPostgreSQLClient) GetCurrentDatabase(ctx context.Context) (string, error) { + args := m.Called(ctx) return args.String(0), args.Error(1) } -func (m *MockPostgreSQLClient) ListSchemas() ([]*SchemaInfo, error) { - args := m.Called() +func (m *MockPostgreSQLClient) ListSchemas(ctx context.Context) ([]*SchemaInfo, error) { + args := m.Called(ctx) if schemas, ok := args.Get(0).([]*SchemaInfo); ok { return schemas, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) ListTables(schema string) ([]*TableInfo, error) { - args := m.Called(schema) +func (m *MockPostgreSQLClient) ListTables(ctx context.Context, schema string) ([]*TableInfo, error) { + args := m.Called(ctx, schema) if tables, ok := args.Get(0).([]*TableInfo); ok { return tables, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) ListTablesWithStats(schema string) ([]*TableInfo, error) { - args := m.Called(schema) +func (m *MockPostgreSQLClient) ListTablesWithStats(ctx context.Context, schema string) ([]*TableInfo, error) { + args := m.Called(ctx, schema) if tables, ok := args.Get(0).([]*TableInfo); ok { return tables, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) DescribeTable(schema, table string) ([]*ColumnInfo, error) { - args := m.Called(schema, table) +func (m *MockPostgreSQLClient) DescribeTable(ctx context.Context, schema, table string) ([]*ColumnInfo, error) { + args := m.Called(ctx, schema, table) if columns, ok := args.Get(0).([]*ColumnInfo); ok { return columns, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) GetTableStats(schema, table string) (*TableInfo, error) { - args := m.Called(schema, table) +func (m *MockPostgreSQLClient) GetTableStats(ctx context.Context, schema, table string) (*TableInfo, error) { + args := m.Called(ctx, schema, table) if stats, ok := args.Get(0).(*TableInfo); ok { return stats, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) ListIndexes(schema, table string) ([]*IndexInfo, error) { - args := m.Called(schema, table) +func (m *MockPostgreSQLClient) ListIndexes(ctx context.Context, schema, table string) ([]*IndexInfo, error) { + args := m.Called(ctx, schema, table) if indexes, ok := args.Get(0).([]*IndexInfo); ok { return indexes, args.Error(1) } return nil, args.Error(1) } -func (m *MockPostgreSQLClient) ExecuteQuery(query string, args ...interface{}) (*QueryResult, error) { - mockArgs := m.Called(query, args) +func (m *MockPostgreSQLClient) ExecuteQuery(ctx context.Context, query string, args ...interface{}) (*QueryResult, error) { + mockArgs := m.Called(ctx, query, args) if result, ok := mockArgs.Get(0).(*QueryResult); ok { return result, mockArgs.Error(1) } return nil, mockArgs.Error(1) } -func (m *MockPostgreSQLClient) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) { - mockArgs := m.Called(query, args) +func (m *MockPostgreSQLClient) ExplainQuery(ctx context.Context, query string, args ...interface{}) (*QueryResult, error) { + mockArgs := m.Called(ctx, query, args) if result, ok := mockArgs.Get(0).(*QueryResult); ok { return result, mockArgs.Error(1) } @@ -160,9 +161,9 @@ func TestApp_ValidateConnection(t *testing.T) { mockClient := &MockPostgreSQLClient{} app.client = mockClient - mockClient.On("Ping").Return(nil) + mockClient.On("Ping", mock.Anything).Return(nil) - err := app.ValidateConnection() + err := app.ValidateConnection(context.Background()) assert.NoError(t, err) mockClient.AssertExpectations(t) } @@ -171,7 +172,7 @@ func TestApp_ValidateConnectionNilClient(t *testing.T) { app, _ := New() app.client = nil - err := app.ValidateConnection() + err := app.ValidateConnection(context.Background()) assert.Error(t, err) assert.Equal(t, ErrConnectionRequired, err) } @@ -183,9 +184,9 @@ func TestApp_ValidateConnectionPingError(t *testing.T) { // Mock ping failure and reconnection failure (no env vars set) pingError := errors.New("ping failed") - mockClient.On("Ping").Return(pingError) + mockClient.On("Ping", mock.Anything).Return(pingError) - err := app.ValidateConnection() + err := app.ValidateConnection(context.Background()) assert.Error(t, err) assert.Equal(t, ErrConnectionRequired, err) mockClient.AssertExpectations(t) @@ -201,10 +202,10 @@ func TestApp_ListDatabases(t *testing.T) { {Name: "db2", Owner: "user2", Encoding: "UTF8"}, } - mockClient.On("Ping").Return(nil) - mockClient.On("ListDatabases").Return(expectedDatabases, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListDatabases", mock.Anything).Return(expectedDatabases, nil) - databases, err := app.ListDatabases() + databases, err := app.ListDatabases(context.Background()) assert.NoError(t, err) assert.Equal(t, expectedDatabases, databases) mockClient.AssertExpectations(t) @@ -216,9 +217,9 @@ func TestApp_ListDatabasesConnectionError(t *testing.T) { app.client = mockClient expectedError := errors.New("connection error") - mockClient.On("Ping").Return(expectedError) + mockClient.On("Ping", mock.Anything).Return(expectedError) - databases, err := app.ListDatabases() + databases, err := app.ListDatabases(context.Background()) assert.Error(t, err) assert.Nil(t, databases) // After our refactoring, ping failure leads to reconnection attempt, which fails due to no env vars, @@ -234,10 +235,10 @@ func TestApp_GetCurrentDatabase(t *testing.T) { expectedDB := "testdb" - mockClient.On("Ping").Return(nil) - mockClient.On("GetCurrentDatabase").Return(expectedDB, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("GetCurrentDatabase", mock.Anything).Return(expectedDB, nil) - dbName, err := app.GetCurrentDatabase() + dbName, err := app.GetCurrentDatabase(context.Background()) assert.NoError(t, err) assert.Equal(t, expectedDB, dbName) mockClient.AssertExpectations(t) @@ -253,10 +254,10 @@ func TestApp_ListSchemas(t *testing.T) { {Name: "private", Owner: "user"}, } - mockClient.On("Ping").Return(nil) - mockClient.On("ListSchemas").Return(expectedSchemas, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListSchemas", mock.Anything).Return(expectedSchemas, nil) - schemas, err := app.ListSchemas() + schemas, err := app.ListSchemas(context.Background()) assert.NoError(t, err) assert.Equal(t, expectedSchemas, schemas) mockClient.AssertExpectations(t) @@ -276,10 +277,10 @@ func TestApp_ListTables(t *testing.T) { Schema: "public", } - mockClient.On("Ping").Return(nil) - mockClient.On("ListTables", "public").Return(expectedTables, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListTables", mock.Anything, "public").Return(expectedTables, nil) - tables, err := app.ListTables(opts) + tables, err := app.ListTables(context.Background(), opts) assert.NoError(t, err) assert.Equal(t, expectedTables, tables) mockClient.AssertExpectations(t) @@ -296,10 +297,10 @@ func TestApp_ListTablesWithDefaultSchema(t *testing.T) { opts := &ListTablesOptions{} - mockClient.On("Ping").Return(nil) - mockClient.On("ListTables", DefaultSchema).Return(expectedTables, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListTables", mock.Anything, DefaultSchema).Return(expectedTables, nil) - tables, err := app.ListTables(opts) + tables, err := app.ListTables(context.Background(), opts) assert.NoError(t, err) assert.Equal(t, expectedTables, tables) mockClient.AssertExpectations(t) @@ -314,10 +315,10 @@ func TestApp_ListTablesWithNilOptions(t *testing.T) { {Schema: "public", Name: "users", Type: "table", Owner: "user"}, } - mockClient.On("Ping").Return(nil) - mockClient.On("ListTables", DefaultSchema).Return(expectedTables, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListTables", mock.Anything, DefaultSchema).Return(expectedTables, nil) - tables, err := app.ListTables(nil) + tables, err := app.ListTables(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, expectedTables, tables) mockClient.AssertExpectations(t) @@ -344,10 +345,10 @@ func TestApp_ListTablesWithSize(t *testing.T) { IncludeSize: true, } - mockClient.On("Ping").Return(nil) - mockClient.On("ListTablesWithStats", "public").Return(tablesWithStats, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListTablesWithStats", mock.Anything, "public").Return(tablesWithStats, nil) - tables, err := app.ListTables(opts) + tables, err := app.ListTables(context.Background(), opts) assert.NoError(t, err) assert.Len(t, tables, 1) assert.Equal(t, int64(1000), tables[0].RowCount) @@ -365,10 +366,10 @@ func TestApp_DescribeTable(t *testing.T) { {Name: "name", DataType: "varchar(255)", IsNullable: true}, } - mockClient.On("Ping").Return(nil) - mockClient.On("DescribeTable", "public", "users").Return(expectedColumns, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("DescribeTable", mock.Anything, "public", "users").Return(expectedColumns, nil) - columns, err := app.DescribeTable("public", "users") + columns, err := app.DescribeTable(context.Background(), "public", "users") assert.NoError(t, err) assert.Equal(t, expectedColumns, columns) mockClient.AssertExpectations(t) @@ -377,7 +378,7 @@ func TestApp_DescribeTable(t *testing.T) { func TestApp_DescribeTableEmptyTableName(t *testing.T) { app, _ := New() - columns, err := app.DescribeTable("public", "") + columns, err := app.DescribeTable(context.Background(), "public", "") assert.Error(t, err) assert.Nil(t, columns) assert.Contains(t, err.Error(), "database connection failed") @@ -392,10 +393,10 @@ func TestApp_DescribeTableDefaultSchema(t *testing.T) { {Name: "id", DataType: "integer", IsNullable: false}, } - mockClient.On("Ping").Return(nil) - mockClient.On("DescribeTable", DefaultSchema, "users").Return(expectedColumns, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("DescribeTable", mock.Anything, DefaultSchema, "users").Return(expectedColumns, nil) - columns, err := app.DescribeTable("", "users") + columns, err := app.DescribeTable(context.Background(), "", "users") assert.NoError(t, err) assert.Equal(t, expectedColumns, columns) mockClient.AssertExpectations(t) @@ -416,10 +417,10 @@ func TestApp_ExecuteQuery(t *testing.T) { Query: "SELECT id, name FROM users", } - mockClient.On("Ping").Return(nil) - mockClient.On("ExecuteQuery", "SELECT id, name FROM users", []interface{}(nil)).Return(expectedResult, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ExecuteQuery", mock.Anything, "SELECT id, name FROM users", []interface{}(nil)).Return(expectedResult, nil) - result, err := app.ExecuteQuery(opts) + result, err := app.ExecuteQuery(context.Background(), opts) assert.NoError(t, err) assert.Equal(t, expectedResult, result) mockClient.AssertExpectations(t) @@ -441,10 +442,10 @@ func TestApp_ExecuteQueryWithLimit(t *testing.T) { Limit: 2, } - mockClient.On("Ping").Return(nil) - mockClient.On("ExecuteQuery", "SELECT id, name FROM users", []interface{}(nil)).Return(originalResult, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ExecuteQuery", mock.Anything, "SELECT id, name FROM users", []interface{}(nil)).Return(originalResult, nil) - result, err := app.ExecuteQuery(opts) + result, err := app.ExecuteQuery(context.Background(), opts) assert.NoError(t, err) assert.Len(t, result.Rows, 2) assert.Equal(t, 2, result.RowCount) @@ -454,7 +455,7 @@ func TestApp_ExecuteQueryWithLimit(t *testing.T) { func TestApp_ExecuteQueryNilOptions(t *testing.T) { app, _ := New() - result, err := app.ExecuteQuery(nil) + result, err := app.ExecuteQuery(context.Background(), nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "database connection failed") @@ -465,7 +466,7 @@ func TestApp_ExecuteQueryEmptyQuery(t *testing.T) { opts := &ExecuteQueryOptions{} - result, err := app.ExecuteQuery(opts) + result, err := app.ExecuteQuery(context.Background(), opts) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "database connection failed") @@ -482,10 +483,10 @@ func TestApp_ExplainQuery(t *testing.T) { RowCount: 1, } - mockClient.On("Ping").Return(nil) - mockClient.On("ExplainQuery", "SELECT * FROM users", []interface{}(nil)).Return(expectedResult, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ExplainQuery", mock.Anything, "SELECT * FROM users", []interface{}(nil)).Return(expectedResult, nil) - result, err := app.ExplainQuery("SELECT * FROM users") + result, err := app.ExplainQuery(context.Background(), "SELECT * FROM users") assert.NoError(t, err) assert.Equal(t, expectedResult, result) mockClient.AssertExpectations(t) @@ -494,7 +495,7 @@ func TestApp_ExplainQuery(t *testing.T) { func TestApp_ExplainQueryEmptyQuery(t *testing.T) { app, _ := New() - result, err := app.ExplainQuery("") + result, err := app.ExplainQuery(context.Background(), "") assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "database connection failed") @@ -512,10 +513,10 @@ func TestApp_GetTableStats(t *testing.T) { Size: "5MB", } - mockClient.On("Ping").Return(nil) - mockClient.On("GetTableStats", "public", "users").Return(expectedStats, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("GetTableStats", mock.Anything, "public", "users").Return(expectedStats, nil) - stats, err := app.GetTableStats("public", "users") + stats, err := app.GetTableStats(context.Background(), "public", "users") assert.NoError(t, err) assert.Equal(t, expectedStats, stats) mockClient.AssertExpectations(t) @@ -531,10 +532,10 @@ func TestApp_ListIndexes(t *testing.T) { {Name: "idx_users_email", Table: "users", Columns: []string{"email"}, IsUnique: true, IsPrimary: false}, } - mockClient.On("Ping").Return(nil) - mockClient.On("ListIndexes", "public", "users").Return(expectedIndexes, nil) + mockClient.On("Ping", mock.Anything).Return(nil) + mockClient.On("ListIndexes", mock.Anything, "public", "users").Return(expectedIndexes, nil) - indexes, err := app.ListIndexes("public", "users") + indexes, err := app.ListIndexes(context.Background(), "public", "users") assert.NoError(t, err) assert.Equal(t, expectedIndexes, indexes) mockClient.AssertExpectations(t) @@ -548,10 +549,10 @@ func TestApp_Connect_Success(t *testing.T) { connectionString := "postgres://user:pass@localhost/db" // Mock expectations - mockClient.On("Ping").Return(errors.New("not connected")) // No existing connection - mockClient.On("Connect", connectionString).Return(nil) + mockClient.On("Ping", mock.Anything).Return(errors.New("not connected")) // No existing connection + mockClient.On("Connect", mock.Anything, connectionString).Return(nil) - err := app.Connect(connectionString) + err := app.Connect(context.Background(), connectionString) assert.NoError(t, err) mockClient.AssertExpectations(t) } @@ -559,7 +560,7 @@ func TestApp_Connect_Success(t *testing.T) { func TestApp_Connect_EmptyString(t *testing.T) { app, _ := New() - err := app.Connect("") + err := app.Connect(context.Background(), "") assert.Error(t, err) assert.Equal(t, ErrNoConnectionString, err) } @@ -572,11 +573,11 @@ func TestApp_Connect_ReconnectClosesExisting(t *testing.T) { connectionString := "postgres://user:pass@localhost/db" // Mock expectations for reconnection scenario - mockClient.On("Ping").Return(nil).Once() // Existing connection is alive + mockClient.On("Ping", mock.Anything).Return(nil).Once() // Existing connection is alive mockClient.On("Close").Return(nil).Once() // Close existing - mockClient.On("Connect", connectionString).Return(nil) + mockClient.On("Connect", mock.Anything, connectionString).Return(nil) - err := app.Connect(connectionString) + err := app.Connect(context.Background(), connectionString) assert.NoError(t, err) mockClient.AssertExpectations(t) } @@ -590,10 +591,10 @@ func TestApp_Connect_ConnectError(t *testing.T) { expectedError := errors.New("connection failed") // Mock expectations - mockClient.On("Ping").Return(errors.New("not connected")) // No existing connection - mockClient.On("Connect", connectionString).Return(expectedError) + mockClient.On("Ping", mock.Anything).Return(errors.New("not connected")) // No existing connection + mockClient.On("Connect", mock.Anything, connectionString).Return(expectedError) - err := app.Connect(connectionString) + err := app.Connect(context.Background(), connectionString) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to connect") mockClient.AssertExpectations(t) diff --git a/internal/app/client.go b/internal/app/client.go index dd92b2d..7e5cec0 100644 --- a/internal/app/client.go +++ b/internal/app/client.go @@ -22,13 +22,13 @@ func NewPostgreSQLClient() *PostgreSQLClientImpl { } // Connect establishes a connection to the PostgreSQL database. -func (c *PostgreSQLClientImpl) Connect(connectionString string) error { +func (c *PostgreSQLClientImpl) Connect(ctx context.Context, connectionString string) error { db, err := sql.Open("postgres", connectionString) if err != nil { return fmt.Errorf("failed to open database connection: %w", err) } - if err := db.PingContext(context.Background()); err != nil { + if err := db.PingContext(ctx); err != nil { _ = db.Close() return fmt.Errorf("failed to ping database: %w", err) } @@ -50,11 +50,11 @@ func (c *PostgreSQLClientImpl) Close() error { } // Ping checks if the database connection is alive. -func (c *PostgreSQLClientImpl) Ping() error { +func (c *PostgreSQLClientImpl) Ping(ctx context.Context) error { if c.db == nil { return ErrNoDatabaseConnection } - if err := c.db.PingContext(context.Background()); err != nil { + if err := c.db.PingContext(ctx); err != nil { return fmt.Errorf("failed to ping database: %w", err) } return nil @@ -66,7 +66,7 @@ func (c *PostgreSQLClientImpl) GetDB() *sql.DB { } // ListDatabases returns a list of all databases on the server. -func (c *PostgreSQLClientImpl) ListDatabases() ([]*DatabaseInfo, error) { +func (c *PostgreSQLClientImpl) ListDatabases(ctx context.Context) ([]*DatabaseInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -77,7 +77,7 @@ func (c *PostgreSQLClientImpl) ListDatabases() ([]*DatabaseInfo, error) { WHERE datistemplate = false ORDER BY datname` - rows, err := c.db.QueryContext(context.Background(), query) + rows, err := c.db.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to list databases: %w", err) } @@ -99,13 +99,13 @@ func (c *PostgreSQLClientImpl) ListDatabases() ([]*DatabaseInfo, error) { } // GetCurrentDatabase returns the name of the current database. -func (c *PostgreSQLClientImpl) GetCurrentDatabase() (string, error) { +func (c *PostgreSQLClientImpl) GetCurrentDatabase(ctx context.Context) (string, error) { if c.db == nil { return "", ErrNoDatabaseConnection } var dbName string - err := c.db.QueryRowContext(context.Background(), "SELECT current_database()").Scan(&dbName) + err := c.db.QueryRowContext(ctx, "SELECT current_database()").Scan(&dbName) if err != nil { return "", fmt.Errorf("failed to get current database: %w", err) } @@ -114,7 +114,7 @@ func (c *PostgreSQLClientImpl) GetCurrentDatabase() (string, error) { } // ListSchemas returns a list of schemas in the current database. -func (c *PostgreSQLClientImpl) ListSchemas() ([]*SchemaInfo, error) { +func (c *PostgreSQLClientImpl) ListSchemas(ctx context.Context) ([]*SchemaInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -125,7 +125,7 @@ func (c *PostgreSQLClientImpl) ListSchemas() ([]*SchemaInfo, error) { WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast') ORDER BY schema_name` - rows, err := c.db.QueryContext(context.Background(), query) + rows, err := c.db.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to list schemas: %w", err) } @@ -147,7 +147,7 @@ func (c *PostgreSQLClientImpl) ListSchemas() ([]*SchemaInfo, error) { } // ListTables returns a list of tables in the specified schema. -func (c *PostgreSQLClientImpl) ListTables(schema string) ([]*TableInfo, error) { +func (c *PostgreSQLClientImpl) ListTables(ctx context.Context, schema string) ([]*TableInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -174,7 +174,7 @@ func (c *PostgreSQLClientImpl) ListTables(schema string) ([]*TableInfo, error) { WHERE schemaname = $1 ORDER BY tablename` - rows, err := c.db.QueryContext(context.Background(), query, schema) + rows, err := c.db.QueryContext(ctx, query, schema) if err != nil { return nil, fmt.Errorf("failed to list tables: %w", err) } @@ -198,7 +198,7 @@ func (c *PostgreSQLClientImpl) ListTables(schema string) ([]*TableInfo, error) { // ListTablesWithStats returns a list of tables with size and row count statistics in a single optimized query. // This eliminates the N+1 query pattern by joining table metadata with pg_stat_user_tables. // For tables where statistics show 0 rows, it falls back to COUNT(*) to get actual row counts. -func (c *PostgreSQLClientImpl) ListTablesWithStats(schema string) ([]*TableInfo, error) { +func (c *PostgreSQLClientImpl) ListTablesWithStats(ctx context.Context, schema string) ([]*TableInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -239,7 +239,7 @@ func (c *PostgreSQLClientImpl) ListTablesWithStats(schema string) ([]*TableInfo, ON t.schemaname = s.schemaname AND t.tablename = s.relname ORDER BY t.tablename` - rows, err := c.db.QueryContext(context.Background(), query, schema) + rows, err := c.db.QueryContext(ctx, query, schema) if err != nil { return nil, fmt.Errorf("failed to list tables with stats: %w", err) } @@ -264,7 +264,7 @@ func (c *PostgreSQLClientImpl) ListTablesWithStats(schema string) ([]*TableInfo, if table.RowCount == 0 && table.Type == "table" { countQuery := `SELECT COUNT(*) FROM "` + table.Schema + `"."` + table.Name + `"` var actualCount int64 - if err := c.db.QueryRowContext(context.Background(), countQuery).Scan(&actualCount); err != nil { + if err := c.db.QueryRowContext(ctx, countQuery).Scan(&actualCount); err != nil { // Log warning but don't fail the entire operation continue } @@ -276,7 +276,7 @@ func (c *PostgreSQLClientImpl) ListTablesWithStats(schema string) ([]*TableInfo, } // DescribeTable returns detailed column information for a table. -func (c *PostgreSQLClientImpl) DescribeTable(schema, table string) ([]*ColumnInfo, error) { +func (c *PostgreSQLClientImpl) DescribeTable(ctx context.Context, schema, table string) ([]*ColumnInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -295,7 +295,7 @@ func (c *PostgreSQLClientImpl) DescribeTable(schema, table string) ([]*ColumnInf WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position` - rows, err := c.db.QueryContext(context.Background(), query, schema, table) + rows, err := c.db.QueryContext(ctx, query, schema, table) if err != nil { return nil, fmt.Errorf("failed to describe table: %w", err) } @@ -323,7 +323,7 @@ func (c *PostgreSQLClientImpl) DescribeTable(schema, table string) ([]*ColumnInf } // GetTableStats returns statistics for a specific table. -func (c *PostgreSQLClientImpl) GetTableStats(schema, table string) (*TableInfo, error) { +func (c *PostgreSQLClientImpl) GetTableStats(ctx context.Context, schema, table string) (*TableInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -345,7 +345,7 @@ func (c *PostgreSQLClientImpl) GetTableStats(schema, table string) (*TableInfo, WHERE schemaname = $1 AND relname = $2` var rowCount sql.NullInt64 - err := c.db.QueryRowContext(context.Background(), countQuery, schema, table).Scan(&rowCount) + err := c.db.QueryRowContext(ctx, countQuery, schema, table).Scan(&rowCount) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to get table stats: %w", err) } @@ -356,7 +356,7 @@ func (c *PostgreSQLClientImpl) GetTableStats(schema, table string) (*TableInfo, // Use string concatenation instead of fmt.Sprintf for security actualCountQuery := `SELECT COUNT(*) FROM "` + schema + `"."` + table + `"` var actualCount int64 - err := c.db.QueryRowContext(context.Background(), actualCountQuery).Scan(&actualCount) + err := c.db.QueryRowContext(ctx, actualCountQuery).Scan(&actualCount) if err != nil { return nil, fmt.Errorf("failed to get actual row count: %w", err) } @@ -369,7 +369,7 @@ func (c *PostgreSQLClientImpl) GetTableStats(schema, table string) (*TableInfo, } // ListIndexes returns a list of indexes for the specified table. -func (c *PostgreSQLClientImpl) ListIndexes(schema, table string) ([]*IndexInfo, error) { +func (c *PostgreSQLClientImpl) ListIndexes(ctx context.Context, schema, table string) ([]*IndexInfo, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -396,7 +396,7 @@ func (c *PostgreSQLClientImpl) ListIndexes(schema, table string) ([]*IndexInfo, GROUP BY i.relname, t.relname, ix.indisunique, ix.indisprimary, am.amname ORDER BY i.relname` - rows, err := c.db.QueryContext(context.Background(), query, schema, table) + rows, err := c.db.QueryContext(ctx, query, schema, table) if err != nil { return nil, fmt.Errorf("failed to list indexes: %w", err) } @@ -466,7 +466,7 @@ func processRows(rows *sql.Rows) ([][]any, error) { } // ExecuteQuery executes a SELECT query and returns the results. -func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryResult, error) { +func (c *PostgreSQLClientImpl) ExecuteQuery(ctx context.Context, query string, args ...any) (*QueryResult, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -475,7 +475,7 @@ func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryRe return nil, err } - rows, err := c.db.QueryContext(context.Background(), query, args...) + rows, err := c.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to execute query: %w", err) } @@ -502,7 +502,7 @@ func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryRe } // ExplainQuery returns the execution plan for a query. -func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...any) (*QueryResult, error) { +func (c *PostgreSQLClientImpl) ExplainQuery(ctx context.Context, query string, args ...any) (*QueryResult, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -514,7 +514,7 @@ func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...any) (*QueryRe // Construct the EXPLAIN query explainQuery := "EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) " + query - rows, err := c.db.QueryContext(context.Background(), explainQuery, args...) + rows, err := c.db.QueryContext(ctx, explainQuery, args...) if err != nil { return nil, fmt.Errorf("failed to execute explain query: %w", err) } diff --git a/internal/app/client_mocked_test.go b/internal/app/client_mocked_test.go index a86bb2a..3502cd6 100644 --- a/internal/app/client_mocked_test.go +++ b/internal/app/client_mocked_test.go @@ -1,6 +1,7 @@ package app import ( + "context" "database/sql" "testing" @@ -64,7 +65,7 @@ func TestPostgreSQLClient_ConnectValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := client.Connect(tt.connectionStr) + err := client.Connect(context.Background(), tt.connectionStr) if tt.expectError { assert.Error(t, err) } else { @@ -131,7 +132,7 @@ func TestPostgreSQLClient_QueryValidationLogic(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Test the validation logic that would happen in ExecuteQuery // by calling it without a real database connection - _, err := client.ExecuteQuery(tt.query) + _, err := client.ExecuteQuery(context.Background(), tt.query) if tt.shouldAllow { // Should fail with connection error, not validation error @@ -156,7 +157,7 @@ func TestPostgreSQLClient_StateManagement(t *testing.T) { assert.NoError(t, err) // Test Ping on fresh client - err = client.Ping() + err = client.Ping(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") @@ -171,55 +172,55 @@ func TestPostgreSQLClient_ErrorScenarios(t *testing.T) { // Test all methods that check for db == nil t.Run("ListDatabases", func(t *testing.T) { - _, err := client.ListDatabases() + _, err := client.ListDatabases(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("GetCurrentDatabase", func(t *testing.T) { - _, err := client.GetCurrentDatabase() + _, err := client.GetCurrentDatabase(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("ListSchemas", func(t *testing.T) { - _, err := client.ListSchemas() + _, err := client.ListSchemas(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("ListTables", func(t *testing.T) { - _, err := client.ListTables("public") + _, err := client.ListTables(context.Background(), "public") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("DescribeTable", func(t *testing.T) { - _, err := client.DescribeTable("public", "users") + _, err := client.DescribeTable(context.Background(), "public", "users") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("GetTableStats", func(t *testing.T) { - _, err := client.GetTableStats("public", "users") + _, err := client.GetTableStats(context.Background(), "public", "users") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("ListIndexes", func(t *testing.T) { - _, err := client.ListIndexes("public", "users") + _, err := client.ListIndexes(context.Background(), "public", "users") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("ExecuteQuery", func(t *testing.T) { - _, err := client.ExecuteQuery("SELECT 1") + _, err := client.ExecuteQuery(context.Background(), "SELECT 1") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) t.Run("ExplainQuery", func(t *testing.T) { - _, err := client.ExplainQuery("SELECT 1") + _, err := client.ExplainQuery(context.Background(), "SELECT 1") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) @@ -243,15 +244,15 @@ func TestPostgreSQLClient_SchemaDefaults(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // All these will fail with "no database connection" but exercise the schema defaulting logic - _, err := client.GetTableStats(tt.schema, tt.table) + _, err := client.GetTableStats(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.ListIndexes(tt.schema, tt.table) + _, err = client.ListIndexes(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.DescribeTable(tt.schema, tt.table) + _, err = client.DescribeTable(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) diff --git a/internal/app/client_test.go b/internal/app/client_test.go index f5489d6..db1be62 100644 --- a/internal/app/client_test.go +++ b/internal/app/client_test.go @@ -1,6 +1,7 @@ package app import ( + "context" "database/sql" "fmt" "testing" @@ -70,7 +71,7 @@ func TestPostgreSQLClient_Connect_InvalidConnectionString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := client.Connect(tt.connectionString) + err := client.Connect(context.Background(), tt.connectionString) if tt.expectError { assert.Error(t, err) } else { @@ -89,7 +90,7 @@ func TestPostgreSQLClient_CloseWithoutConnection(t *testing.T) { func TestPostgreSQLClient_PingWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - err := client.Ping() + err := client.Ping(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") } @@ -102,7 +103,7 @@ func TestPostgreSQLClient_GetDBWithoutConnection(t *testing.T) { func TestPostgreSQLClient_ListDatabasesWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - databases, err := client.ListDatabases() + databases, err := client.ListDatabases(context.Background()) assert.Error(t, err) assert.Nil(t, databases) assert.Contains(t, err.Error(), "no database connection") @@ -110,7 +111,7 @@ func TestPostgreSQLClient_ListDatabasesWithoutConnection(t *testing.T) { func TestPostgreSQLClient_GetCurrentDatabaseWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - dbName, err := client.GetCurrentDatabase() + dbName, err := client.GetCurrentDatabase(context.Background()) assert.Error(t, err) assert.Empty(t, dbName) assert.Contains(t, err.Error(), "no database connection") @@ -118,7 +119,7 @@ func TestPostgreSQLClient_GetCurrentDatabaseWithoutConnection(t *testing.T) { func TestPostgreSQLClient_ListSchemasWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - schemas, err := client.ListSchemas() + schemas, err := client.ListSchemas(context.Background()) assert.Error(t, err) assert.Nil(t, schemas) assert.Contains(t, err.Error(), "no database connection") @@ -126,7 +127,7 @@ func TestPostgreSQLClient_ListSchemasWithoutConnection(t *testing.T) { func TestPostgreSQLClient_ListTablesWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - tables, err := client.ListTables("public") + tables, err := client.ListTables(context.Background(), "public") assert.Error(t, err) assert.Nil(t, tables) assert.Contains(t, err.Error(), "no database connection") @@ -134,7 +135,7 @@ func TestPostgreSQLClient_ListTablesWithoutConnection(t *testing.T) { func TestPostgreSQLClient_ListTablesWithEmptySchema(t *testing.T) { client := NewPostgreSQLClient() - tables, err := client.ListTables("") + tables, err := client.ListTables(context.Background(), "") assert.Error(t, err) assert.Nil(t, tables) assert.Contains(t, err.Error(), "no database connection") @@ -142,7 +143,7 @@ func TestPostgreSQLClient_ListTablesWithEmptySchema(t *testing.T) { func TestPostgreSQLClient_DescribeTableWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - columns, err := client.DescribeTable("public", "users") + columns, err := client.DescribeTable(context.Background(), "public", "users") assert.Error(t, err) assert.Nil(t, columns) assert.Contains(t, err.Error(), "no database connection") @@ -150,7 +151,7 @@ func TestPostgreSQLClient_DescribeTableWithoutConnection(t *testing.T) { func TestPostgreSQLClient_DescribeTableWithEmptySchema(t *testing.T) { client := NewPostgreSQLClient() - columns, err := client.DescribeTable("", "users") + columns, err := client.DescribeTable(context.Background(), "", "users") assert.Error(t, err) assert.Nil(t, columns) assert.Contains(t, err.Error(), "no database connection") @@ -158,7 +159,7 @@ func TestPostgreSQLClient_DescribeTableWithEmptySchema(t *testing.T) { func TestPostgreSQLClient_GetTableStatsWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - stats, err := client.GetTableStats("public", "users") + stats, err := client.GetTableStats(context.Background(), "public", "users") assert.Error(t, err) assert.Nil(t, stats) assert.Contains(t, err.Error(), "no database connection") @@ -166,7 +167,7 @@ func TestPostgreSQLClient_GetTableStatsWithoutConnection(t *testing.T) { func TestPostgreSQLClient_GetTableStatsWithEmptySchema(t *testing.T) { client := NewPostgreSQLClient() - stats, err := client.GetTableStats("", "users") + stats, err := client.GetTableStats(context.Background(), "", "users") assert.Error(t, err) assert.Nil(t, stats) assert.Contains(t, err.Error(), "no database connection") @@ -174,7 +175,7 @@ func TestPostgreSQLClient_GetTableStatsWithEmptySchema(t *testing.T) { func TestPostgreSQLClient_ListIndexesWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - indexes, err := client.ListIndexes("public", "users") + indexes, err := client.ListIndexes(context.Background(), "public", "users") assert.Error(t, err) assert.Nil(t, indexes) assert.Contains(t, err.Error(), "no database connection") @@ -182,7 +183,7 @@ func TestPostgreSQLClient_ListIndexesWithoutConnection(t *testing.T) { func TestPostgreSQLClient_ListIndexesWithEmptySchema(t *testing.T) { client := NewPostgreSQLClient() - indexes, err := client.ListIndexes("", "users") + indexes, err := client.ListIndexes(context.Background(), "", "users") assert.Error(t, err) assert.Nil(t, indexes) assert.Contains(t, err.Error(), "no database connection") @@ -190,7 +191,7 @@ func TestPostgreSQLClient_ListIndexesWithEmptySchema(t *testing.T) { func TestPostgreSQLClient_ExecuteQueryWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - result, err := client.ExecuteQuery("SELECT 1") + result, err := client.ExecuteQuery(context.Background(), "SELECT 1") assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "no database connection") @@ -269,7 +270,7 @@ func TestPostgreSQLClient_ExecuteQueryInvalidQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := client.ExecuteQuery(tt.query) + result, err := client.ExecuteQuery(context.Background(), tt.query) if tt.expectError { assert.Error(t, err) if tt.errorMsg == "only SELECT and WITH queries are allowed" { @@ -288,7 +289,7 @@ func TestPostgreSQLClient_ExecuteQueryInvalidQueries(t *testing.T) { func TestPostgreSQLClient_ExplainQueryWithoutConnection(t *testing.T) { client := NewPostgreSQLClient() - result, err := client.ExplainQuery("SELECT 1") + result, err := client.ExplainQuery(context.Background(), "SELECT 1") assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "no database connection") @@ -314,7 +315,7 @@ func TestPostgreSQLClient_ExplainQueryValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // This will fail due to no real connection, but we're testing the query validation - result, err := client.ExplainQuery(tt.query) + result, err := client.ExplainQuery(context.Background(), tt.query) assert.Error(t, err) assert.Nil(t, result) // Should fail with connection error since no real connection @@ -329,7 +330,7 @@ func TestConnectionStringValidation(t *testing.T) { client := &PostgreSQLClientImpl{} // Test that Connect properly validates and handles errors - err := client.Connect("postgres://invaliduser:invalidpass@nonexistenthost:5432/nonexistentdb") + err := client.Connect(context.Background(), "postgres://invaliduser:invalidpass@nonexistenthost:5432/nonexistentdb") assert.Error(t, err) assert.Contains(t, err.Error(), "failed to ping database") } @@ -378,19 +379,19 @@ func TestDefaultSchemaHandling(t *testing.T) { t.Run(fmt.Sprintf("schema_%s", tt.inputSchema), func(t *testing.T) { // These will fail due to no connection, but we can verify // that the schema parameter is properly processed - _, err := client.ListTables(tt.inputSchema) + _, err := client.ListTables(context.Background(), tt.inputSchema) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.DescribeTable(tt.inputSchema, "test_table") + _, err = client.DescribeTable(context.Background(), tt.inputSchema, "test_table") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.ListIndexes(tt.inputSchema, "test_table") + _, err = client.ListIndexes(context.Background(), tt.inputSchema, "test_table") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.GetTableStats(tt.inputSchema, "test_table") + _, err = client.GetTableStats(context.Background(), tt.inputSchema, "test_table") assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) @@ -424,15 +425,15 @@ func TestSQLQueryConstruction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Test that functions handle schema and table parameters properly - _, err := client.DescribeTable(tt.schema, tt.table) + _, err := client.DescribeTable(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.ListIndexes(tt.schema, tt.table) + _, err = client.ListIndexes(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") - _, err = client.GetTableStats(tt.schema, tt.table) + _, err = client.GetTableStats(context.Background(), tt.schema, tt.table) assert.Error(t, err) assert.Contains(t, err.Error(), "no database connection") }) diff --git a/internal/app/interfaces.go b/internal/app/interfaces.go index 040e421..2976cb7 100644 --- a/internal/app/interfaces.go +++ b/internal/app/interfaces.go @@ -1,6 +1,7 @@ package app import ( + "context" "database/sql" "errors" ) @@ -77,32 +78,32 @@ type QueryResult struct { // ConnectionManager handles database connection operations. type ConnectionManager interface { - Connect(connectionString string) error + Connect(ctx context.Context, connectionString string) error Close() error - Ping() error + Ping(ctx context.Context) error GetDB() *sql.DB } // DatabaseExplorer handles database-level operations. type DatabaseExplorer interface { - ListDatabases() ([]*DatabaseInfo, error) - GetCurrentDatabase() (string, error) - ListSchemas() ([]*SchemaInfo, error) + ListDatabases(ctx context.Context) ([]*DatabaseInfo, error) + GetCurrentDatabase(ctx context.Context) (string, error) + ListSchemas(ctx context.Context) ([]*SchemaInfo, error) } // TableExplorer handles table-level operations. type TableExplorer interface { - ListTables(schema string) ([]*TableInfo, error) - ListTablesWithStats(schema string) ([]*TableInfo, error) - DescribeTable(schema, table string) ([]*ColumnInfo, error) - GetTableStats(schema, table string) (*TableInfo, error) - ListIndexes(schema, table string) ([]*IndexInfo, error) + ListTables(ctx context.Context, schema string) ([]*TableInfo, error) + ListTablesWithStats(ctx context.Context, schema string) ([]*TableInfo, error) + DescribeTable(ctx context.Context, schema, table string) ([]*ColumnInfo, error) + GetTableStats(ctx context.Context, schema, table string) (*TableInfo, error) + ListIndexes(ctx context.Context, schema, table string) ([]*IndexInfo, error) } // QueryExecutor handles query operations. type QueryExecutor interface { - ExecuteQuery(query string, args ...any) (*QueryResult, error) - ExplainQuery(query string, args ...any) (*QueryResult, error) + ExecuteQuery(ctx context.Context, query string, args ...any) (*QueryResult, error) + ExplainQuery(ctx context.Context, query string, args ...any) (*QueryResult, error) } // PostgreSQLClient interface combines all database operations. diff --git a/main.go b/main.go index ecbb16a..950a289 100644 --- a/main.go +++ b/main.go @@ -135,6 +135,7 @@ func getConnectionString( // handleConnectDatabaseRequest handles the connect_database tool request. func handleConnectDatabaseRequest( + ctx context.Context, args map[string]any, appInstance *app.App, debugLogger *slog.Logger, @@ -147,13 +148,13 @@ func handleConnectDatabaseRequest( } // Attempt to connect - if err := appInstance.Connect(connectionString); err != nil { + if err := appInstance.Connect(ctx, connectionString); err != nil { debugLogger.Error("Failed to connect to database", "error", err) return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil } // Get current database name to confirm connection - dbName, err := appInstance.GetCurrentDatabase() + dbName, err := appInstance.GetCurrentDatabase(ctx) if err != nil { debugLogger.Warn("Connected but failed to get database name", "error", err) dbName = "unknown" @@ -204,7 +205,7 @@ func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLo ) s.AddTool(connectDBTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleConnectDatabaseRequest(request.GetArguments(), appInstance, debugLogger) + return handleConnectDatabaseRequest(ctx, request.GetArguments(), appInstance, debugLogger) }) } @@ -218,7 +219,7 @@ func setupListDatabasesTool(s *server.MCPServer, appInstance *app.App, debugLogg debugLogger.Debug("Received list_databases tool request") // List databases - databases, err := appInstance.ListDatabases() + databases, err := appInstance.ListDatabases(ctx) if err != nil { debugLogger.Error("Failed to list databases", "error", err) return mcp.NewToolResultError(fmt.Sprintf("Failed to list databases: %v", err)), nil @@ -246,7 +247,7 @@ func setupListSchemasTool(s *server.MCPServer, appInstance *app.App, debugLogger debugLogger.Debug("Received list_schemas tool request") // List schemas - schemas, err := appInstance.ListSchemas() + schemas, err := appInstance.ListSchemas(ctx) if err != nil { debugLogger.Error("Failed to list schemas", "error", err) return mcp.NewToolResultError(fmt.Sprintf("Failed to list schemas: %v", err)), nil @@ -294,7 +295,7 @@ func setupListTablesTool(s *server.MCPServer, appInstance *app.App, debugLogger debugLogger.Debug("Processing list_tables request", "schema", opts.Schema, "include_size", opts.IncludeSize) // List tables - tables, err := appInstance.ListTables(opts) + tables, err := appInstance.ListTables(ctx, opts) if err != nil { debugLogger.Error("Failed to list tables", "error", err) return mcp.NewToolResultError(fmt.Sprintf("Failed to list tables: %v", err)), nil @@ -350,7 +351,7 @@ type TableToolConfig struct { Name string Description string TableDesc string - Operation func(appInstance *app.App, schema, table string) (any, error) + Operation func(ctx context.Context, appInstance *app.App, schema, table string) (any, error) SuccessMsg func(result any, schema, table string) (string, []any) ErrorMsg string } @@ -377,7 +378,7 @@ func setupTableTool(s *server.MCPServer, appInstance *app.App, debugLogger *slog return mcp.NewToolResultError(err.Error()), nil } - result, err := config.Operation(appInstance, schema, table) + result, err := config.Operation(ctx, appInstance, schema, table) if err != nil { debugLogger.Error("Failed to "+config.ErrorMsg, "error", err, "schema", schema, "table", table) return mcp.NewToolResultError(fmt.Sprintf("Failed to %s: %v", config.ErrorMsg, err)), nil @@ -400,8 +401,8 @@ func setupDescribeTableTool(s *server.MCPServer, appInstance *app.App, debugLogg Name: "describe_table", Description: "Get detailed information about a table's structure (columns, types, constraints)", TableDesc: "Table name to describe", - Operation: func(appInstance *app.App, schema, table string) (any, error) { - return appInstance.DescribeTable(schema, table) + Operation: func(ctx context.Context, appInstance *app.App, schema, table string) (any, error) { + return appInstance.DescribeTable(ctx, schema, table) }, SuccessMsg: func(result any, schema, table string) (string, []any) { columns, ok := result.([]*app.ColumnInfo) @@ -450,7 +451,7 @@ func setupExecuteQueryTool(s *server.MCPServer, appInstance *app.App, debugLogge debugLogger.Debug("Processing execute_query request", "query", query, "limit", opts.Limit) // Execute query - result, err := appInstance.ExecuteQuery(opts) + result, err := appInstance.ExecuteQuery(ctx, opts) if err != nil { debugLogger.Error("Failed to execute query", "error", err, "query", query) return mcp.NewToolResultError(fmt.Sprintf("Failed to execute query: %v", err)), nil @@ -474,8 +475,8 @@ func setupListIndexesTool(s *server.MCPServer, appInstance *app.App, debugLogger Name: "list_indexes", Description: "List indexes for a specific table", TableDesc: "Table name to list indexes for", - Operation: func(appInstance *app.App, schema, table string) (any, error) { - return appInstance.ListIndexes(schema, table) + Operation: func(ctx context.Context, appInstance *app.App, schema, table string) (any, error) { + return appInstance.ListIndexes(ctx, schema, table) }, SuccessMsg: func(result any, schema, table string) (string, []any) { indexes, ok := result.([]*app.IndexInfo) @@ -512,7 +513,7 @@ func setupExplainQueryTool(s *server.MCPServer, appInstance *app.App, debugLogge debugLogger.Debug("Processing explain_query request", "query", query) // Explain query - result, err := appInstance.ExplainQuery(query) + result, err := appInstance.ExplainQuery(ctx, query) if err != nil { debugLogger.Error("Failed to explain query", "error", err, "query", query) return mcp.NewToolResultError(fmt.Sprintf("Failed to explain query: %v", err)), nil @@ -563,7 +564,7 @@ func setupGetTableStatsTool(s *server.MCPServer, appInstance *app.App, debugLogg debugLogger.Debug("Processing get_table_stats request", "schema", schema, "table", table) // Get table stats - stats, err := appInstance.GetTableStats(schema, table) + stats, err := appInstance.GetTableStats(ctx, schema, table) if err != nil { debugLogger.Error("Failed to get table stats", "error", err, "schema", schema, "table", table) return mcp.NewToolResultError(fmt.Sprintf("Failed to get table stats: %v", err)), nil diff --git a/main_additional_test.go b/main_additional_test.go index e76f70c..36d8e93 100644 --- a/main_additional_test.go +++ b/main_additional_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "os" "testing" @@ -101,7 +102,7 @@ func TestInitializeApp_Implementation(t *testing.T) { app.SetLogger(logger) // App should be in disconnected state initially (without environment variables) - err := app.ValidateConnection() + err := app.ValidateConnection(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "database connection failed") } diff --git a/main_tool_coverage_test.go b/main_tool_coverage_test.go index f898b6e..70934ee 100644 --- a/main_tool_coverage_test.go +++ b/main_tool_coverage_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "testing" @@ -136,7 +137,7 @@ func TestInitializeAppCoverage(t *testing.T) { assert.NotNil(t, logger) // Test that the app is properly initialized - err := app.ValidateConnection() + err := app.ValidateConnection(context.Background()) assert.Error(t, err) // Should error because no connection established // Test setting logger