From 611d16529c6fffabb31a957354f0161a4a0394d9 Mon Sep 17 00:00:00 2001 From: Jimmy Stridh Date: Fri, 30 Jan 2026 00:07:43 +0100 Subject: [PATCH 1/2] feat(connect): add --server-name flag for tunneled connections Allows specifying the server name sent in the TDS LOGIN7 packet separately from the dial address. Fixes connections through SSH tunnels or proxies to Azure SQL where the server validates hostname. Refs: #576 --- pkg/sqlcmd/connect.go | 29 ++++++++++++++++++++++ pkg/sqlcmd/dialer.go | 51 +++++++++++++++++++++++++++++++++++++++ pkg/sqlcmd/dialer_test.go | 51 +++++++++++++++++++++++++++++++++++++++ pkg/sqlcmd/sqlcmd.go | 23 ++++++++++++++++-- pkg/sqlcmd/sqlcmd_test.go | 12 +++++++++ 5 files changed, 164 insertions(+), 2 deletions(-) create mode 100644 pkg/sqlcmd/dialer.go create mode 100644 pkg/sqlcmd/dialer_test.go diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go index 95af0871..a1304e81 100644 --- a/pkg/sqlcmd/connect.go +++ b/pkg/sqlcmd/connect.go @@ -62,6 +62,10 @@ type ConnectSettings struct { HostNameInCertificate string // ServerCertificate is the path to a certificate file to match against the server's TLS certificate ServerCertificate string + // ServerNameOverride specifies the server name to use in the login packet. + // When set, the actual dial address comes from ServerName, but this value + // is sent in the TDS login packet for server validation. + ServerNameOverride string } func (c ConnectSettings) authenticationMethod() string { @@ -100,6 +104,21 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err if err != nil { return "", err } + + if connect.useServerNameOverride(protocol, connect.ServerName) { + overrideName, overrideInstance, _, _, err := splitServer(connect.ServerNameOverride) + if err != nil { + return "", err + } + if overrideName == "" { + overrideName = "." + } + serverName = overrideName + if overrideInstance != "" { + instance = overrideInstance + } + } + query := url.Values{} connectionURL := &url.URL{ Scheme: "sqlserver", @@ -176,3 +195,13 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err connectionURL.RawQuery = query.Encode() return connectionURL.String(), nil } + +func (connect ConnectSettings) useServerNameOverride(protocol string, serverName string) bool { + if connect.ServerNameOverride == "" { + return false + } + if protocol == "np" || strings.HasPrefix(serverName, `\\`) { + return false + } + return true +} diff --git a/pkg/sqlcmd/dialer.go b/pkg/sqlcmd/dialer.go new file mode 100644 index 00000000..994c1df1 --- /dev/null +++ b/pkg/sqlcmd/dialer.go @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "net" + "strings" +) + +// proxyDialer implements mssql.HostDialer to allow specifying a server name +// for the TDS login packet that differs from the dial address. This enables +// tunneling connections through localhost while authenticating to the real server. +type proxyDialer struct { + serverName string + targetHost string + targetPort string + dialer *net.Dialer +} + +func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if d.dialer == nil { + d.dialer = &net.Dialer{} + } + return d.dialer.DialContext(ctx, network, d.dialAddress(network, addr)) +} + +func (d *proxyDialer) HostName() string { + return d.serverName +} + +func (d *proxyDialer) dialAddress(network, addr string) string { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + + if d.targetHost != "" { + host = d.targetHost + } + if d.targetPort != "" && isTCPNetwork(network) { + port = d.targetPort + } + + return net.JoinHostPort(host, port) +} + +func isTCPNetwork(network string) bool { + return strings.HasPrefix(network, "tcp") +} diff --git a/pkg/sqlcmd/dialer_test.go b/pkg/sqlcmd/dialer_test.go new file mode 100644 index 00000000..a30c4a6c --- /dev/null +++ b/pkg/sqlcmd/dialer_test.go @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProxyDialerHostName(t *testing.T) { + d := &proxyDialer{serverName: "myserver.database.windows.net"} + assert.Equal(t, "myserver.database.windows.net", d.HostName()) +} + +func TestProxyDialerHostNameEmpty(t *testing.T) { + d := &proxyDialer{} + assert.Equal(t, "", d.HostName()) +} + +func TestProxyDialerInitializesNetDialer(t *testing.T) { + d := &proxyDialer{serverName: "test.server.net"} + assert.Nil(t, d.dialer) + + // DialContext should fail with an invalid address, but that's fine for this test + // We just want to verify the dialer gets initialized + _, _ = d.DialContext(context.Background(), "tcp", "invalid:99999") + assert.NotNil(t, d.dialer) +} + +func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) { + d := &proxyDialer{ + targetHost: "proxy.local", + targetPort: "1444", + } + + dialAddr := d.dialAddress("tcp", "server.example.com:1433") + assert.Equal(t, "proxy.local:1444", dialAddr) +} + +func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) { + d := &proxyDialer{ + targetHost: "proxy.local", + targetPort: "1444", + } + + dialAddr := d.dialAddress("udp", "server.example.com:1434") + assert.Equal(t, "proxy.local:1434", dialAddr) +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..d30eadc9 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -7,7 +7,6 @@ import ( "bufio" "context" "database/sql" - "database/sql/driver" "errors" "fmt" "io" @@ -259,7 +258,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { connect = s.Connect } - var connector driver.Connector + var connector *mssql.Connector useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication() if connect.RequiresPassword() && !nopw && connect.Password == "" { var err error @@ -280,6 +279,26 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { if err != nil { return err } + if connect.ServerNameOverride != "" { + serverName, _, port, protocol, err := splitServer(connect.ServerName) + if err != nil { + return err + } + if serverName == "" { + serverName = "." + } + if connect.useServerNameOverride(protocol, connect.ServerName) { + targetPort := "" + if port > 0 { + targetPort = fmt.Sprintf("%d", port) + } + connector.Dialer = &proxyDialer{ + serverName: connect.ServerNameOverride, + targetHost: serverName, + targetPort: targetPort, + } + } + } db, err := sql.OpenDB(connector).Conn(context.Background()) if err != nil { fmt.Fprintln(s.GetOutput(), err) diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index dfe97d1a..408e69df 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -54,6 +54,18 @@ func TestConnectionStringFromSqlCmd(t *testing.T) { &ConnectSettings{ServerName: `tcp:someserver,1045`, Encrypt: "strict", HostNameInCertificate: "*.mydomain.com"}, "sqlserver://someserver:1045?encrypt=strict&hostnameincertificate=%2A.mydomain.com&protocol=tcp", }, + { + &ConnectSettings{ServerName: `tcp:proxyhost,1444`, ServerNameOverride: "realsql"}, + "sqlserver://realsql:1444?protocol=tcp", + }, + { + &ConnectSettings{ServerName: `proxyhost\instance`, ServerNameOverride: "realsql"}, + "sqlserver://realsql/instance", + }, + { + &ConnectSettings{ServerName: `proxyhost,1444`, ServerNameOverride: `realsql\inst`}, + "sqlserver://realsql:1444/inst", + }, { &ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd}, fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd), From eb57f2a2c93a1787baceffe4459cdd4c905e98c6 Mon Sep 17 00:00:00 2001 From: Jimmy Stridh Date: Fri, 30 Jan 2026 08:16:24 +0100 Subject: [PATCH 2/2] Fix connector typing for server-name override --- pkg/sqlcmd/sqlcmd.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index d30eadc9..f1ddc78b 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -7,6 +7,7 @@ import ( "bufio" "context" "database/sql" + "database/sql/driver" "errors" "fmt" "io" @@ -258,7 +259,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { connect = s.Connect } - var connector *mssql.Connector + var connector driver.Connector useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication() if connect.RequiresPassword() && !nopw && connect.Password == "" { var err error @@ -272,7 +273,9 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { } if !useAad { - connector, err = mssql.NewConnector(connstr) + var c *mssql.Connector + c, err = mssql.NewConnector(connstr) + connector = c } else { connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod()) } @@ -292,10 +295,12 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { if port > 0 { targetPort = fmt.Sprintf("%d", port) } - connector.Dialer = &proxyDialer{ - serverName: connect.ServerNameOverride, - targetHost: serverName, - targetPort: targetPort, + if mssqlConnector, ok := connector.(*mssql.Connector); ok { + mssqlConnector.Dialer = &proxyDialer{ + serverName: connect.ServerNameOverride, + targetHost: serverName, + targetPort: targetPort, + } } } }