diff --git a/README.md b/README.md index fe26e192..defe15cc 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,24 @@ sqlcmd If no current context exists, `sqlcmd` (with no connection parameters) reverts to the original ODBC `sqlcmd` behavior of creating an interactive session to the default local instance on port 1433 using trusted authentication, otherwise it will create an interactive session to the current context. +### Interactive Mode Commands + +In interactive mode, `sqlcmd` supports several special commands. The `EXIT` command can execute a query and use its result as the exit code: + +``` +1> EXIT(SELECT 100) +``` + +For complex queries, `EXIT(query)` can span multiple lines. When parentheses are unbalanced, `sqlcmd` prompts for continuation: + +``` +1> EXIT(SELECT 1 + -> + 2 + -> + 3) +``` + +The query result (6 in this example) becomes the process exit code. + ## Sqlcmd The `sqlcmd` project aims to be a complete port of the original ODBC sqlcmd to the `Go` language, utilizing the [go-mssqldb][] driver. For full documentation of the tool and installation instructions, see [go-sqlcmd-utility][]. @@ -134,7 +152,6 @@ The following switches have different behavior in this version of `sqlcmd` compa - More information about client/server encryption negotiation can be found at - `-u` The generated Unicode output file will have the UTF16 Little-Endian Byte-order mark (BOM) written to it. - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. -- All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. - `-i` doesn't handle a comma `,` in a file name correctly unless the file name argument is triple quoted. For example: `sqlcmd -i """select,100.sql"""` will try to open a file named `sql,100.sql` while `sqlcmd -i "select,100.sql"` will try to open two files `select` and `100.sql` - If using a single `-i` flag to pass multiple file names, there must be a space after the `-i`. Example: `-i file1.sql file2.sql` diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..af870703 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -208,10 +208,100 @@ func (c Commands) SetBatchTerminator(terminator string) error { return nil } +// isExitParenBalanced checks if the parentheses in an EXIT command argument are balanced. +// It tracks quotes to avoid counting parens inside string literals. +// It handles SQL Server's quote escaping: ” inside single-quoted strings, "" inside double-quoted strings, and ]] inside bracket identifiers. +// It also ignores parentheses inside SQL comments (-- single-line and /* multi-line */). +func isExitParenBalanced(s string) bool { + depth := 0 + var quote rune + inLineComment := false + inBlockComment := false + runes := []rune(s) + for i := 0; i < len(runes); i++ { + c := runes[i] + + // Handle line comment state + if inLineComment { + // Line comment ends at newline + if c == '\n' { + inLineComment = false + } + continue + } + + // Handle block comment state + if inBlockComment { + // Check for end of block comment + if c == '*' && i+1 < len(runes) && runes[i+1] == '/' { + inBlockComment = false + i++ // skip the '/' + } + continue + } + + switch { + case quote != 0: + // Inside a quoted string + if c == quote { + // Check for escaped quote ('' or ]]) + if i+1 < len(runes) && runes[i+1] == quote { + i++ // skip the escaped quote + } else { + quote = 0 + } + } + case c == '-' && i+1 < len(runes) && runes[i+1] == '-': + // Start of single-line comment + inLineComment = true + i++ // skip the second '-' + case c == '/' && i+1 < len(runes) && runes[i+1] == '*': + // Start of block comment + inBlockComment = true + i++ // skip the '*' + case c == '\'' || c == '"': + quote = c + case c == '[': + quote = ']' // SQL Server bracket quoting + case c == '(': + depth++ + case c == ')': + depth-- + } + } + return depth == 0 +} + +// readExitContinuation reads additional lines from the console until the EXIT +// parentheses are balanced. This enables multi-line EXIT(query) in interactive mode. +func readExitContinuation(s *Sqlcmd, params string) (string, error) { + var builder strings.Builder + builder.WriteString(params) + + // Save original prompt and restore it when done (if batch is initialized) + if s.batch != nil { + originalPrompt := s.Prompt() + defer s.lineIo.SetPrompt(originalPrompt) + } + + for !isExitParenBalanced(builder.String()) { + // Show continuation prompt + s.lineIo.SetPrompt(" -> ") + line, err := s.lineIo.Readline() + if err != nil { + return "", err + } + builder.WriteString(SqlcmdEol) + builder.WriteString(line) + } + return builder.String(), nil +} + // exitCommand has 3 modes. // With no (), it just exits without running any query // With () it runs whatever batch is in the buffer then exits // With any text between () it runs the text as a query then exits +// In interactive mode, if parentheses are unbalanced, it prompts for continuation lines. func exitCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 { return ErrExitRequested @@ -220,9 +310,29 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { if params == "" { return ErrExitRequested } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + + // Check if we have an opening paren + if !strings.HasPrefix(params, "(") { return InvalidCommandError("EXIT", line) } + + // If parentheses are unbalanced, try to read continuation lines (interactive mode only) + if !isExitParenBalanced(params) { + if s.lineIo == nil { + // Not in interactive mode, can't read more lines + return InvalidCommandError("EXIT", line) + } + var err error + params, err = readExitContinuation(s, params) + if err != nil { + return err + } + } + + if !strings.HasSuffix(params, ")") { + return InvalidCommandError("EXIT", line) + } + // First we save the current batch query1 := s.batch.String() if len(query1) > 0 { diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..6184ea48 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -5,7 +5,9 @@ package sqlcmd import ( "bytes" + "errors" "fmt" + "io" "os" "strings" "testing" @@ -458,3 +460,213 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } +func TestIsExitParenBalanced(t *testing.T) { + tests := []struct { + input string + balanced bool + }{ + {"()", true}, + {"(select 1)", true}, + {"(select 1", false}, + {"(select (1 + 2))", true}, + {"(select ')')", true}, // paren inside string + {"(select \"(\")", true}, // paren inside double-quoted string + {"(select [col)])", true}, // paren inside bracket-quoted identifier + {"(select 1) extra", true}, // balanced even with trailing text + {"((nested))", true}, + {"((nested)", false}, + {"", true}, // empty string is balanced + {"no parens", true}, // no parens is balanced + {"(", false}, + {")", false}, // depth goes -1, not balanced + {"(test))", false}, // depth goes -1 at end + {"(select 'can''t')", true}, // escaped single quote + {"(select [col]]name])", true}, // escaped bracket identifier + {"(select 'it''s a )test')", true}, // escaped quote with paren + {"(select [a]]])", true}, // escaped bracket with paren + // SQL comment tests + {"(select 1 -- unmatched (\n)", true}, // line comment with paren + {"(select 1 /* ( */ )", true}, // block comment with paren + {"(select /* nested ( */ 1)", true}, // block comment in middle + {"(select 1 -- comment\n+ 2)", true}, // line comment continues to next line + {"(select /* multi\nline\n( */ 1)", true}, // multi-line block comment + {"(select 1 -- ) still need close\n)", true}, // paren in line comment doesn't count + {"(select 1 /* ) */ + /* ( */ 2)", true}, // multiple block comments + {"(select 1 -- (\n-- )\n)", true}, // multiple line comments + {"(select '-- not a comment (' )", true}, // -- inside string is not a comment + {"(select '/* not a comment (' )", true}, // /* inside string is not a comment + {"(select 1 /* unclosed comment", false}, // unclosed block comment, missing ) + {"(select 1) -- trailing comment (", true}, // trailing comment after balanced + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + result := isExitParenBalanced(test.input) + assert.Equal(t, test.balanced, result, "isExitParenBalanced(%q)", test.input) + }) + } +} + +func TestReadExitContinuation(t *testing.T) { + t.Run("reads continuation lines until balanced", func(t *testing.T) { + s := &Sqlcmd{} + lines := []string{"+ 2)", ""} + lineIndex := 0 + promptSet := "" + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + if lineIndex >= len(lines) { + return "", io.EOF + } + line := lines[lineIndex] + lineIndex++ + return line, nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + s.lineIo.SetPrompt("") + + result, err := readExitContinuation(s, "(select 1") + assert.NoError(t, err) + assert.Equal(t, "(select 1"+SqlcmdEol+"+ 2)", result) + + // Verify prompt was set + tc := s.lineIo.(*testConsole) + promptSet = tc.PromptText + assert.Equal(t, " -> ", promptSet) + }) + + t.Run("returns error on readline failure", func(t *testing.T) { + s := &Sqlcmd{} + expectedErr := errors.New("readline error") + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + return "", expectedErr + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + + _, err := readExitContinuation(s, "(select 1") + assert.Equal(t, expectedErr, err) + }) + + t.Run("handles multiple continuation lines", func(t *testing.T) { + s := &Sqlcmd{} + lines := []string{"+ 2", "+ 3", ")"} + lineIndex := 0 + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + if lineIndex >= len(lines) { + return "", io.EOF + } + line := lines[lineIndex] + lineIndex++ + return line, nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + + result, err := readExitContinuation(s, "(select 1") + assert.NoError(t, err) + assert.Equal(t, "(select 1"+SqlcmdEol+"+ 2"+SqlcmdEol+"+ 3"+SqlcmdEol+")", result) + }) + + t.Run("returns immediately if already balanced", func(t *testing.T) { + s := &Sqlcmd{} + readLineCalled := false + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + readLineCalled = true + return "", nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + + result, err := readExitContinuation(s, "(select 1)") + assert.NoError(t, err) + assert.Equal(t, "(select 1)", result) + assert.False(t, readLineCalled, "Readline should not be called for balanced input") + }) + + t.Run("restores original prompt when batch is initialized", func(t *testing.T) { + s := &Sqlcmd{} + s.batch = NewBatch(nil, nil) + lines := []string{")"} + lineIndex := 0 + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + if lineIndex >= len(lines) { + return "", io.EOF + } + line := lines[lineIndex] + lineIndex++ + return line, nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + s.lineIo.SetPrompt("1> ") + + result, err := readExitContinuation(s, "(select 1") + assert.NoError(t, err) + assert.Equal(t, "(select 1"+SqlcmdEol+")", result) + // After function returns, prompt should be restored to original + tc := s.lineIo.(*testConsole) + assert.Equal(t, "1> ", tc.PromptText) + }) +} + +func TestExitCommandNonInteractiveUnbalanced(t *testing.T) { + // Test that unbalanced parentheses in non-interactive mode returns InvalidCommandError + s := &Sqlcmd{} + s.lineIo = nil // non-interactive mode + + err := exitCommand(s, []string{"(select 1"}, 1) + assert.EqualError(t, err, InvalidCommandError("EXIT", 1).Error(), "unbalanced parens in non-interactive should error") +} + +// TestExitCommandMultiLineInteractive is an integration test that exercises the full +// multi-line EXIT flow: starting with unbalanced parentheses, reading continuation lines +// from the console, executing the combined query, and returning the correct exit code. +func TestExitCommandMultiLineInteractive(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Set up mock console to provide continuation lines + continuationLines := []string{"+ 2", ")"} + lineIndex := 0 + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + if lineIndex >= len(continuationLines) { + return "", io.EOF + } + line := continuationLines[lineIndex] + lineIndex++ + return line, nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + + // Initialize batch so exitCommand can work with it + s.batch = NewBatch(nil, nil) + + // Call exitCommand with unbalanced parentheses - this should: + // 1. Detect unbalanced parens in "(select 1" + // 2. Read continuation lines "+ 2" and ")" from the mock console + // 3. Combine into "(select 1\r\n+ 2\r\n)" and execute + // 4. Return ErrExitRequested with Exitcode set to 3 (1+2) + err := exitCommand(s, []string{"(select 1"}, 1) + + assert.Equal(t, ErrExitRequested, err, "exitCommand should return ErrExitRequested") + assert.Equal(t, 3, s.Exitcode, "Exitcode should be 3 (result of 'select 1 + 2')") +}