Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pkg/sqlcmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ type ConnectSettings struct {
HostNameInCertificate string
// ServerCertificate is the path to a certificate file to match against the server's TLS certificate
ServerCertificate string
// ServerNameOverride specifies the server name to use in the login packet.
// When set, the actual dial address comes from ServerName, but this value
// is sent in the TDS login packet for server validation.
ServerNameOverride string
Comment on lines +65 to +68
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description/title mention adding a --server-name flag, but the repo currently has no CLI flag wiring for this new ConnectSettings.ServerNameOverride field (searching cmd/ shows no references). As-is, the new behavior is not reachable from the shipped sqlcmd commands; please add the flag (and help text) in the Cobra command(s) and map it into ConnectSettings.

Copilot uses AI. Check for mistakes.
}

func (c ConnectSettings) authenticationMethod() string {
Expand Down Expand Up @@ -100,6 +104,21 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
if err != nil {
return "", err
}

if connect.useServerNameOverride(protocol, connect.ServerName) {
overrideName, overrideInstance, _, _, err := splitServer(connect.ServerNameOverride)
if err != nil {
return "", err
}
if overrideName == "" {
overrideName = "."
}
serverName = overrideName
if overrideInstance != "" {
instance = overrideInstance
}
}

query := url.Values{}
connectionURL := &url.URL{
Scheme: "sqlserver",
Expand Down Expand Up @@ -176,3 +195,13 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
connectionURL.RawQuery = query.Encode()
return connectionURL.String(), nil
}

func (connect ConnectSettings) useServerNameOverride(protocol string, serverName string) bool {
if connect.ServerNameOverride == "" {
return false
}
if protocol == "np" || strings.HasPrefix(serverName, `\\`) {
return false
}
return true
}
51 changes: 51 additions & 0 deletions pkg/sqlcmd/dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sqlcmd

import (
"context"
"net"
"strings"
)

// proxyDialer implements mssql.HostDialer to allow specifying a server name
// for the TDS login packet that differs from the dial address. This enables
// tunneling connections through localhost while authenticating to the real server.
type proxyDialer struct {
serverName string
targetHost string
targetPort string
dialer *net.Dialer
}

func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if d.dialer == nil {
d.dialer = &net.Dialer{}
}
return d.dialer.DialContext(ctx, network, d.dialAddress(network, addr))
}

func (d *proxyDialer) HostName() string {
return d.serverName
}

func (d *proxyDialer) dialAddress(network, addr string) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}

if d.targetHost != "" {
host = d.targetHost
}
if d.targetPort != "" && isTCPNetwork(network) {
port = d.targetPort
}

return net.JoinHostPort(host, port)
}

func isTCPNetwork(network string) bool {
return strings.HasPrefix(network, "tcp")
}
51 changes: 51 additions & 0 deletions pkg/sqlcmd/dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sqlcmd

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestProxyDialerHostName(t *testing.T) {
d := &proxyDialer{serverName: "myserver.database.windows.net"}
assert.Equal(t, "myserver.database.windows.net", d.HostName())
}

func TestProxyDialerHostNameEmpty(t *testing.T) {
d := &proxyDialer{}
assert.Equal(t, "", d.HostName())
}

func TestProxyDialerInitializesNetDialer(t *testing.T) {
d := &proxyDialer{serverName: "test.server.net"}
assert.Nil(t, d.dialer)

// DialContext should fail with an invalid address, but that's fine for this test
// We just want to verify the dialer gets initialized
_, _ = d.DialContext(context.Background(), "tcp", "invalid:99999")
assert.NotNil(t, d.dialer)
}

func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) {
d := &proxyDialer{
targetHost: "proxy.local",
targetPort: "1444",
}

dialAddr := d.dialAddress("tcp", "server.example.com:1433")
assert.Equal(t, "proxy.local:1444", dialAddr)
}

func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) {
d := &proxyDialer{
targetHost: "proxy.local",
targetPort: "1444",
}

dialAddr := d.dialAddress("udp", "server.example.com:1434")
assert.Equal(t, "proxy.local:1434", dialAddr)
}
23 changes: 21 additions & 2 deletions pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"bufio"
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -259,7 +258,7 @@
connect = s.Connect
}

var connector driver.Connector
var connector *mssql.Connector
useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication()
if connect.RequiresPassword() && !nopw && connect.Password == "" {
var err error
Comment on lines +261 to 264
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connector is now typed as *mssql.Connector, but the AAD path calls GetTokenBasedConnection(...) which currently returns driver.Connector (see pkg/sqlcmd/azure_auth.go). This will cause a type mismatch/compile error and also makes it hard to support ServerNameOverride for AAD connectors. Consider keeping connector as driver.Connector and only type-asserting to *mssql.Connector when setting Dialer (or update GetTokenBasedConnection to return *mssql.Connector consistently).

Copilot uses AI. Check for mistakes.
Expand All @@ -275,11 +274,31 @@
if !useAad {
connector, err = mssql.NewConnector(connstr)
} else {
connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod())

Check failure on line 277 in pkg/sqlcmd/sqlcmd.go

View workflow job for this annotation

GitHub Actions / lint-pr-changes

cannot use GetTokenBasedConnection(connstr, connect.authenticationMethod()) (value of interface type driver.Connector) as *mssql.Connector value in assignment: need type assertion) (typecheck)

Check failure on line 277 in pkg/sqlcmd/sqlcmd.go

View workflow job for this annotation

GitHub Actions / Go Vulnerability Check

cannot use GetTokenBasedConnection(connstr, connect.authenticationMethod()) (value of interface type driver.Connector) as *mssql.Connector value in assignment: need type assertion
}
if err != nil {
return err
}
if connect.ServerNameOverride != "" {
serverName, _, port, protocol, err := splitServer(connect.ServerName)
if err != nil {
return err
}
if serverName == "" {
serverName = "."
}
if connect.useServerNameOverride(protocol, connect.ServerName) {
targetPort := ""
if port > 0 {
targetPort = fmt.Sprintf("%d", port)
}
connector.Dialer = &proxyDialer{
serverName: connect.ServerNameOverride,
targetHost: serverName,
targetPort: targetPort,
}
}
}
db, err := sql.OpenDB(connector).Conn(context.Background())
if err != nil {
fmt.Fprintln(s.GetOutput(), err)
Expand Down
12 changes: 12 additions & 0 deletions pkg/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ func TestConnectionStringFromSqlCmd(t *testing.T) {
&ConnectSettings{ServerName: `tcp:someserver,1045`, Encrypt: "strict", HostNameInCertificate: "*.mydomain.com"},
"sqlserver://someserver:1045?encrypt=strict&hostnameincertificate=%2A.mydomain.com&protocol=tcp",
},
{
&ConnectSettings{ServerName: `tcp:proxyhost,1444`, ServerNameOverride: "realsql"},
"sqlserver://realsql:1444?protocol=tcp",
},
{
&ConnectSettings{ServerName: `proxyhost\instance`, ServerNameOverride: "realsql"},
"sqlserver://realsql/instance",
},
{
&ConnectSettings{ServerName: `proxyhost,1444`, ServerNameOverride: `realsql\inst`},
"sqlserver://realsql:1444/inst",
},
{
&ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd},
fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd),
Expand Down
Loading