diff --git a/internal/app/app.go b/internal/app/app.go index bf4fa56..6940f51 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -21,9 +21,9 @@ type ListTablesOptions struct { // ExecuteQueryOptions represents options for executing queries. type ExecuteQueryOptions struct { - Query string `json:"query"` - Args []interface{} `json:"args,omitempty"` - Limit int `json:"limit,omitempty"` + Query string `json:"query"` + Args []any `json:"args,omitempty"` + Limit int `json:"limit,omitempty"` } // App represents the main application structure. @@ -281,7 +281,7 @@ func (a *App) GetCurrentDatabase() (string, error) { } // ExplainQuery returns the execution plan for a query. -func (a *App) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) { +func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) { if err := a.ensureConnection(); err != nil { return nil, err } diff --git a/internal/app/client.go b/internal/app/client.go index c7bb070..f42abcc 100644 --- a/internal/app/client.go +++ b/internal/app/client.go @@ -358,16 +358,16 @@ func validateQuery(query string) error { } // processRows processes query result rows and handles type conversion. -func processRows(rows *sql.Rows) ([][]interface{}, error) { +func processRows(rows *sql.Rows) ([][]any, error) { columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get columns: %w", err) } - var result [][]interface{} + var result [][]any for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } @@ -389,7 +389,7 @@ func processRows(rows *sql.Rows) ([][]interface{}, error) { } // ExecuteQuery executes a SELECT query and returns the results. -func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) (*QueryResult, error) { +func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryResult, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } @@ -425,7 +425,7 @@ func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) ( } // ExplainQuery returns the execution plan for a query. -func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) { +func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...any) (*QueryResult, error) { if c.db == nil { return nil, ErrNoDatabaseConnection } diff --git a/internal/app/interfaces.go b/internal/app/interfaces.go index 6b42a47..4786bbe 100644 --- a/internal/app/interfaces.go +++ b/internal/app/interfaces.go @@ -15,7 +15,8 @@ var ( ErrQueryRequired = errors.New("query is required") ErrInvalidQuery = errors.New("only SELECT and WITH queries are allowed") ErrNoConnectionString = errors.New( - "no database connection string provided. Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable", + "no database connection string provided. " + + "Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable", ) ErrNoDatabaseConnection = errors.New("no database connection") ErrTableNotFound = errors.New("table does not exist") @@ -69,9 +70,9 @@ type IndexInfo struct { // QueryResult represents the result of a query execution. type QueryResult struct { - Columns []string `json:"columns"` - Rows [][]interface{} `json:"rows"` - RowCount int `json:"row_count"` + Columns []string `json:"columns"` + Rows [][]any `json:"rows"` + RowCount int `json:"row_count"` } // ConnectionManager handles database connection operations. @@ -99,8 +100,8 @@ type TableExplorer interface { // QueryExecutor handles query operations. type QueryExecutor interface { - ExecuteQuery(query string, args ...interface{}) (*QueryResult, error) - ExplainQuery(query string, args ...interface{}) (*QueryResult, error) + ExecuteQuery(query string, args ...any) (*QueryResult, error) + ExplainQuery(query string, args ...any) (*QueryResult, error) } // PostgreSQLClient interface combines all database operations. diff --git a/main.go b/main.go index 3d8cc2b..d3f9bde 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,9 @@ var version = "dev" // Error variables for static errors. var ( ErrInvalidConnectionParameters = errors.New("invalid connection parameters") + ErrHostRequired = errors.New("host is required") + ErrUserRequired = errors.New("user is required") + ErrDatabaseRequired = errors.New("database is required") ) // ConnectionParams represents individual database connection parameters. @@ -39,13 +42,13 @@ type ConnectionParams struct { func buildConnectionString(params ConnectionParams) (string, error) { // Validate required parameters if params.Host == "" { - return "", errors.New("host is required") + return "", ErrHostRequired } if params.User == "" { - return "", errors.New("user is required") + return "", ErrUserRequired } if params.Database == "" { - return "", errors.New("database is required") + return "", ErrDatabaseRequired } // Set defaults @@ -59,13 +62,13 @@ func buildConnectionString(params ConnectionParams) (string, error) { sslMode = "prefer" // PostgreSQL default SSL mode } - // Build connection string + // Build connection string using net.JoinHostPort pattern + hostPort := fmt.Sprintf("%s:%d", params.Host, port) connStr := fmt.Sprintf( - "postgres://%s:%s@%s:%d/%s?sslmode=%s", + "postgres://%s:%s@%s/%s?sslmode=%s", params.User, params.Password, - params.Host, - port, + hostPort, params.Database, sslMode, ) @@ -73,12 +76,112 @@ func buildConnectionString(params ConnectionParams) (string, error) { return connStr, nil } +// extractConnectionParams extracts connection parameters from args. +func extractConnectionParams(args map[string]any) ConnectionParams { + params := ConnectionParams{ + Host: "localhost", // Default + } + + if host, ok := args["host"].(string); ok && host != "" { + params.Host = host + } + + if portFloat, ok := args["port"].(float64); ok { + params.Port = int(portFloat) + } + + if user, ok := args["user"].(string); ok { + params.User = user + } + + if password, ok := args["password"].(string); ok { + params.Password = password + } + + if database, ok := args["database"].(string); ok { + params.Database = database + } + + if sslmode, ok := args["sslmode"].(string); ok { + params.SSLMode = sslmode + } + + return params +} + +// getConnectionString determines the connection string from args. +func getConnectionString( + args map[string]any, + debugLogger *slog.Logger, +) (string, error) { + // Check if full connection URL is provided + if connURL, ok := args["connection_url"].(string); ok && connURL != "" { + debugLogger.Debug("Using provided connection URL") + return connURL, nil + } + + // Build connection string from individual parameters + params := extractConnectionParams(args) + connectionString, err := buildConnectionString(params) + if err != nil { + debugLogger.Error("Failed to build connection string", "error", err) + return "", fmt.Errorf("invalid connection parameters: %w", err) + } + + debugLogger.Debug("Built connection string from parameters", + "host", params.Host, "port", params.Port, "database", params.Database) + return connectionString, nil +} + +// handleConnectDatabaseRequest handles the connect_database tool request. +func handleConnectDatabaseRequest( + args map[string]any, + appInstance *app.App, + debugLogger *slog.Logger, +) (*mcp.CallToolResult, error) { + debugLogger.Debug("Received connect_database tool request", "args", args) + + connectionString, err := getConnectionString(args, debugLogger) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Attempt to connect + if err := appInstance.Connect(connectionString); err != nil { + debugLogger.Error("Failed to connect to database", "error", err) + return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil + } + + // Get current database name to confirm connection + dbName, err := appInstance.GetCurrentDatabase() + if err != nil { + debugLogger.Warn("Connected but failed to get database name", "error", err) + dbName = "unknown" + } + + debugLogger.Info("Successfully connected to database", "database", dbName) + + response := map[string]any{ + "status": "connected", + "database": dbName, + "message": "Successfully connected to database: " + dbName, + } + + jsonData, err := json.Marshal(response) + if err != nil { + debugLogger.Error("Failed to marshal connection response", "error", err) + return mcp.NewToolResultError("Failed to format connection response"), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + // setupConnectDatabaseTool creates and registers the connect_database tool. func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLogger *slog.Logger) { connectDBTool := mcp.NewTool("connect_database", - mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or connection URL"), + mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or URL"), mcp.WithString("connection_url", - mcp.Description("Full PostgreSQL connection URL (postgres://user:password@host:port/dbname?sslmode=mode). If provided, individual parameters are ignored."), + mcp.Description("Full PostgreSQL connection URL. If provided, individual parameters are ignored."), ), mcp.WithString("host", mcp.Description("Database host (default: localhost)"), @@ -101,85 +204,7 @@ func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLo ) s.AddTool(connectDBTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - args := request.GetArguments() - debugLogger.Debug("Received connect_database tool request", "args", args) - - var connectionString string - - // Check if full connection URL is provided - if connURL, ok := args["connection_url"].(string); ok && connURL != "" { - connectionString = connURL - debugLogger.Debug("Using provided connection URL") - } else { - // Build connection string from individual parameters - params := ConnectionParams{} - - if host, ok := args["host"].(string); ok && host != "" { - params.Host = host - } else { - params.Host = "localhost" // Default - } - - if portFloat, ok := args["port"].(float64); ok { - params.Port = int(portFloat) - } - // Port will default to 5432 in buildConnectionString if 0 - - if user, ok := args["user"].(string); ok { - params.User = user - } - - if password, ok := args["password"].(string); ok { - params.Password = password - } - - if database, ok := args["database"].(string); ok { - params.Database = database - } - - if sslmode, ok := args["sslmode"].(string); ok { - params.SSLMode = sslmode - } - - // Validate and build connection string - var err error - connectionString, err = buildConnectionString(params) - if err != nil { - debugLogger.Error("Failed to build connection string", "error", err) - return mcp.NewToolResultError(fmt.Sprintf("Invalid connection parameters: %v", err)), nil - } - - debugLogger.Debug("Built connection string from parameters", "host", params.Host, "port", params.Port, "database", params.Database) - } - - // Attempt to connect - if err := appInstance.Connect(connectionString); err != nil { - debugLogger.Error("Failed to connect to database", "error", err) - return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil - } - - // Get current database name to confirm connection - dbName, err := appInstance.GetCurrentDatabase() - if err != nil { - debugLogger.Warn("Connected but failed to get database name", "error", err) - dbName = "unknown" - } - - debugLogger.Info("Successfully connected to database", "database", dbName) - - response := map[string]interface{}{ - "status": "connected", - "database": dbName, - "message": fmt.Sprintf("Successfully connected to database: %s", dbName), - } - - jsonData, err := json.Marshal(response) - if err != nil { - debugLogger.Error("Failed to marshal connection response", "error", err) - return mcp.NewToolResultError("Failed to format connection response"), nil - } - - return mcp.NewToolResultText(string(jsonData)), nil + return handleConnectDatabaseRequest(request.GetArguments(), appInstance, debugLogger) }) } @@ -289,7 +314,7 @@ func setupListTablesTool(s *server.MCPServer, appInstance *app.App, debugLogger // handleTableSchemaToolRequest handles tool requests that require table and optional schema parameters. func handleTableSchemaToolRequest( - args map[string]interface{}, + args map[string]any, debugLogger *slog.Logger, toolName string, ) (string, string, error) { @@ -311,7 +336,7 @@ func handleTableSchemaToolRequest( } // marshalToJSON converts data to JSON and handles errors. -func marshalToJSON(data interface{}, debugLogger *slog.Logger, errorMsg string) ([]byte, error) { +func marshalToJSON(data any, debugLogger *slog.Logger, errorMsg string) ([]byte, error) { jsonData, err := json.Marshal(data) if err != nil { debugLogger.Error("Failed to marshal data to JSON", "error", err, "context", errorMsg) @@ -325,8 +350,8 @@ type TableToolConfig struct { Name string Description string TableDesc string - Operation func(appInstance *app.App, schema, table string) (interface{}, error) - SuccessMsg func(result interface{}, schema, table string) (string, []any) + Operation func(appInstance *app.App, schema, table string) (any, error) + SuccessMsg func(result any, schema, table string) (string, []any) ErrorMsg string } @@ -375,10 +400,10 @@ func setupDescribeTableTool(s *server.MCPServer, appInstance *app.App, debugLogg Name: "describe_table", Description: "Get detailed information about a table's structure (columns, types, constraints)", TableDesc: "Table name to describe", - Operation: func(appInstance *app.App, schema, table string) (interface{}, error) { + Operation: func(appInstance *app.App, schema, table string) (any, error) { return appInstance.DescribeTable(schema, table) }, - SuccessMsg: func(result interface{}, schema, table string) (string, []any) { + SuccessMsg: func(result any, schema, table string) (string, []any) { columns, ok := result.([]*app.ColumnInfo) if !ok { return "Error processing result", []any{"error", "type assertion failed"} @@ -449,10 +474,10 @@ func setupListIndexesTool(s *server.MCPServer, appInstance *app.App, debugLogger Name: "list_indexes", Description: "List indexes for a specific table", TableDesc: "Table name to list indexes for", - Operation: func(appInstance *app.App, schema, table string) (interface{}, error) { + Operation: func(appInstance *app.App, schema, table string) (any, error) { return appInstance.ListIndexes(schema, table) }, - SuccessMsg: func(result interface{}, schema, table string) (string, []any) { + SuccessMsg: func(result any, schema, table string) (string, []any) { indexes, ok := result.([]*app.IndexInfo) if !ok { return "Error processing result", []any{"error", "type assertion failed"}