Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 64 additions & 27 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@ func setupTestContainer(t *testing.T) (*postgres.PostgresContainer, string, func
return postgresContainer, connStr, cleanup
}

func setupTestDatabase(t *testing.T) (*sql.DB, func()) {
func setupTestDatabase(t *testing.T) (*sql.DB, string, func()) {
_, connectionString, containerCleanup := setupTestContainer(t)

// Set environment variable for the app to use
os.Setenv("POSTGRES_URL", connectionString)

// Connect to PostgreSQL
db, err := sql.Open("postgres", connectionString)
require.NoError(t, err)
Expand Down Expand Up @@ -134,56 +131,63 @@ func setupTestDatabase(t *testing.T) (*sql.DB, func()) {
cleanup := func() {
_, _ = db.ExecContext(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", testSchema))
db.Close()
os.Unsetenv("POSTGRES_URL")
containerCleanup() // Clean up container
}

return db, cleanup
return db, connectionString, cleanup
}

func TestIntegration_App_Connect(t *testing.T) {
_, connectionString, cleanup := setupTestContainer(t)
defer cleanup()

// Set environment variable for connection
os.Setenv("POSTGRES_URL", connectionString)
defer os.Unsetenv("POSTGRES_URL")

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

// Test explicit connection with connection string
err = appInstance.Connect(connectionString)
require.NoError(t, err)

// Test that we can get current database
dbName, err := appInstance.GetCurrentDatabase()
assert.NoError(t, err)
assert.NotEmpty(t, dbName)
}

func TestIntegration_App_ConnectWithDatabaseURL(t *testing.T) {
func TestIntegration_App_ConnectWithEnvironmentVariable(t *testing.T) {
_, connectionString, cleanup := setupTestContainer(t)
defer cleanup()

// Test with DATABASE_URL environment variable
os.Setenv("DATABASE_URL", connectionString)
defer os.Unsetenv("DATABASE_URL")
// Test with POSTGRES_URL environment variable (backward compatibility via tryConnect)
os.Setenv("POSTGRES_URL", connectionString)
defer os.Unsetenv("POSTGRES_URL")

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

// Test that connection works
// Explicitly call ensureConnection which will trigger tryConnect() fallback
err = appInstance.ValidateConnection()
assert.NoError(t, err)

// Verify connection works
dbName, err := appInstance.GetCurrentDatabase()
assert.NoError(t, err)
assert.NotEmpty(t, dbName)
}

func TestIntegration_App_ListDatabases(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

databases, err := appInstance.ListDatabases()
assert.NoError(t, err)
assert.NotEmpty(t, databases)
Expand All @@ -201,13 +205,16 @@ func TestIntegration_App_ListDatabases(t *testing.T) {
}

func TestIntegration_App_ListSchemas(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

schemas, err := appInstance.ListSchemas()
assert.NoError(t, err)
assert.NotEmpty(t, schemas)
Expand All @@ -223,13 +230,16 @@ func TestIntegration_App_ListSchemas(t *testing.T) {
}

func TestIntegration_App_ListTables(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

// List tables in test schema
listOpts := &app.ListTablesOptions{
Schema: "test_mcp_schema",
Expand All @@ -253,13 +263,16 @@ func TestIntegration_App_ListTables(t *testing.T) {
}

func TestIntegration_App_ListTablesWithSize(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

// List tables with size information
listOpts := &app.ListTablesOptions{
Schema: "test_mcp_schema",
Expand All @@ -280,13 +293,16 @@ func TestIntegration_App_ListTablesWithSize(t *testing.T) {
}

func TestIntegration_App_DescribeTable(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

columns, err := appInstance.DescribeTable("test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotEmpty(t, columns)
Expand Down Expand Up @@ -322,13 +338,16 @@ func TestIntegration_App_DescribeTable(t *testing.T) {
}

func TestIntegration_App_ExecuteQuery(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

// Test simple SELECT query
queryOpts := &app.ExecuteQueryOptions{
Query: "SELECT id, name, email FROM test_mcp_schema.test_users WHERE active = true ORDER BY id",
Expand All @@ -352,13 +371,16 @@ func TestIntegration_App_ExecuteQuery(t *testing.T) {
}

func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

// Test query with limit
queryOpts := &app.ExecuteQueryOptions{
Query: "SELECT * FROM test_mcp_schema.test_users ORDER BY id",
Expand All @@ -375,13 +397,16 @@ func TestIntegration_App_ExecuteQueryWithLimit(t *testing.T) {
}

func TestIntegration_App_ListIndexes(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

indexes, err := appInstance.ListIndexes("test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotEmpty(t, indexes)
Expand Down Expand Up @@ -412,13 +437,16 @@ func TestIntegration_App_ListIndexes(t *testing.T) {
}

func TestIntegration_App_ExplainQuery(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

// Test EXPLAIN query
result, err := appInstance.ExplainQuery("SELECT * FROM test_mcp_schema.test_users WHERE active = true")
require.NoError(t, err)
Expand All @@ -430,13 +458,16 @@ func TestIntegration_App_ExplainQuery(t *testing.T) {
}

func TestIntegration_App_GetTableStats(t *testing.T) {
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(t, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(t, err)

stats, err := appInstance.GetTableStats("test_mcp_schema", "test_users")
assert.NoError(t, err)
assert.NotNil(t, stats)
Expand Down Expand Up @@ -512,13 +543,16 @@ func BenchmarkIntegration_ListTables(b *testing.B) {

// Use a testing.T wrapper for setup functions
t := &testing.T{}
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(b, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(b, err)

listOpts := &app.ListTablesOptions{
Schema: "test_mcp_schema",
}
Expand All @@ -539,13 +573,16 @@ func BenchmarkIntegration_ExecuteQuery(b *testing.B) {

// Use a testing.T wrapper for setup functions
t := &testing.T{}
_, cleanup := setupTestDatabase(t)
_, connectionString, cleanup := setupTestDatabase(t)
defer cleanup()

appInstance, err := app.New()
require.NoError(b, err)
defer appInstance.Disconnect()

err = appInstance.Connect(connectionString)
require.NoError(b, err)

queryOpts := &app.ExecuteQueryOptions{
Query: "SELECT COUNT(*) FROM test_mcp_schema.test_users",
}
Expand Down
52 changes: 36 additions & 16 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ type App struct {
logger *slog.Logger
}

// New creates a new App instance and attempts to connect to the database.
// New creates a new App instance without establishing a connection.
// Use Connect() method or connect_database tool to establish connection.
func New() (*App, error) {
app := &App{
client: NewPostgreSQLClient(),
logger: logger.NewLogger("info"),
}

// Attempt initial connection
if err := app.tryConnect(); err != nil {
app.logger.Warn("Could not connect to database on startup, will retry on first tool request", "error", err)
}
// Note: Connection is now explicit via Connect() or connect_database tool
// Environment variables are still supported as fallback via tryConnect()

return app, nil
}
Expand All @@ -52,6 +51,34 @@ func (a *App) SetLogger(logger *slog.Logger) {
a.logger = logger
}

// Connect establishes a database connection with the provided connection string.
// If a connection already exists, it will be closed before establishing a new one.
func (a *App) Connect(connectionString string) error {
if connectionString == "" {
return ErrNoConnectionString
}

// Close existing connection if any
if a.client != nil {
if err := a.client.Ping(); err == nil {
// Connection exists and is active, close it first
if closeErr := a.client.Close(); closeErr != nil {
a.logger.Warn("Failed to close existing connection", "error", closeErr)
}
}
}

a.logger.Debug("Connecting to PostgreSQL database")

if err := a.client.Connect(connectionString); err != nil {
a.logger.Error("Failed to connect to database", "error", err)
return fmt.Errorf("failed to connect: %w", err)
}

a.logger.Info("Successfully connected to PostgreSQL database")
return nil
}

// Disconnect closes the database connection.
func (a *App) Disconnect() error {
if a.client != nil {
Expand Down Expand Up @@ -280,9 +307,10 @@ func (a *App) ValidateConnection() error {
return a.ensureConnection()
}

// tryConnect attempts to connect to the database using environment variables.
// tryConnect attempts to connect using environment variables as a fallback mechanism.
// Returns ErrNoConnectionString if no environment variables are set.
func (a *App) tryConnect() error {
// Try environment variables
// Try environment variables as fallback
connectionString := os.Getenv("POSTGRES_URL")
if connectionString == "" {
connectionString = os.Getenv("DATABASE_URL")
Expand All @@ -292,15 +320,7 @@ func (a *App) tryConnect() error {
return ErrNoConnectionString
}

a.logger.Debug("Connecting to PostgreSQL database")

if err := a.client.Connect(connectionString); err != nil {
a.logger.Error("Failed to connect to database", "error", err)
return fmt.Errorf("failed to connect: %w", err)
}

a.logger.Info("Successfully connected to PostgreSQL database")
return nil
return a.Connect(connectionString)
}

// ensureConnection checks if the database connection is valid and attempts to reconnect if needed.
Expand Down
Loading
Loading