diff --git a/changes/20250813180357.feature b/changes/20250813180357.feature new file mode 100644 index 0000000000..ee2f04e0ee --- /dev/null +++ b/changes/20250813180357.feature @@ -0,0 +1 @@ +:sparkles: [parallelisation] Add an option to collate all the errors found during a function store execution diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index 42ee63e08c..daa6ab8885 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -20,8 +20,8 @@ type StoreOptions struct { stopOnFirstError bool sequential bool reverse bool + joinErrors bool } - type StoreOption func(*StoreOptions) *StoreOptions // StopOnFirstError stops store execution on first error. @@ -30,6 +30,18 @@ var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { return o } o.stopOnFirstError = true + o.joinErrors = false + return o +} + +// JoinErrors will collate any errors which happened when executing functions in store. +// This option should not be used in combination to StopOnFirstError. +var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.stopOnFirstError = false + o.joinErrors = true return o } @@ -130,9 +142,9 @@ func (s *store[T]) Execute(ctx context.Context) (err error) { } if s.options.sequential { - err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse) + err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors) } else { - err = s.executeInParallel(ctx, s.options.stopOnFirstError) + err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors) } if err == nil && s.options.clearOnExecution { @@ -141,31 +153,48 @@ func (s *store[T]) Execute(ctx context.Context) (err error) { return } -func (s *store[T]) executeInParallel(ctx context.Context, stopOnFirstError bool) error { +func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error { g, gCtx := errgroup.WithContext(ctx) if !stopOnFirstError { gCtx = ctx } - g.SetLimit(len(s.functions)) + funcNum := len(s.functions) + errCh := make(chan error, funcNum) + g.SetLimit(funcNum) for i := range s.functions { g.Go(func() error { _, subErr := s.executeFunction(gCtx, s.functions[i]) + errCh <- subErr return subErr }) } + err := g.Wait() + close(errCh) + if collateErrors { + collateErr := make([]error, funcNum) + i := 0 + for subErr := range errCh { + collateErr[i] = subErr + i++ + } + err = commonerrors.Join(collateErr...) + } - return g.Wait() + return err } -func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse bool) (err error) { +func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) { err = DetermineContextError(ctx) if err != nil { return } + funcNum := len(s.functions) + collateErr := make([]error, funcNum) if reverse { - for i := len(s.functions) - 1; i >= 0; i-- { - mustBreak, subErr := s.executeFunction(ctx, s.functions[i]) - if mustBreak { + for i := funcNum - 1; i >= 0; i-- { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[funcNum-i-1] = subErr + if shouldBreak { err = subErr return } @@ -179,6 +208,7 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re } else { for i := range s.functions { shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[i] = subErr if shouldBreak { err = subErr return @@ -192,6 +222,9 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re } } + if collateErrors { + err = commonerrors.Join(collateErr...) + } return } diff --git a/utils/parallelisation/onclose.go b/utils/parallelisation/onclose.go index b239522387..5f3d3fa37a 100644 --- a/utils/parallelisation/onclose.go +++ b/utils/parallelisation/onclose.go @@ -51,6 +51,13 @@ func CloseAll(cs ...io.Closer) error { return group.Close() } +// CloseAllAndCollateErrors calls concurrently Close on all io.Closer implementations passed as arguments and returns the errors encountered +func CloseAllAndCollateErrors(cs ...io.Closer) error { + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, JoinErrors) + group.RegisterFunction(cs...) + return group.Close() +} + // CloseAllWithContext is similar to CloseAll but can be controlled using a context. func CloseAllWithContext(ctx context.Context, cs ...io.Closer) error { group := NewCloserStore(false) @@ -58,6 +65,13 @@ func CloseAllWithContext(ctx context.Context, cs ...io.Closer) error { return group.Execute(ctx) } +// CloseAllWithContextAndCollateErrors is similar to CloseAllAndCollateErrors but can be controlled using a context. +func CloseAllWithContextAndCollateErrors(ctx context.Context, cs ...io.Closer) error { + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, JoinErrors) + group.RegisterFunction(cs...) + return group.Execute(ctx) +} + // CloseAllFunc calls concurrently all Close functions passed as arguments and returns the first error encountered func CloseAllFunc(cs ...CloseFunc) error { group := NewCloseFunctionStoreStore(false) @@ -65,6 +79,13 @@ func CloseAllFunc(cs ...CloseFunc) error { return group.Close() } +// CloseAllFuncAndCollateErrors calls concurrently all Close functions passed as arguments and returns the errors encountered +func CloseAllFuncAndCollateErrors(cs ...CloseFunc) error { + group := NewCloseFunctionStore(ExecuteAll, Parallel, JoinErrors) + group.RegisterFunction(cs...) + return group.Close() +} + type CloseFunc func() error type CloseFunctionStore struct { diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index 436fa42bff..2287796c11 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -2,6 +2,7 @@ package parallelisation import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -25,6 +26,16 @@ func TestCloseAll(t *testing.T) { require.NoError(t, CloseAll(closerMock, closerMock, closerMock)) }) + t.Run("close and join errors", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).MinTimes(1) + + require.NoError(t, CloseAllAndCollateErrors(closerMock, closerMock, closerMock)) + }) + t.Run("close with error", func(t *testing.T) { ctlr := gomock.NewController(t) defer ctlr.Finish() @@ -36,12 +47,29 @@ func TestCloseAll(t *testing.T) { errortest.AssertError(t, CloseAll(closerMock, closerMock, closerMock), closeError) }) + t.Run("close with errors", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + closeError := commonerrors.ErrUnexpected + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(closeError).MinTimes(1) + + errortest.AssertError(t, CloseAllAndCollateErrors(closerMock, closerMock, closerMock), closeError) + }) + t.Run("close with 1 error", func(t *testing.T) { closeError := commonerrors.ErrUnexpected errortest.AssertError(t, CloseAllFunc(func() error { return nil }, func() error { return nil }, func() error { return closeError }, func() error { return nil }), closeError) }) + t.Run("close with 1 error but error collection", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + + errortest.AssertError(t, CloseAllFuncAndCollateErrors(func() error { return nil }, func() error { return nil }, func() error { return closeError }, func() error { return nil }), closeError) + }) + } func TestCancelOnClose(t *testing.T) { @@ -98,39 +126,63 @@ func TestCancelOnClose(t *testing.T) { }) } -func TestStopOnFirstError(t *testing.T) { - t.Run("sequentially", func(t *testing.T) { - closeStore := NewCloseFunctionStore(StopOnFirstError, Sequential) - ctx1, cancel1 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) - ctx2, cancel2 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) - ctx3, cancel3 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) - assert.Equal(t, 3, closeStore.Len()) - require.NoError(t, DetermineContextError(ctx1)) - require.NoError(t, DetermineContextError(ctx2)) - require.NoError(t, DetermineContextError(ctx3)) - errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) - errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) - assert.NoError(t, DetermineContextError(ctx2)) - assert.NoError(t, DetermineContextError(ctx3)) - }) - t.Run("reverse", func(t *testing.T) { - closeStore := NewCloseFunctionStore(StopOnFirstError, SequentialInReverse) - ctx1, cancel1 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) - ctx2, cancel2 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) - ctx3, cancel3 := context.WithCancel(context.Background()) - closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) - assert.Equal(t, 3, closeStore.Len()) - require.NoError(t, DetermineContextError(ctx1)) - require.NoError(t, DetermineContextError(ctx2)) - require.NoError(t, DetermineContextError(ctx3)) - errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) - assert.NoError(t, DetermineContextError(ctx1)) - assert.NoError(t, DetermineContextError(ctx2)) - errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) - }) +func TestSequentialExecution(t *testing.T) { + tests := []struct { + option StoreOption + }{ + {StopOnFirstError}, + {JoinErrors}, + } + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("%v-%#v", i, test.option), func(t *testing.T) { + opt := test.option(&StoreOptions{}) + t.Run("sequentially", func(t *testing.T) { + closeStore := NewCloseFunctionStore(test.option, Sequential) + ctx1, cancel1 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) + ctx2, cancel2 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) + ctx3, cancel3 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + + errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + if opt.stopOnFirstError { + assert.NoError(t, DetermineContextError(ctx2)) + assert.NoError(t, DetermineContextError(ctx3)) + } else { + errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + } + + }) + t.Run("reverse", func(t *testing.T) { + closeStore := NewCloseFunctionStore(test.option, SequentialInReverse) + ctx1, cancel1 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) + ctx2, cancel2 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) + ctx3, cancel3 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) + if opt.stopOnFirstError { + assert.NoError(t, DetermineContextError(ctx1)) + assert.NoError(t, DetermineContextError(ctx2)) + } else { + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) + } + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + }) + }) + } }