-
Notifications
You must be signed in to change notification settings - Fork 82
feat(connect): add --server-name flag for tunneled connections #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT license. | ||
|
|
||
| package sqlcmd | ||
|
|
||
| import ( | ||
| "context" | ||
| "net" | ||
| "strings" | ||
| ) | ||
|
|
||
| // proxyDialer implements mssql.HostDialer to allow specifying a server name | ||
| // for the TDS login packet that differs from the dial address. This enables | ||
| // tunneling connections through localhost while authenticating to the real server. | ||
| type proxyDialer struct { | ||
| serverName string | ||
| targetHost string | ||
| targetPort string | ||
| dialer *net.Dialer | ||
| } | ||
|
|
||
| func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
| if d.dialer == nil { | ||
| d.dialer = &net.Dialer{} | ||
| } | ||
| return d.dialer.DialContext(ctx, network, d.dialAddress(network, addr)) | ||
| } | ||
|
|
||
| func (d *proxyDialer) HostName() string { | ||
| return d.serverName | ||
| } | ||
|
|
||
| func (d *proxyDialer) dialAddress(network, addr string) string { | ||
| host, port, err := net.SplitHostPort(addr) | ||
| if err != nil { | ||
| return addr | ||
| } | ||
|
|
||
| if d.targetHost != "" { | ||
| host = d.targetHost | ||
| } | ||
| if d.targetPort != "" && isTCPNetwork(network) { | ||
| port = d.targetPort | ||
| } | ||
|
|
||
| return net.JoinHostPort(host, port) | ||
| } | ||
|
|
||
| func isTCPNetwork(network string) bool { | ||
| return strings.HasPrefix(network, "tcp") | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT license. | ||
|
|
||
| package sqlcmd | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| ) | ||
|
|
||
| func TestProxyDialerHostName(t *testing.T) { | ||
| d := &proxyDialer{serverName: "myserver.database.windows.net"} | ||
| assert.Equal(t, "myserver.database.windows.net", d.HostName()) | ||
| } | ||
|
|
||
| func TestProxyDialerHostNameEmpty(t *testing.T) { | ||
| d := &proxyDialer{} | ||
| assert.Equal(t, "", d.HostName()) | ||
| } | ||
|
|
||
| func TestProxyDialerInitializesNetDialer(t *testing.T) { | ||
| d := &proxyDialer{serverName: "test.server.net"} | ||
| assert.Nil(t, d.dialer) | ||
|
|
||
| // DialContext should fail with an invalid address, but that's fine for this test | ||
| // We just want to verify the dialer gets initialized | ||
| _, _ = d.DialContext(context.Background(), "tcp", "invalid:99999") | ||
| assert.NotNil(t, d.dialer) | ||
| } | ||
|
|
||
| func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) { | ||
| d := &proxyDialer{ | ||
| targetHost: "proxy.local", | ||
| targetPort: "1444", | ||
| } | ||
|
|
||
| dialAddr := d.dialAddress("tcp", "server.example.com:1433") | ||
| assert.Equal(t, "proxy.local:1444", dialAddr) | ||
| } | ||
|
|
||
| func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) { | ||
| d := &proxyDialer{ | ||
| targetHost: "proxy.local", | ||
| targetPort: "1444", | ||
| } | ||
|
|
||
| dialAddr := d.dialAddress("udp", "server.example.com:1434") | ||
| assert.Equal(t, "proxy.local:1434", dialAddr) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,6 @@ | |
| "bufio" | ||
| "context" | ||
| "database/sql" | ||
| "database/sql/driver" | ||
| "errors" | ||
| "fmt" | ||
| "io" | ||
|
|
@@ -259,7 +258,7 @@ | |
| connect = s.Connect | ||
| } | ||
|
|
||
| var connector driver.Connector | ||
| var connector *mssql.Connector | ||
| useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication() | ||
| if connect.RequiresPassword() && !nopw && connect.Password == "" { | ||
| var err error | ||
|
Comment on lines
+261
to
264
|
||
|
|
@@ -275,11 +274,31 @@ | |
| if !useAad { | ||
| connector, err = mssql.NewConnector(connstr) | ||
| } else { | ||
| connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod()) | ||
|
Check failure on line 277 in pkg/sqlcmd/sqlcmd.go
|
||
| } | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if connect.ServerNameOverride != "" { | ||
| serverName, _, port, protocol, err := splitServer(connect.ServerName) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if serverName == "" { | ||
| serverName = "." | ||
| } | ||
| if connect.useServerNameOverride(protocol, connect.ServerName) { | ||
| targetPort := "" | ||
| if port > 0 { | ||
| targetPort = fmt.Sprintf("%d", port) | ||
| } | ||
| connector.Dialer = &proxyDialer{ | ||
| serverName: connect.ServerNameOverride, | ||
| targetHost: serverName, | ||
| targetPort: targetPort, | ||
| } | ||
| } | ||
| } | ||
| db, err := sql.OpenDB(connector).Conn(context.Background()) | ||
| if err != nil { | ||
| fmt.Fprintln(s.GetOutput(), err) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description/title mention adding a
--server-nameflag, but the repo currently has no CLI flag wiring for this newConnectSettings.ServerNameOverridefield (searchingcmd/shows no references). As-is, the new behavior is not reachable from the shippedsqlcmdcommands; please add the flag (and help text) in the Cobra command(s) and map it intoConnectSettings.