Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250813180357.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: [parallelisation] Add an option to collate all the errors found during a function store execution
53 changes: 43 additions & 10 deletions utils/parallelisation/cancel_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type StoreOptions struct {
stopOnFirstError bool
sequential bool
reverse bool
joinErrors bool
}

type StoreOption func(*StoreOptions) *StoreOptions

// StopOnFirstError stops store execution on first error.
Expand All @@ -30,6 +30,18 @@ var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions {
return o
}
o.stopOnFirstError = true
o.joinErrors = false
return o
}

// JoinErrors will collate any errors which happened when executing functions in store.
// This option should not be used in combination to StopOnFirstError.
var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.stopOnFirstError = false
o.joinErrors = true
return o
}

Expand Down Expand Up @@ -130,9 +142,9 @@ func (s *store[T]) Execute(ctx context.Context) (err error) {
}

if s.options.sequential {
err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse)
err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors)
} else {
err = s.executeInParallel(ctx, s.options.stopOnFirstError)
err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors)
}

if err == nil && s.options.clearOnExecution {
Expand All @@ -141,31 +153,48 @@ func (s *store[T]) Execute(ctx context.Context) (err error) {
return
}

func (s *store[T]) executeInParallel(ctx context.Context, stopOnFirstError bool) error {
func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error {
g, gCtx := errgroup.WithContext(ctx)
if !stopOnFirstError {
gCtx = ctx
}
g.SetLimit(len(s.functions))
funcNum := len(s.functions)
errCh := make(chan error, funcNum)
g.SetLimit(funcNum)
for i := range s.functions {
g.Go(func() error {
_, subErr := s.executeFunction(gCtx, s.functions[i])
errCh <- subErr
return subErr
})
}
err := g.Wait()
close(errCh)
if collateErrors {
collateErr := make([]error, funcNum)
i := 0
for subErr := range errCh {
collateErr[i] = subErr
i++
}
err = commonerrors.Join(collateErr...)
}

return g.Wait()
return err
}

func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse bool) (err error) {
func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) {
err = DetermineContextError(ctx)
if err != nil {
return
}
funcNum := len(s.functions)
collateErr := make([]error, funcNum)
if reverse {
for i := len(s.functions) - 1; i >= 0; i-- {
mustBreak, subErr := s.executeFunction(ctx, s.functions[i])
if mustBreak {
for i := funcNum - 1; i >= 0; i-- {
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
collateErr[funcNum-i-1] = subErr
if shouldBreak {
err = subErr
return
}
Expand All @@ -179,6 +208,7 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re
} else {
for i := range s.functions {
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
collateErr[i] = subErr
if shouldBreak {
err = subErr
return
Expand All @@ -192,6 +222,9 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re
}
}

if collateErrors {
err = commonerrors.Join(collateErr...)
}
return
}

Expand Down
21 changes: 21 additions & 0 deletions utils/parallelisation/onclose.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,41 @@ func CloseAll(cs ...io.Closer) error {
return group.Close()
}

// CloseAllAndCollateErrors calls concurrently Close on all io.Closer implementations passed as arguments and returns the errors encountered
func CloseAllAndCollateErrors(cs ...io.Closer) error {
group := NewCloserStoreWithOptions(ExecuteAll, Parallel, JoinErrors)
group.RegisterFunction(cs...)
return group.Close()
}

// CloseAllWithContext is similar to CloseAll but can be controlled using a context.
func CloseAllWithContext(ctx context.Context, cs ...io.Closer) error {
group := NewCloserStore(false)
group.RegisterFunction(cs...)
return group.Execute(ctx)
}

// CloseAllWithContextAndCollateErrors is similar to CloseAllAndCollateErrors but can be controlled using a context.
func CloseAllWithContextAndCollateErrors(ctx context.Context, cs ...io.Closer) error {
group := NewCloserStoreWithOptions(ExecuteAll, Parallel, JoinErrors)
group.RegisterFunction(cs...)
return group.Execute(ctx)
}

// CloseAllFunc calls concurrently all Close functions passed as arguments and returns the first error encountered
func CloseAllFunc(cs ...CloseFunc) error {
group := NewCloseFunctionStoreStore(false)
group.RegisterFunction(cs...)
return group.Close()
}

// CloseAllFuncAndCollateErrors calls concurrently all Close functions passed as arguments and returns the errors encountered
func CloseAllFuncAndCollateErrors(cs ...CloseFunc) error {
group := NewCloseFunctionStore(ExecuteAll, Parallel, JoinErrors)
group.RegisterFunction(cs...)
return group.Close()
}

type CloseFunc func() error

type CloseFunctionStore struct {
Expand Down
122 changes: 87 additions & 35 deletions utils/parallelisation/onclose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package parallelisation

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -25,6 +26,16 @@ func TestCloseAll(t *testing.T) {
require.NoError(t, CloseAll(closerMock, closerMock, closerMock))
})

t.Run("close and join errors", func(t *testing.T) {
ctlr := gomock.NewController(t)
defer ctlr.Finish()

closerMock := mocks.NewMockCloser(ctlr)
closerMock.EXPECT().Close().Return(nil).MinTimes(1)

require.NoError(t, CloseAllAndCollateErrors(closerMock, closerMock, closerMock))
})

t.Run("close with error", func(t *testing.T) {
ctlr := gomock.NewController(t)
defer ctlr.Finish()
Expand All @@ -36,12 +47,29 @@ func TestCloseAll(t *testing.T) {
errortest.AssertError(t, CloseAll(closerMock, closerMock, closerMock), closeError)
})

t.Run("close with errors", func(t *testing.T) {
ctlr := gomock.NewController(t)
defer ctlr.Finish()
closeError := commonerrors.ErrUnexpected

closerMock := mocks.NewMockCloser(ctlr)
closerMock.EXPECT().Close().Return(closeError).MinTimes(1)

errortest.AssertError(t, CloseAllAndCollateErrors(closerMock, closerMock, closerMock), closeError)
})

t.Run("close with 1 error", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected

errortest.AssertError(t, CloseAllFunc(func() error { return nil }, func() error { return nil }, func() error { return closeError }, func() error { return nil }), closeError)
})

t.Run("close with 1 error but error collection", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected

errortest.AssertError(t, CloseAllFuncAndCollateErrors(func() error { return nil }, func() error { return nil }, func() error { return closeError }, func() error { return nil }), closeError)
})

}

func TestCancelOnClose(t *testing.T) {
Expand Down Expand Up @@ -98,39 +126,63 @@ func TestCancelOnClose(t *testing.T) {
})
}

func TestStopOnFirstError(t *testing.T) {
t.Run("sequentially", func(t *testing.T) {
closeStore := NewCloseFunctionStore(StopOnFirstError, Sequential)
ctx1, cancel1 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) })
ctx2, cancel2 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) })
ctx3, cancel3 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) })
assert.Equal(t, 3, closeStore.Len())
require.NoError(t, DetermineContextError(ctx1))
require.NoError(t, DetermineContextError(ctx2))
require.NoError(t, DetermineContextError(ctx3))
errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled)
errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled)
assert.NoError(t, DetermineContextError(ctx2))
assert.NoError(t, DetermineContextError(ctx3))
})
t.Run("reverse", func(t *testing.T) {
closeStore := NewCloseFunctionStore(StopOnFirstError, SequentialInReverse)
ctx1, cancel1 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) })
ctx2, cancel2 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) })
ctx3, cancel3 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) })
assert.Equal(t, 3, closeStore.Len())
require.NoError(t, DetermineContextError(ctx1))
require.NoError(t, DetermineContextError(ctx2))
require.NoError(t, DetermineContextError(ctx3))
errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled)
assert.NoError(t, DetermineContextError(ctx1))
assert.NoError(t, DetermineContextError(ctx2))
errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled)
})
func TestSequentialExecution(t *testing.T) {
tests := []struct {
option StoreOption
}{
{StopOnFirstError},
{JoinErrors},
}
for i := range tests {
test := tests[i]
t.Run(fmt.Sprintf("%v-%#v", i, test.option), func(t *testing.T) {
opt := test.option(&StoreOptions{})
t.Run("sequentially", func(t *testing.T) {
closeStore := NewCloseFunctionStore(test.option, Sequential)
ctx1, cancel1 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) })
ctx2, cancel2 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) })
ctx3, cancel3 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) })
assert.Equal(t, 3, closeStore.Len())
require.NoError(t, DetermineContextError(ctx1))
require.NoError(t, DetermineContextError(ctx2))
require.NoError(t, DetermineContextError(ctx3))

errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled)
errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled)
if opt.stopOnFirstError {
assert.NoError(t, DetermineContextError(ctx2))
assert.NoError(t, DetermineContextError(ctx3))
} else {
errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled)
errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled)
}

})
t.Run("reverse", func(t *testing.T) {
closeStore := NewCloseFunctionStore(test.option, SequentialInReverse)
ctx1, cancel1 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) })
ctx2, cancel2 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) })
ctx3, cancel3 := context.WithCancel(context.Background())
closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) })
assert.Equal(t, 3, closeStore.Len())
require.NoError(t, DetermineContextError(ctx1))
require.NoError(t, DetermineContextError(ctx2))
require.NoError(t, DetermineContextError(ctx3))
errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled)
if opt.stopOnFirstError {
assert.NoError(t, DetermineContextError(ctx1))
assert.NoError(t, DetermineContextError(ctx2))
} else {
errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled)
errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled)
}
errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled)
})
})
}
}
Loading