diff --git a/session.go b/session.go index 342988f64..f05dfa65e 100644 --- a/session.go +++ b/session.go @@ -53,6 +53,7 @@ type session struct { stateTimer *internal.EventTimer peerTimer *internal.EventTimer sentReset bool + stopCh chan struct{} stopOnce sync.Once targetDefaultApplVerID string @@ -98,6 +99,9 @@ type stopReq struct{} func (s *session) stop() { // Stop once. s.stopOnce.Do(func() { + if s.stopCh != nil { + close(s.stopCh) + } s.admin <- stopReq{} }) } @@ -596,7 +600,13 @@ func (s *session) initiateLogoutInReplyTo(reason string, inReplyTo *Message) (er return } s.log.OnEvent("Inititated logout request") - time.AfterFunc(s.LogoutTimeout, func() { s.sessionEvent <- internal.LogoutTimeout }) + time.AfterFunc(s.LogoutTimeout, func() { + select { + case <-s.stopCh: + return + case s.sessionEvent <- internal.LogoutTimeout: + } + }) return } diff --git a/session_factory.go b/session_factory.go index 1b79f5ee1..cd7a9bafa 100644 --- a/session_factory.go +++ b/session_factory.go @@ -90,6 +90,7 @@ func (f sessionFactory) newSession( s = &session{ sessionID: sessionID, stopOnce: sync.Once{}, + stopCh: make(chan struct{}), } var validatorSettings = defaultValidatorSettings diff --git a/session_leak_test.go b/session_leak_test.go new file mode 100644 index 000000000..c49d3ed3a --- /dev/null +++ b/session_leak_test.go @@ -0,0 +1,87 @@ +package quickfix + +import ( + "bytes" + "runtime/pprof" + "strings" + "testing" + "time" + + "github.com/quickfixgo/quickfix/internal" +) + +type testLog struct{} + +func (testLog) OnIncoming([]byte) {} +func (testLog) OnOutgoing([]byte) {} +func (testLog) OnEvent(string) {} +func (testLog) OnEventf(string, ...interface{}) {} + +// testApp is a no-op Application for tests. +type testApp struct{} + +func (testApp) OnCreate(SessionID) {} +func (testApp) OnLogon(SessionID) {} +func (testApp) OnLogout(SessionID) {} +func (testApp) ToAdmin(*Message, SessionID) {} +func (testApp) FromAdmin(*Message, SessionID) MessageRejectError { return nil } +func (testApp) ToApp(*Message, SessionID) error { return nil } +func (testApp) FromApp(*Message, SessionID) MessageRejectError { return nil } + +func newTimerOnlySession() *session { + tr, _ := internal.NewUTCTimeRange(internal.NewTimeOfDay(0, 0, 0), internal.NewTimeOfDay(23, 59, 59), nil) + s := &session{ + store: &memoryStore{}, + log: testLog{}, + sessionID: SessionID{BeginString: BeginStringFIX44, SenderCompID: "S", TargetCompID: "T"}, + messageOut: make(chan []byte, 2), + messageIn: make(chan fixIn), + sessionEvent: make(chan internal.Event), + messageEvent: make(chan bool), + application: testApp{}, + stopCh: make(chan struct{}), + SessionSettings: internal.SessionSettings{ + SessionTime: tr, + }, + } + return s +} + +func countGoroutinesContaining(substr string) int { + var buf bytes.Buffer + _ = pprof.Lookup("goroutine").WriteTo(&buf, 2) + return strings.Count(buf.String(), substr) +} + +func TestLogonTimeoutDoesNotLeakGoroutine(t *testing.T) { + s := newTimerOnlySession() + s.InitiateLogon = true + s.LogonTimeout = 10 * time.Millisecond + + baseline := countGoroutinesContaining("stateMachine).Connect.func1") + + s.stateMachine.Connect(s) + close(s.stopCh) + time.Sleep(4 * s.LogonTimeout) + + if got := countGoroutinesContaining("stateMachine).Connect.func1"); got > baseline { + t.Fatalf("logon timeout goroutine leaked: baseline=%d current=%d", baseline, got) + } +} + +func TestLogoutTimeoutDoesNotLeakGoroutine(t *testing.T) { + s := newTimerOnlySession() + s.LogoutTimeout = 10 * time.Millisecond + baseline := countGoroutinesContaining("initiateLogoutInReplyTo") + s.stateMachine.Connect(s) + + if err := s.initiateLogout("bye"); err != nil { + t.Fatalf("initiateLogout returned error: %v", err) + } + close(s.stopCh) + time.Sleep(4 * s.LogoutTimeout) + + if got := countGoroutinesContaining("initiateLogoutInReplyTo"); got > baseline { + t.Fatalf("logout timeout goroutine leaked: baseline=%d current=%d", baseline, got) + } +} diff --git a/session_state.go b/session_state.go index 6fe4dded7..15e8b96dc 100644 --- a/session_state.go +++ b/session_state.go @@ -65,7 +65,13 @@ func (sm *stateMachine) Connect(session *session) { sm.setState(session, logonState{}) // Fire logon timeout event after the pre-configured delay period. - time.AfterFunc(session.LogonTimeout, func() { session.sessionEvent <- internal.LogonTimeout }) + time.AfterFunc(session.LogonTimeout, func() { + select { + case <-session.stopCh: + return + case session.sessionEvent <- internal.LogonTimeout: + } + }) } func (sm *stateMachine) Stop(session *session) {