Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Check out source
uses: actions/checkout@v2
- name: Start MongoDB
uses: supercharge/mongodb-github-action@1.8.0
uses: supercharge/mongodb-github-action@1.12.1
with:
mongodb-replica-set: replicaset
- name: Unit test
Expand Down Expand Up @@ -173,7 +173,7 @@ jobs:
- name: Check out source
uses: actions/checkout@v2
- name: Start MongoDB
uses: supercharge/mongodb-github-action@1.8.0
uses: supercharge/mongodb-github-action@1.12.1
with:
mongodb-replica-set: replicaset
- name: Install ruby
Expand Down
25 changes: 23 additions & 2 deletions config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ const (
// This allows for alternate socket hosts for connecting to a session for failover.
// (i.e.) SocketConnectHost1, SocketConnectHost2... must be consecutive and have a matching SocketConnectPort<n>.
//
// Required: Yes for initiators
// Required: Yes for initiators on socket connections
//
// Default: None
//
Expand All @@ -570,7 +570,7 @@ const (
// This allows for alternate socket ports for connecting to a session for failover.
// (i.e.) SocketConnectPort1, SocketConnectPort2... must be consecutive and have a matching SocketConnectHost<n>.
//
// Required: Yes for initiators
// Required: Yes for initiators on socket connections
//
// Default: None
//
Expand Down Expand Up @@ -647,6 +647,27 @@ const (
// Valid Values:
// - Any string
ProxyPassword string = "ProxyPassword"

// WebsocketLocation sets the websocket endpoint to attempt to connect to.
// Setting this would override any SocketConnectHost and SocketConnectPort settings and connect using websocket
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - A websocket endpoint - eg. wss://example.com/ws
WebsocketLocation string = "WebsocketLocation"

// WebsocketOrigin sets the websocket origin to attempt to connect from.
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - url - eg. http://localhost/
WebsocketOrigin string = "WebsocketOrigin"
)

const (
Expand Down
92 changes: 88 additions & 4 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,113 @@
package quickfix

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"

"golang.org/x/net/proxy"
"golang.org/x/net/websocket"

"github.com/quickfixgo/quickfix/config"
)

func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) {
type Dialer interface {
Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (net.Conn, error)
}

type TCPDialer struct {
ctxDialer proxy.ContextDialer
}

func (d *TCPDialer) Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (conn net.Conn, err error) {
address := session.SocketConnectAddress[attempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

conn, err = d.ctxDialer.DialContext(ctx, "tcp", address)

if err != nil {
return
} else if tlsConfig != nil {
// Unless InsecureSkipVerify is true, server name config is required for TLS
// to verify the received certificate
if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 {
serverName := address
if c := strings.LastIndex(serverName, ":"); c > 0 {
serverName = serverName[:c]
}
tlsConfig.ServerName = serverName
}
tlsConn := tls.Client(conn, tlsConfig)
if err = tlsConn.Handshake(); err != nil {

session.log.OnEventf("Failed handshake: %v", err)
return
}
conn = tlsConn
}

return
}

type WebsocketDialer struct {
wsConfig *websocket.Config
}

func (d *WebsocketDialer) Dial(ctx context.Context, session *session, _ int, tlsConfig *tls.Config) (conn net.Conn, err error) {
session.log.OnEventf("Connecting to: %v", d.wsConfig.Location)

d.wsConfig.TlsConfig = tlsConfig
conn, err = d.wsConfig.DialContext(ctx)
return
}

func loadDialerConfig(settings *SessionSettings) (dialer Dialer, err error) {

if settings.HasSetting(config.WebsocketLocation) {
var location string
location, err = settings.Setting(config.WebsocketLocation)
if err != nil {
return nil, err
}

var origin string
origin, err = settings.Setting(config.WebsocketOrigin)
if err != nil {
return nil, err
}

var wsConfig *websocket.Config
wsConfig, err = websocket.NewConfig(location, origin)
if err != nil {
return nil, err
}

dialer = &WebsocketDialer{
wsConfig: wsConfig,
}
return
}

stdDialer := &net.Dialer{}
dialer = &TCPDialer{
ctxDialer: stdDialer,
}
if settings.HasSetting(config.SocketTimeout) {
timeout, err := settings.DurationSetting(config.SocketTimeout)
if err != nil {
timeoutInt, err := settings.IntSetting(config.SocketTimeout)
if err != nil {
return stdDialer, err
return nil, err
}

stdDialer.Timeout = time.Duration(timeoutInt) * time.Second
} else {
stdDialer.Timeout = timeout
}
}
dialer = stdDialer

if !settings.HasSetting(config.ProxyType) {
return
Expand Down Expand Up @@ -81,7 +163,9 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, er
}

if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok {
dialer = contextDialer
dialer = &TCPDialer{
ctxDialer: contextDialer,
}
} else {
err = fmt.Errorf("proxy does not support context dialer")
return
Expand Down
25 changes: 22 additions & 3 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *DialerTestSuite) TestLoadDialerNoSettings() {
dialer, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().Nil(err)

stdDialer, ok := dialer.(*net.Dialer)
stdDialer, ok := dialer.(*TCPDialer).ctxDialer.(*net.Dialer)
s.Require().True(ok)
s.Require().NotNil(stdDialer)
s.Zero(stdDialer.Timeout)
Expand All @@ -53,7 +53,7 @@ func (s *DialerTestSuite) TestLoadDialerWithTimeout() {
dialer, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().Nil(err)

stdDialer, ok := dialer.(*net.Dialer)
stdDialer, ok := dialer.(*TCPDialer).ctxDialer.(*net.Dialer)
s.Require().True(ok)
s.Require().NotNil(stdDialer)
s.EqualValues(10*time.Second, stdDialer.Timeout)
Expand All @@ -73,7 +73,7 @@ func (s *DialerTestSuite) TestLoadDialerSocksProxy() {
s.Require().Nil(err)
s.Require().NotNil(dialer)

_, ok := dialer.(*net.Dialer)
_, ok := dialer.(*TCPDialer).ctxDialer.(*net.Dialer)
s.Require().False(ok)
}

Expand All @@ -90,3 +90,22 @@ func (s *DialerTestSuite) TestLoadDialerSocksProxyInvalidPort() {
_, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().NotNil(err)
}

func (s *DialerTestSuite) TestLoadDialerWebsocket() {
s.settings.GlobalSettings().Set(config.WebsocketLocation, "ws://example.com/ws")
s.settings.GlobalSettings().Set(config.WebsocketOrigin, "http://localhost/")

dialer, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().NoError(err)

wsDialer, ok := dialer.(*WebsocketDialer)
s.Require().True(ok)
s.Equal("ws://example.com/ws", wsDialer.wsConfig.Location.String())
s.Equal("http://localhost/", wsDialer.wsConfig.Origin.String())
}

func (s *DialerTestSuite) TestLoadDialerWebsocketMissingOrigin() {
s.settings.GlobalSettings().Set(config.WebsocketLocation, "ws://example.com/ws")
_, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().Error(err)
}
33 changes: 6 additions & 27 deletions initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ import (
"bufio"
"context"
"crypto/tls"
"strings"
"sync"
"time"

"golang.org/x/net/proxy"
)

// Initiator initiates connections and processes messages for all sessions.
Expand All @@ -48,12 +45,12 @@ func (i *Initiator) Start() (err error) {
// TODO: move into session factory.
var tlsConfig *tls.Config
if tlsConfig, err = loadTLSConfig(settings); err != nil {
return
return err
}

var dialer proxy.ContextDialer
var dialer Dialer
if dialer, err = loadDialerConfig(settings); err != nil {
return
return err
}

i.wg.Add(1)
Expand Down Expand Up @@ -143,7 +140,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo
return true
}

func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) {
func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer Dialer) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand Down Expand Up @@ -180,30 +177,12 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
var msgIn chan fixIn
var msgOut chan []byte

address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

netConn, err := dialer.DialContext(ctx, "tcp", address)
netConn, err := dialer.Dial(ctx, session, connectionAttempt, tlsConfig)
if err != nil {
session.log.OnEventf("Failed to connect: %v", err)
goto reconnect
} else if tlsConfig != nil {
// Unless InsecureSkipVerify is true, server name config is required for TLS
// to verify the received certificate
if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 {
serverName := address
if c := strings.LastIndex(serverName, ":"); c > 0 {
serverName = serverName[:c]
}
tlsConfig.ServerName = serverName
}
tlsConn := tls.Client(netConn, tlsConfig)
if err = tlsConn.Handshake(); err != nil {
session.log.OnEventf("Failed handshake: %v", err)
goto reconnect
}
netConn = tlsConn
}
session.log.OnEventf("connected to remote address: %v", netConn.RemoteAddr().String())

msgIn = make(chan fixIn, session.InChanCapacity)
msgOut = make(chan []byte)
Expand Down
7 changes: 7 additions & 0 deletions session_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ func (f sessionFactory) buildInitiatorSettings(session *session, settings *Sessi
func (f sessionFactory) configureSocketConnectAddress(session *session, settings *SessionSettings) (err error) {
session.SocketConnectAddress = []string{}

if !settings.HasSetting(config.SocketConnectHost) {
if !settings.HasSetting(config.WebsocketLocation) {
err = errors.New("SocketConnectHost must be specified if WebsocketLocation is not specified")
}
return
}

var socketConnectHost, socketConnectPort string
for i := 0; ; {

Expand Down
9 changes: 9 additions & 0 deletions session_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,15 @@ func (s *SessionFactorySuite) TestConfigureSocketConnectAddress() {
}
}

func (s *SessionFactorySuite) TestConfigureSocketConnectAddressWebsocketOnly() {
sess := new(session)
s.SessionSettings.Set(config.WebsocketLocation, "wss://example.com/ws")

err := s.configureSocketConnectAddress(sess, s.SessionSettings)
s.Require().NoError(err)
s.Empty(sess.SocketConnectAddress)
}

func (s *SessionFactorySuite) TestConfigureSocketConnectAddressMulti() {
session := new(session)
s.SessionSettings.Set(config.SocketConnectHost, "127.0.0.1")
Expand Down
Loading