Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ type ListTablesOptions struct {

// ExecuteQueryOptions represents options for executing queries.
type ExecuteQueryOptions struct {
Query string `json:"query"`
Args []interface{} `json:"args,omitempty"`
Limit int `json:"limit,omitempty"`
Query string `json:"query"`
Args []any `json:"args,omitempty"`
Limit int `json:"limit,omitempty"`
}

// App represents the main application structure.
Expand Down Expand Up @@ -281,7 +281,7 @@ func (a *App) GetCurrentDatabase() (string, error) {
}

// ExplainQuery returns the execution plan for a query.
func (a *App) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) {
func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) {
if err := a.ensureConnection(); err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions internal/app/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,16 +358,16 @@ func validateQuery(query string) error {
}

// processRows processes query result rows and handles type conversion.
func processRows(rows *sql.Rows) ([][]interface{}, error) {
func processRows(rows *sql.Rows) ([][]any, error) {
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}

var result [][]interface{}
var result [][]any
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
Expand All @@ -389,7 +389,7 @@ func processRows(rows *sql.Rows) ([][]interface{}, error) {
}

// ExecuteQuery executes a SELECT query and returns the results.
func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) (*QueryResult, error) {
func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryResult, error) {
if c.db == nil {
return nil, ErrNoDatabaseConnection
}
Expand Down Expand Up @@ -425,7 +425,7 @@ func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) (
}

// ExplainQuery returns the execution plan for a query.
func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) {
func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...any) (*QueryResult, error) {
if c.db == nil {
return nil, ErrNoDatabaseConnection
}
Expand Down
13 changes: 7 additions & 6 deletions internal/app/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ var (
ErrQueryRequired = errors.New("query is required")
ErrInvalidQuery = errors.New("only SELECT and WITH queries are allowed")
ErrNoConnectionString = errors.New(
"no database connection string provided. Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable",
"no database connection string provided. " +
"Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable",
)
ErrNoDatabaseConnection = errors.New("no database connection")
ErrTableNotFound = errors.New("table does not exist")
Expand Down Expand Up @@ -69,9 +70,9 @@ type IndexInfo struct {

// QueryResult represents the result of a query execution.
type QueryResult struct {
Columns []string `json:"columns"`
Rows [][]interface{} `json:"rows"`
RowCount int `json:"row_count"`
Columns []string `json:"columns"`
Rows [][]any `json:"rows"`
RowCount int `json:"row_count"`
}

// ConnectionManager handles database connection operations.
Expand Down Expand Up @@ -99,8 +100,8 @@ type TableExplorer interface {

// QueryExecutor handles query operations.
type QueryExecutor interface {
ExecuteQuery(query string, args ...interface{}) (*QueryResult, error)
ExplainQuery(query string, args ...interface{}) (*QueryResult, error)
ExecuteQuery(query string, args ...any) (*QueryResult, error)
ExplainQuery(query string, args ...any) (*QueryResult, error)
}

// PostgreSQLClient interface combines all database operations.
Expand Down
217 changes: 121 additions & 96 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ var version = "dev"
// Error variables for static errors.
var (
ErrInvalidConnectionParameters = errors.New("invalid connection parameters")
ErrHostRequired = errors.New("host is required")
ErrUserRequired = errors.New("user is required")
ErrDatabaseRequired = errors.New("database is required")
)

// ConnectionParams represents individual database connection parameters.
Expand All @@ -39,13 +42,13 @@ type ConnectionParams struct {
func buildConnectionString(params ConnectionParams) (string, error) {
// Validate required parameters
if params.Host == "" {
return "", errors.New("host is required")
return "", ErrHostRequired
}
if params.User == "" {
return "", errors.New("user is required")
return "", ErrUserRequired
}
if params.Database == "" {
return "", errors.New("database is required")
return "", ErrDatabaseRequired
}

// Set defaults
Expand All @@ -59,26 +62,126 @@ func buildConnectionString(params ConnectionParams) (string, error) {
sslMode = "prefer" // PostgreSQL default SSL mode
}

// Build connection string
// Build connection string using net.JoinHostPort pattern
hostPort := fmt.Sprintf("%s:%d", params.Host, port)
connStr := fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s?sslmode=%s",
"postgres://%s:%s@%s/%s?sslmode=%s",
params.User,
params.Password,
params.Host,
port,
hostPort,
params.Database,
sslMode,
)

return connStr, nil
}

// extractConnectionParams extracts connection parameters from args.
func extractConnectionParams(args map[string]any) ConnectionParams {
params := ConnectionParams{
Host: "localhost", // Default
}

if host, ok := args["host"].(string); ok && host != "" {
params.Host = host
}

if portFloat, ok := args["port"].(float64); ok {
params.Port = int(portFloat)
}

if user, ok := args["user"].(string); ok {
params.User = user
}

if password, ok := args["password"].(string); ok {
params.Password = password
}

if database, ok := args["database"].(string); ok {
params.Database = database
}

if sslmode, ok := args["sslmode"].(string); ok {
params.SSLMode = sslmode
}

return params
}

// getConnectionString determines the connection string from args.
func getConnectionString(
args map[string]any,
debugLogger *slog.Logger,
) (string, error) {
// Check if full connection URL is provided
if connURL, ok := args["connection_url"].(string); ok && connURL != "" {
debugLogger.Debug("Using provided connection URL")
return connURL, nil
}

// Build connection string from individual parameters
params := extractConnectionParams(args)
connectionString, err := buildConnectionString(params)
if err != nil {
debugLogger.Error("Failed to build connection string", "error", err)
return "", fmt.Errorf("invalid connection parameters: %w", err)
}

debugLogger.Debug("Built connection string from parameters",
"host", params.Host, "port", params.Port, "database", params.Database)
return connectionString, nil
}

// handleConnectDatabaseRequest handles the connect_database tool request.
func handleConnectDatabaseRequest(
args map[string]any,
appInstance *app.App,
debugLogger *slog.Logger,
) (*mcp.CallToolResult, error) {
debugLogger.Debug("Received connect_database tool request", "args", args)

connectionString, err := getConnectionString(args, debugLogger)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Attempt to connect
if err := appInstance.Connect(connectionString); err != nil {
debugLogger.Error("Failed to connect to database", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil
}

// Get current database name to confirm connection
dbName, err := appInstance.GetCurrentDatabase()
if err != nil {
debugLogger.Warn("Connected but failed to get database name", "error", err)
dbName = "unknown"
}

debugLogger.Info("Successfully connected to database", "database", dbName)

response := map[string]any{
"status": "connected",
"database": dbName,
"message": "Successfully connected to database: " + dbName,
}

jsonData, err := json.Marshal(response)
if err != nil {
debugLogger.Error("Failed to marshal connection response", "error", err)
return mcp.NewToolResultError("Failed to format connection response"), nil
}

return mcp.NewToolResultText(string(jsonData)), nil
}

// setupConnectDatabaseTool creates and registers the connect_database tool.
func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLogger *slog.Logger) {
connectDBTool := mcp.NewTool("connect_database",
mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or connection URL"),
mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or URL"),
mcp.WithString("connection_url",
mcp.Description("Full PostgreSQL connection URL (postgres://user:password@host:port/dbname?sslmode=mode). If provided, individual parameters are ignored."),
mcp.Description("Full PostgreSQL connection URL. If provided, individual parameters are ignored."),
),
mcp.WithString("host",
mcp.Description("Database host (default: localhost)"),
Expand All @@ -101,85 +204,7 @@ func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLo
)

s.AddTool(connectDBTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.GetArguments()
debugLogger.Debug("Received connect_database tool request", "args", args)

var connectionString string

// Check if full connection URL is provided
if connURL, ok := args["connection_url"].(string); ok && connURL != "" {
connectionString = connURL
debugLogger.Debug("Using provided connection URL")
} else {
// Build connection string from individual parameters
params := ConnectionParams{}

if host, ok := args["host"].(string); ok && host != "" {
params.Host = host
} else {
params.Host = "localhost" // Default
}

if portFloat, ok := args["port"].(float64); ok {
params.Port = int(portFloat)
}
// Port will default to 5432 in buildConnectionString if 0

if user, ok := args["user"].(string); ok {
params.User = user
}

if password, ok := args["password"].(string); ok {
params.Password = password
}

if database, ok := args["database"].(string); ok {
params.Database = database
}

if sslmode, ok := args["sslmode"].(string); ok {
params.SSLMode = sslmode
}

// Validate and build connection string
var err error
connectionString, err = buildConnectionString(params)
if err != nil {
debugLogger.Error("Failed to build connection string", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("Invalid connection parameters: %v", err)), nil
}

debugLogger.Debug("Built connection string from parameters", "host", params.Host, "port", params.Port, "database", params.Database)
}

// Attempt to connect
if err := appInstance.Connect(connectionString); err != nil {
debugLogger.Error("Failed to connect to database", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil
}

// Get current database name to confirm connection
dbName, err := appInstance.GetCurrentDatabase()
if err != nil {
debugLogger.Warn("Connected but failed to get database name", "error", err)
dbName = "unknown"
}

debugLogger.Info("Successfully connected to database", "database", dbName)

response := map[string]interface{}{
"status": "connected",
"database": dbName,
"message": fmt.Sprintf("Successfully connected to database: %s", dbName),
}

jsonData, err := json.Marshal(response)
if err != nil {
debugLogger.Error("Failed to marshal connection response", "error", err)
return mcp.NewToolResultError("Failed to format connection response"), nil
}

return mcp.NewToolResultText(string(jsonData)), nil
return handleConnectDatabaseRequest(request.GetArguments(), appInstance, debugLogger)
})
}

Expand Down Expand Up @@ -289,7 +314,7 @@ func setupListTablesTool(s *server.MCPServer, appInstance *app.App, debugLogger

// handleTableSchemaToolRequest handles tool requests that require table and optional schema parameters.
func handleTableSchemaToolRequest(
args map[string]interface{},
args map[string]any,
debugLogger *slog.Logger,
toolName string,
) (string, string, error) {
Expand All @@ -311,7 +336,7 @@ func handleTableSchemaToolRequest(
}

// marshalToJSON converts data to JSON and handles errors.
func marshalToJSON(data interface{}, debugLogger *slog.Logger, errorMsg string) ([]byte, error) {
func marshalToJSON(data any, debugLogger *slog.Logger, errorMsg string) ([]byte, error) {
jsonData, err := json.Marshal(data)
if err != nil {
debugLogger.Error("Failed to marshal data to JSON", "error", err, "context", errorMsg)
Expand All @@ -325,8 +350,8 @@ type TableToolConfig struct {
Name string
Description string
TableDesc string
Operation func(appInstance *app.App, schema, table string) (interface{}, error)
SuccessMsg func(result interface{}, schema, table string) (string, []any)
Operation func(appInstance *app.App, schema, table string) (any, error)
SuccessMsg func(result any, schema, table string) (string, []any)
ErrorMsg string
}

Expand Down Expand Up @@ -375,10 +400,10 @@ func setupDescribeTableTool(s *server.MCPServer, appInstance *app.App, debugLogg
Name: "describe_table",
Description: "Get detailed information about a table's structure (columns, types, constraints)",
TableDesc: "Table name to describe",
Operation: func(appInstance *app.App, schema, table string) (interface{}, error) {
Operation: func(appInstance *app.App, schema, table string) (any, error) {
return appInstance.DescribeTable(schema, table)
},
SuccessMsg: func(result interface{}, schema, table string) (string, []any) {
SuccessMsg: func(result any, schema, table string) (string, []any) {
columns, ok := result.([]*app.ColumnInfo)
if !ok {
return "Error processing result", []any{"error", "type assertion failed"}
Expand Down Expand Up @@ -449,10 +474,10 @@ func setupListIndexesTool(s *server.MCPServer, appInstance *app.App, debugLogger
Name: "list_indexes",
Description: "List indexes for a specific table",
TableDesc: "Table name to list indexes for",
Operation: func(appInstance *app.App, schema, table string) (interface{}, error) {
Operation: func(appInstance *app.App, schema, table string) (any, error) {
return appInstance.ListIndexes(schema, table)
},
SuccessMsg: func(result interface{}, schema, table string) (string, []any) {
SuccessMsg: func(result any, schema, table string) (string, []any) {
indexes, ok := result.([]*app.IndexInfo)
if !ok {
return "Error processing result", []any{"error", "type assertion failed"}
Expand Down
Loading