diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go index 95af0871..a1304e81 100644 --- a/pkg/sqlcmd/connect.go +++ b/pkg/sqlcmd/connect.go @@ -62,6 +62,10 @@ type ConnectSettings struct { HostNameInCertificate string // ServerCertificate is the path to a certificate file to match against the server's TLS certificate ServerCertificate string + // ServerNameOverride specifies the server name to use in the login packet. + // When set, the actual dial address comes from ServerName, but this value + // is sent in the TDS login packet for server validation. + ServerNameOverride string } func (c ConnectSettings) authenticationMethod() string { @@ -100,6 +104,21 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err if err != nil { return "", err } + + if connect.useServerNameOverride(protocol, connect.ServerName) { + overrideName, overrideInstance, _, _, err := splitServer(connect.ServerNameOverride) + if err != nil { + return "", err + } + if overrideName == "" { + overrideName = "." + } + serverName = overrideName + if overrideInstance != "" { + instance = overrideInstance + } + } + query := url.Values{} connectionURL := &url.URL{ Scheme: "sqlserver", @@ -176,3 +195,13 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err connectionURL.RawQuery = query.Encode() return connectionURL.String(), nil } + +func (connect ConnectSettings) useServerNameOverride(protocol string, serverName string) bool { + if connect.ServerNameOverride == "" { + return false + } + if protocol == "np" || strings.HasPrefix(serverName, `\\`) { + return false + } + return true +} diff --git a/pkg/sqlcmd/dialer.go b/pkg/sqlcmd/dialer.go new file mode 100644 index 00000000..994c1df1 --- /dev/null +++ b/pkg/sqlcmd/dialer.go @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "net" + "strings" +) + +// proxyDialer implements mssql.HostDialer to allow specifying a server name +// for the TDS login packet that differs from the dial address. This enables +// tunneling connections through localhost while authenticating to the real server. +type proxyDialer struct { + serverName string + targetHost string + targetPort string + dialer *net.Dialer +} + +func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if d.dialer == nil { + d.dialer = &net.Dialer{} + } + return d.dialer.DialContext(ctx, network, d.dialAddress(network, addr)) +} + +func (d *proxyDialer) HostName() string { + return d.serverName +} + +func (d *proxyDialer) dialAddress(network, addr string) string { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + + if d.targetHost != "" { + host = d.targetHost + } + if d.targetPort != "" && isTCPNetwork(network) { + port = d.targetPort + } + + return net.JoinHostPort(host, port) +} + +func isTCPNetwork(network string) bool { + return strings.HasPrefix(network, "tcp") +} diff --git a/pkg/sqlcmd/dialer_test.go b/pkg/sqlcmd/dialer_test.go new file mode 100644 index 00000000..a30c4a6c --- /dev/null +++ b/pkg/sqlcmd/dialer_test.go @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProxyDialerHostName(t *testing.T) { + d := &proxyDialer{serverName: "myserver.database.windows.net"} + assert.Equal(t, "myserver.database.windows.net", d.HostName()) +} + +func TestProxyDialerHostNameEmpty(t *testing.T) { + d := &proxyDialer{} + assert.Equal(t, "", d.HostName()) +} + +func TestProxyDialerInitializesNetDialer(t *testing.T) { + d := &proxyDialer{serverName: "test.server.net"} + assert.Nil(t, d.dialer) + + // DialContext should fail with an invalid address, but that's fine for this test + // We just want to verify the dialer gets initialized + _, _ = d.DialContext(context.Background(), "tcp", "invalid:99999") + assert.NotNil(t, d.dialer) +} + +func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) { + d := &proxyDialer{ + targetHost: "proxy.local", + targetPort: "1444", + } + + dialAddr := d.dialAddress("tcp", "server.example.com:1433") + assert.Equal(t, "proxy.local:1444", dialAddr) +} + +func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) { + d := &proxyDialer{ + targetHost: "proxy.local", + targetPort: "1444", + } + + dialAddr := d.dialAddress("udp", "server.example.com:1434") + assert.Equal(t, "proxy.local:1434", dialAddr) +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..f1ddc78b 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -273,13 +273,37 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { } if !useAad { - connector, err = mssql.NewConnector(connstr) + var c *mssql.Connector + c, err = mssql.NewConnector(connstr) + connector = c } else { connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod()) } if err != nil { return err } + if connect.ServerNameOverride != "" { + serverName, _, port, protocol, err := splitServer(connect.ServerName) + if err != nil { + return err + } + if serverName == "" { + serverName = "." + } + if connect.useServerNameOverride(protocol, connect.ServerName) { + targetPort := "" + if port > 0 { + targetPort = fmt.Sprintf("%d", port) + } + if mssqlConnector, ok := connector.(*mssql.Connector); ok { + mssqlConnector.Dialer = &proxyDialer{ + serverName: connect.ServerNameOverride, + targetHost: serverName, + targetPort: targetPort, + } + } + } + } db, err := sql.OpenDB(connector).Conn(context.Background()) if err != nil { fmt.Fprintln(s.GetOutput(), err) diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index dfe97d1a..408e69df 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -54,6 +54,18 @@ func TestConnectionStringFromSqlCmd(t *testing.T) { &ConnectSettings{ServerName: `tcp:someserver,1045`, Encrypt: "strict", HostNameInCertificate: "*.mydomain.com"}, "sqlserver://someserver:1045?encrypt=strict&hostnameincertificate=%2A.mydomain.com&protocol=tcp", }, + { + &ConnectSettings{ServerName: `tcp:proxyhost,1444`, ServerNameOverride: "realsql"}, + "sqlserver://realsql:1444?protocol=tcp", + }, + { + &ConnectSettings{ServerName: `proxyhost\instance`, ServerNameOverride: "realsql"}, + "sqlserver://realsql/instance", + }, + { + &ConnectSettings{ServerName: `proxyhost,1444`, ServerNameOverride: `realsql\inst`}, + "sqlserver://realsql:1444/inst", + }, { &ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd}, fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd),