From 656b90d28935734ee1cec9cdf8f92cff0c93eee9 Mon Sep 17 00:00:00 2001 From: Sylvain <1552102+sgaunet@users.noreply.github.com> Date: Sun, 7 Dec 2025 21:55:26 +0100 Subject: [PATCH] refactor(app): accept client dependency for better testability - Refactor App.New() to accept PostgreSQLClient interface parameter - Add NewDefault() convenience constructor for production use - Update all tests to inject mock clients via New() constructor - Remove tight coupling between App and PostgreSQLClientImpl - Enable dependency injection pattern for improved testability This follows constructor dependency injection pattern and makes it easier to test App without database dependencies. Implements #28 --- integration_test.go | 36 +++++++-------- internal/app/app.go | 18 ++++++-- internal/app/app_test.go | 92 ++++++++++++++++---------------------- main.go | 6 +-- main_test.go | 18 ++++---- main_tool_coverage_test.go | 6 +-- 6 files changed, 87 insertions(+), 89 deletions(-) diff --git a/integration_test.go b/integration_test.go index 3ac21e6..03ac5fa 100644 --- a/integration_test.go +++ b/integration_test.go @@ -141,7 +141,7 @@ func TestIntegration_App_Connect(t *testing.T) { _, connectionString, cleanup := setupTestContainer(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -165,7 +165,7 @@ func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) { os.Setenv("POSTGRES_URL", connectionString) defer os.Unsetenv("POSTGRES_URL") - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -185,7 +185,7 @@ func TestIntegration_App_ListDatabases(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -214,7 +214,7 @@ func TestIntegration_App_ListSchemas(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -241,7 +241,7 @@ func TestIntegration_App_ListTables(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -276,7 +276,7 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -308,7 +308,7 @@ func TestIntegration_App_DescribeTable(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -355,7 +355,7 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -390,7 +390,7 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -418,7 +418,7 @@ func TestIntegration_App_ListIndexes(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -503,7 +503,7 @@ func TestIntegration_App_ListIndexes_SpecialCharacters(t *testing.T) { require.NoError(t, err) // Test with app - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -554,7 +554,7 @@ func TestIntegration_App_ExplainQuery(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -577,7 +577,7 @@ func TestIntegration_App_GetTableStats(t *testing.T) { _, connectionString, cleanup := setupTestDatabase(t) defer cleanup() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -604,7 +604,7 @@ func TestIntegration_App_ErrorHandling(t *testing.T) { os.Setenv("POSTGRES_URL", connectionString) defer os.Unsetenv("POSTGRES_URL") - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) defer appInstance.Disconnect() @@ -635,7 +635,7 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) { defer cleanup() // Test validation without environment variable - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(t, err) ctx := context.Background() @@ -648,7 +648,7 @@ func TestIntegration_App_ConnectionValidation(t *testing.T) { defer os.Unsetenv("POSTGRES_URL") // Create new instance with environment variable set - appInstance2, err := app.New() + appInstance2, err := app.NewDefault() require.NoError(t, err) defer appInstance2.Disconnect() @@ -670,7 +670,7 @@ func BenchmarkIntegration_ListTables(b *testing.B) { ctx := context.Background() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(b, err) defer appInstance.Disconnect() @@ -702,7 +702,7 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) { ctx := context.Background() - appInstance, err := app.New() + appInstance, err := app.NewDefault() require.NoError(b, err) defer appInstance.Disconnect() diff --git a/internal/app/app.go b/internal/app/app.go index a2dd4fe..e919d74 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -33,11 +33,23 @@ type App struct { logger *slog.Logger } -// New creates a new App instance without establishing a connection. +// New creates a new App instance with the provided PostgreSQLClient. +// This constructor accepts a client implementation for dependency injection, +// making it easy to inject mocks or alternative implementations for testing. +func New(client PostgreSQLClient) *App { + return &App{ + client: client, + logger: logger.NewLogger("info"), + } +} + +// NewDefault creates a new App instance with a default PostgreSQLClient. // Use Connect() method or connect_database tool to establish connection. -func New() (*App, error) { +// This is a convenience constructor for production use. +func NewDefault() (*App, error) { + client := NewPostgreSQLClient() app := &App{ - client: NewPostgreSQLClient(), + client: client, logger: logger.NewLogger("info"), } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index df6e74f..b1be252 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -117,7 +117,16 @@ func (m *MockPostgreSQLClient) GetDB() *sql.DB { } func TestNew(t *testing.T) { - app, err := New() + mockClient := &MockPostgreSQLClient{} + app := New(mockClient) + assert.NotNil(t, app) + assert.NotNil(t, app.client) + assert.NotNil(t, app.logger) + assert.Equal(t, mockClient, app.client) +} + +func TestNewDefault(t *testing.T) { + app, err := NewDefault() assert.NoError(t, err) assert.NotNil(t, app) assert.NotNil(t, app.client) @@ -125,7 +134,7 @@ func TestNew(t *testing.T) { } func TestApp_SetLogger(t *testing.T) { - app, _ := New() + app, _ := NewDefault() originalLogger := app.logger // Create a new logger @@ -137,9 +146,8 @@ func TestApp_SetLogger(t *testing.T) { } func TestApp_Disconnect(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) mockClient.On("Close").Return(nil) @@ -149,17 +157,15 @@ func TestApp_Disconnect(t *testing.T) { } func TestApp_DisconnectWithNilClient(t *testing.T) { - app, _ := New() - app.client = nil + app := New(nil) err := app.Disconnect() assert.NoError(t, err) } func TestApp_ValidateConnection(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) mockClient.On("Ping", mock.Anything).Return(nil) @@ -169,8 +175,7 @@ func TestApp_ValidateConnection(t *testing.T) { } func TestApp_ValidateConnectionNilClient(t *testing.T) { - app, _ := New() - app.client = nil + app := New(nil) err := app.ValidateConnection(context.Background()) assert.Error(t, err) @@ -178,9 +183,8 @@ func TestApp_ValidateConnectionNilClient(t *testing.T) { } func TestApp_ValidateConnectionPingError(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) // Mock ping failure and reconnection failure (no env vars set) pingError := errors.New("ping failed") @@ -193,9 +197,8 @@ func TestApp_ValidateConnectionPingError(t *testing.T) { } func TestApp_ListDatabases(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedDatabases := []*DatabaseInfo{ {Name: "db1", Owner: "user1", Encoding: "UTF8"}, @@ -212,9 +215,8 @@ func TestApp_ListDatabases(t *testing.T) { } func TestApp_ListDatabasesConnectionError(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedError := errors.New("connection error") mockClient.On("Ping", mock.Anything).Return(expectedError) @@ -229,9 +231,8 @@ func TestApp_ListDatabasesConnectionError(t *testing.T) { } func TestApp_GetCurrentDatabase(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedDB := "testdb" @@ -245,9 +246,8 @@ func TestApp_GetCurrentDatabase(t *testing.T) { } func TestApp_ListSchemas(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedSchemas := []*SchemaInfo{ {Name: "public", Owner: "postgres"}, @@ -264,9 +264,8 @@ func TestApp_ListSchemas(t *testing.T) { } func TestApp_ListTables(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedTables := []*TableInfo{ {Schema: "public", Name: "users", Type: "table", Owner: "user"}, @@ -287,9 +286,8 @@ func TestApp_ListTables(t *testing.T) { } func TestApp_ListTablesWithDefaultSchema(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedTables := []*TableInfo{ {Schema: "public", Name: "users", Type: "table", Owner: "user"}, @@ -307,9 +305,8 @@ func TestApp_ListTablesWithDefaultSchema(t *testing.T) { } func TestApp_ListTablesWithNilOptions(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedTables := []*TableInfo{ {Schema: "public", Name: "users", Type: "table", Owner: "user"}, @@ -325,9 +322,8 @@ func TestApp_ListTablesWithNilOptions(t *testing.T) { } func TestApp_ListTablesWithSize(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) tablesWithStats := []*TableInfo{ { @@ -357,9 +353,8 @@ func TestApp_ListTablesWithSize(t *testing.T) { } func TestApp_DescribeTable(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedColumns := []*ColumnInfo{ {Name: "id", DataType: "integer", IsNullable: false}, @@ -376,7 +371,7 @@ func TestApp_DescribeTable(t *testing.T) { } func TestApp_DescribeTableEmptyTableName(t *testing.T) { - app, _ := New() + app, _ := NewDefault() columns, err := app.DescribeTable(context.Background(), "public", "") assert.Error(t, err) @@ -385,9 +380,8 @@ func TestApp_DescribeTableEmptyTableName(t *testing.T) { } func TestApp_DescribeTableDefaultSchema(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedColumns := []*ColumnInfo{ {Name: "id", DataType: "integer", IsNullable: false}, @@ -403,9 +397,8 @@ func TestApp_DescribeTableDefaultSchema(t *testing.T) { } func TestApp_ExecuteQuery(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedResult := &QueryResult{ Columns: []string{"id", "name"}, @@ -427,9 +420,8 @@ func TestApp_ExecuteQuery(t *testing.T) { } func TestApp_ExecuteQueryWithLimit(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) originalResult := &QueryResult{ Columns: []string{"id", "name"}, @@ -453,7 +445,7 @@ func TestApp_ExecuteQueryWithLimit(t *testing.T) { } func TestApp_ExecuteQueryNilOptions(t *testing.T) { - app, _ := New() + app, _ := NewDefault() result, err := app.ExecuteQuery(context.Background(), nil) assert.Error(t, err) @@ -462,7 +454,7 @@ func TestApp_ExecuteQueryNilOptions(t *testing.T) { } func TestApp_ExecuteQueryEmptyQuery(t *testing.T) { - app, _ := New() + app, _ := NewDefault() opts := &ExecuteQueryOptions{} @@ -473,9 +465,8 @@ func TestApp_ExecuteQueryEmptyQuery(t *testing.T) { } func TestApp_ExplainQuery(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedResult := &QueryResult{ Columns: []string{"QUERY PLAN"}, @@ -493,7 +484,7 @@ func TestApp_ExplainQuery(t *testing.T) { } func TestApp_ExplainQueryEmptyQuery(t *testing.T) { - app, _ := New() + app, _ := NewDefault() result, err := app.ExplainQuery(context.Background(), "") assert.Error(t, err) @@ -502,9 +493,8 @@ func TestApp_ExplainQueryEmptyQuery(t *testing.T) { } func TestApp_GetTableStats(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedStats := &TableInfo{ Schema: "public", @@ -523,9 +513,8 @@ func TestApp_GetTableStats(t *testing.T) { } func TestApp_ListIndexes(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) expectedIndexes := []*IndexInfo{ {Name: "users_pkey", Table: "users", Columns: []string{"id"}, IsUnique: true, IsPrimary: true}, @@ -542,9 +531,8 @@ func TestApp_ListIndexes(t *testing.T) { } func TestApp_Connect_Success(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) connectionString := "postgres://user:pass@localhost/db" @@ -558,7 +546,7 @@ func TestApp_Connect_Success(t *testing.T) { } func TestApp_Connect_EmptyString(t *testing.T) { - app, _ := New() + app, _ := NewDefault() err := app.Connect(context.Background(), "") assert.Error(t, err) @@ -566,9 +554,8 @@ func TestApp_Connect_EmptyString(t *testing.T) { } func TestApp_Connect_ReconnectClosesExisting(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) connectionString := "postgres://user:pass@localhost/db" @@ -583,9 +570,8 @@ func TestApp_Connect_ReconnectClosesExisting(t *testing.T) { } func TestApp_Connect_ConnectError(t *testing.T) { - app, _ := New() mockClient := &MockPostgreSQLClient{} - app.client = mockClient + app := New(mockClient) connectionString := "postgres://user:pass@localhost/db" expectedError := errors.New("connection failed") diff --git a/main.go b/main.go index 5515d58..75da9a8 100644 --- a/main.go +++ b/main.go @@ -659,8 +659,8 @@ func handleCommandLineFlags() { // initializeApp creates and configures the application instance. func initializeApp() (*app.App, *slog.Logger) { - // Initialize the app - appInstance, err := app.New() + // Initialize the app with default client + appInstance, err := app.NewDefault() if err != nil { log.Fatalf("Failed to initialize app: %v", err) } @@ -732,7 +732,7 @@ func main() { // Create a custom StdioServer with context support stdioServer := server.NewStdioServer(s) - if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil && err != context.Canceled { + if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil && !errors.Is(err, context.Canceled) { debugLogger.Error("Server error", "error", err) fmt.Fprintf(os.Stderr, "Server error: %v\n", err) return diff --git a/main_test.go b/main_test.go index ada34ed..afcd2d4 100644 --- a/main_test.go +++ b/main_test.go @@ -98,7 +98,7 @@ func (m *MockApp) Disconnect() error { func TestSetupListDatabasesTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -109,7 +109,7 @@ func TestSetupListDatabasesTool(t *testing.T) { func TestSetupListSchemasTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -120,7 +120,7 @@ func TestSetupListSchemasTool(t *testing.T) { func TestSetupListTablesTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -131,7 +131,7 @@ func TestSetupListTablesTool(t *testing.T) { func TestSetupDescribeTableTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -142,7 +142,7 @@ func TestSetupDescribeTableTool(t *testing.T) { func TestSetupExecuteQueryTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -153,7 +153,7 @@ func TestSetupExecuteQueryTool(t *testing.T) { func TestSetupListIndexesTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -164,7 +164,7 @@ func TestSetupListIndexesTool(t *testing.T) { func TestSetupExplainQueryTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -175,7 +175,7 @@ func TestSetupExplainQueryTool(t *testing.T) { func TestSetupGetTableStatsTool(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -186,7 +186,7 @@ func TestSetupGetTableStatsTool(t *testing.T) { func TestRegisterAllTools(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - realApp, err := app.New() + realApp, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() diff --git a/main_tool_coverage_test.go b/main_tool_coverage_test.go index 70934ee..35d8a06 100644 --- a/main_tool_coverage_test.go +++ b/main_tool_coverage_test.go @@ -20,7 +20,7 @@ type MockTool struct { // Test that all tool setup functions can be called without panicking func TestAllToolSetupFunctions(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - appInstance, err := app.New() + appInstance, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -62,7 +62,7 @@ func TestAllToolSetupFunctions(t *testing.T) { // Test parameter validation error handling in tool handlers func TestToolParameterValidationErrors(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - appInstance, err := app.New() + appInstance, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default() @@ -116,7 +116,7 @@ func TestJSONResponseHelpers(t *testing.T) { // Test the registerAllTools function func TestRegisterAllToolsFunction(t *testing.T) { s := server.NewMCPServer("test", "1.0.0") - appInstance, err := app.New() + appInstance, err := app.NewDefault() assert.NoError(t, err) logger := slog.Default()