From b6c30de9338806807a09d3d11f618fefe96ff2fe Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 13 Aug 2025 06:38:31 +0100 Subject: [PATCH 1/3] :sparkles: `[parallelisation]` Extend the cancel/close store to have various ways of functioning (e.g. parallel, sequential, reverse order, etc.) --- changes/20250813063457.feature | 1 + utils/parallelisation/cancel_functions.go | 186 +++++++++++++++--- .../parallelisation/cancel_functions_test.go | 56 ++++-- utils/parallelisation/onclose.go | 36 +++- utils/parallelisation/onclose_test.go | 103 ++++++++-- 5 files changed, 316 insertions(+), 66 deletions(-) create mode 100644 changes/20250813063457.feature diff --git a/changes/20250813063457.feature b/changes/20250813063457.feature new file mode 100644 index 0000000000..822612ff34 --- /dev/null +++ b/changes/20250813063457.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Extend the cancel/close store to have various ways of functioning (e.g. parallel, sequential, reverse order, etc.) diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index cce577b019..cafc2f9bf1 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -15,22 +15,99 @@ import ( "github.com/ARM-software/golang-utils/utils/reflection" ) -func newFunctionStore[T any](clearOnExecution, stopOnFirstError bool, executeFunc func(context.Context, T) error) *store[T] { +type StoreOptions struct { + clearOnExecution bool + stopOnFirstError bool + sequential bool + reverse bool +} + +type StoreOption func(*StoreOptions) *StoreOptions + +// StopOnFirstError stops store execution on first error. +var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.stopOnFirstError = true + return o +} + +// ExecuteAll executes all functions in the store even if an error is raised. the first error raised is then returned. +var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.stopOnFirstError = false + return o +} + +// ClearAfterExecution clears the store after execution. +var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.clearOnExecution = true + return o +} + +// RetainAfterExecution keep the store intact after execution (no reset). +var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.clearOnExecution = false + return o +} + +// Parallel ensures every function registered in the store is executed concurrently in the order they were registered. +var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.sequential = false + return o +} + +// Sequential ensures every function registered in the store is executed sequentially in the order they were registered. +var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.sequential = true + return o +} + +// SequentialInReverse ensures every function registered in the store is executed sequentially but in the reverse order they were registered. +var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + return o + } + o.sequential = true + o.reverse = true + return o +} + +func newFunctionStore[T any](executeFunc func(context.Context, T) error, options ...StoreOption) *store[T] { + + opts := &StoreOptions{} + + for i := range options { + opts = options[i](opts) + } return &store[T]{ - mu: deadlock.RWMutex{}, - functions: make([]T, 0), - executeFunc: executeFunc, - clearOnExecution: clearOnExecution, - stopOnFirstError: stopOnFirstError, + mu: deadlock.RWMutex{}, + functions: make([]T, 0), + executeFunc: executeFunc, + options: *opts, } } type store[T any] struct { - mu deadlock.RWMutex - functions []T - executeFunc func(ctx context.Context, element T) error - clearOnExecution bool - stopOnFirstError bool + mu deadlock.RWMutex + functions []T + executeFunc func(ctx context.Context, element T) error + options StoreOptions } func (s *store[T]) RegisterFunction(function ...T) { @@ -45,32 +122,87 @@ func (s *store[T]) Len() int { return len(s.functions) } -func (s *store[T]) Execute(ctx context.Context) error { +func (s *store[T]) Execute(ctx context.Context) (err error) { defer s.mu.Unlock() s.mu.Lock() if reflection.IsEmpty(s.executeFunc) { - return commonerrors.New(commonerrors.ErrUndefined, "the cancel store was not initialised correctly") + return commonerrors.New(commonerrors.ErrUndefined, "the store was not initialised correctly") + } + + if s.options.sequential { + err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse) + } else { + err = s.executeConcurrently(ctx, s.options.stopOnFirstError) } + + if err == nil && s.options.clearOnExecution { + s.functions = make([]T, 0, len(s.functions)) + } + return +} + +func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool) error { g, gCtx := errgroup.WithContext(ctx) - if !s.stopOnFirstError { + if !stopOnFirstError { gCtx = ctx } g.SetLimit(len(s.functions)) for i := range s.functions { g.Go(func() error { - err := DetermineContextError(gCtx) - if err != nil { - return err - } - return s.executeFunc(gCtx, s.functions[i]) + _, subErr := s.executeFunction(gCtx, s.functions[i]) + return subErr }) } - err := g.Wait() - if err == nil && s.clearOnExecution { - s.functions = make([]T, 0, len(s.functions)) + return g.Wait() +} + +func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse bool) (err error) { + err = DetermineContextError(ctx) + if err != nil { + return + } + if reverse { + for i := len(s.functions) - 1; i >= 0; i-- { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } else { + for i := range s.functions { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } + + return +} + +func (s *store[T]) executeFunction(ctx context.Context, element T) (shouldBreak bool, err error) { + err = DetermineContextError(ctx) + if err != nil { + shouldBreak = true + return } - return err + err = s.executeFunc(ctx, element) + return } type CancelFunctionStore struct { @@ -90,12 +222,12 @@ func (s *CancelFunctionStore) Len() int { return s.store.Len() } -// NewCancelFunctionsStore creates a store for cancel functions. -func NewCancelFunctionsStore() *CancelFunctionStore { +// NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed. +func NewCancelFunctionsStore(options ...StoreOption) *CancelFunctionStore { return &CancelFunctionStore{ - store: *newFunctionStore[context.CancelFunc](true, false, func(_ context.Context, cancelFunc context.CancelFunc) error { + store: *newFunctionStore[context.CancelFunc](func(_ context.Context, cancelFunc context.CancelFunc) error { cancelFunc() return nil - }), + }, append(options, ClearAfterExecution, ExecuteAll)...), } } diff --git a/utils/parallelisation/cancel_functions_test.go b/utils/parallelisation/cancel_functions_test.go index 2ef6f5889a..75163f5081 100644 --- a/utils/parallelisation/cancel_functions_test.go +++ b/utils/parallelisation/cancel_functions_test.go @@ -9,36 +9,54 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" ) +func testCancelStore(t *testing.T, store *CancelFunctionStore) { + t.Helper() + require.NotNil(t, store) + // Set up some fake CancelFuncs to make sure they are called + called1 := false + called2 := false + cancelFunc1 := func() { + called1 = true + } + cancelFunc2 := func() { + called2 = true + } + + store.RegisterCancelFunction(cancelFunc1, cancelFunc2) + + assert.Equal(t, 2, store.Len()) + assert.False(t, called1) + assert.False(t, called2) + store.Cancel() + + assert.True(t, called1) + assert.True(t, called2) +} + // Given a CancelFunctionsStore // Functions can be registered // and all functions will be called func TestCancelFunctionStore(t *testing.T) { t.Run("valid cancel store", func(t *testing.T) { - // Set up some fake CancelFuncs to make sure they are called - called1 := false - called2 := false - cancelFunc1 := func() { - called1 = true - } - cancelFunc2 := func() { - called2 = true - } - - store := NewCancelFunctionsStore() - - store.RegisterCancelFunction(cancelFunc1, cancelFunc2) - - assert.Equal(t, 2, store.Len()) - - store.Cancel() - assert.True(t, called1) - assert.True(t, called2) + t.Run("parallel", func(t *testing.T) { + testCancelStore(t, NewCancelFunctionsStore()) + }) + t.Run("sequential", func(t *testing.T) { + testCancelStore(t, NewCancelFunctionsStore(Sequential)) + }) + t.Run("reverse", func(t *testing.T) { + testCancelStore(t, NewCancelFunctionsStore(SequentialInReverse)) + }) + t.Run("execute all", func(t *testing.T) { + testCancelStore(t, NewCancelFunctionsStore(StopOnFirstError)) + }) }) t.Run("incorrectly initialised cancel store", func(t *testing.T) { diff --git a/utils/parallelisation/onclose.go b/utils/parallelisation/onclose.go index 1bf23ac4f0..b239522387 100644 --- a/utils/parallelisation/onclose.go +++ b/utils/parallelisation/onclose.go @@ -25,13 +25,22 @@ func (s *CloserStore) Len() int { // NewCloserStore returns a store of io.Closer object which will all be closed concurrently on Close(). The first error received will be returned func NewCloserStore(stopOnFirstError bool) *CloserStore { + option := ExecuteAll + if stopOnFirstError { + option = StopOnFirstError + } + return NewCloserStoreWithOptions(option, Parallel) +} + +// NewCloserStoreWithOptions returns a store of io.Closer object which will all be closed on Close(). The first error received if any will be returned +func NewCloserStoreWithOptions(opts ...StoreOption) *CloserStore { return &CloserStore{ - store: *newFunctionStore[io.Closer](false, stopOnFirstError, func(_ context.Context, closerObj io.Closer) error { + store: *newFunctionStore[io.Closer](func(_ context.Context, closerObj io.Closer) error { if closerObj == nil { return commonerrors.UndefinedVariable("closer object") } return closerObj.Close() - }), + }, append(opts, RetainAfterExecution)...), } } @@ -90,11 +99,26 @@ func (s *CloseFunctionStore) Len() int { return s.store.Len() } -// NewCloseFunctionStoreStore returns a store closing functions which will all be called concurrently on Close(). The first error received will be returned. -func NewCloseFunctionStoreStore(stopOnFirstError bool) *CloseFunctionStore { +// NewCloseFunctionStore returns a store closing functions which will all be called on Close(). The first error received if any will be returned. +func NewCloseFunctionStore(options ...StoreOption) *CloseFunctionStore { return &CloseFunctionStore{ - store: *newFunctionStore[CloseFunc](false, stopOnFirstError, func(_ context.Context, closerObj CloseFunc) error { + store: *newFunctionStore[CloseFunc](func(_ context.Context, closerObj CloseFunc) error { return closerObj() - }), + }, append(options, RetainAfterExecution)...), + } +} + +// NewCloseFunctionStoreStore is exactly the same as NewConcurrentCloseFunctionStore but without a typo in the name. +func NewCloseFunctionStoreStore(stopOnFirstError bool) *CloseFunctionStore { + return NewConcurrentCloseFunctionStore(stopOnFirstError) +} + +// NewConcurrentCloseFunctionStore returns a store closing functions which will all be called concurrently on Close(). The first error received will be returned. +// Prefer using NewCloseFunctionStore where possible +func NewConcurrentCloseFunctionStore(stopOnFirstError bool) *CloseFunctionStore { + option := ExecuteAll + if stopOnFirstError { + option = StopOnFirstError } + return NewCloseFunctionStore(option, Parallel) } diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index df1efd620a..436fa42bff 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -44,18 +45,92 @@ func TestCloseAll(t *testing.T) { } func TestCancelOnClose(t *testing.T) { - closeStore := NewCloseFunctionStoreStore(true) - ctx1, cancel := context.WithCancel(context.Background()) - closeStore.RegisterCancelFunction(cancel) - ctx2, cancel := context.WithCancel(context.Background()) - closeStore.RegisterCancelFunction(cancel) - ctx3, cancel := context.WithCancel(context.Background()) - closeStore.RegisterCancelFunction(cancel) - require.NoError(t, DetermineContextError(ctx1)) - require.NoError(t, DetermineContextError(ctx2)) - require.NoError(t, DetermineContextError(ctx3)) - require.NoError(t, closeStore.Close()) - errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) - errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) - errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + t.Run("parallel", func(t *testing.T) { + closeStore := NewCloseFunctionStoreStore(true) + ctx1, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx2, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx3, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + require.NoError(t, closeStore.Close()) + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + }) + t.Run("sequentially", func(t *testing.T) { + closeStore := NewCloseFunctionStore(StopOnFirstError, Sequential) + ctx1, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx2, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx3, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + require.NoError(t, closeStore.Close()) + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + }) + t.Run("reverse", func(t *testing.T) { + closeStore := NewCloseFunctionStore(StopOnFirstError, SequentialInReverse) + ctx1, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx2, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + ctx3, cancel := context.WithCancel(context.Background()) + closeStore.RegisterCancelFunction(cancel) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + require.NoError(t, closeStore.Close()) + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx2), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + }) +} + +func TestStopOnFirstError(t *testing.T) { + t.Run("sequentially", func(t *testing.T) { + closeStore := NewCloseFunctionStore(StopOnFirstError, Sequential) + ctx1, cancel1 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) + ctx2, cancel2 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) + ctx3, cancel3 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) + errortest.AssertError(t, DetermineContextError(ctx1), commonerrors.ErrCancelled) + assert.NoError(t, DetermineContextError(ctx2)) + assert.NoError(t, DetermineContextError(ctx3)) + }) + t.Run("reverse", func(t *testing.T) { + closeStore := NewCloseFunctionStore(StopOnFirstError, SequentialInReverse) + ctx1, cancel1 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel1(); return DetermineContextError(ctx1) }) + ctx2, cancel2 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel2(); return DetermineContextError(ctx2) }) + ctx3, cancel3 := context.WithCancel(context.Background()) + closeStore.RegisterCloseFunction(func() error { cancel3(); return DetermineContextError(ctx3) }) + assert.Equal(t, 3, closeStore.Len()) + require.NoError(t, DetermineContextError(ctx1)) + require.NoError(t, DetermineContextError(ctx2)) + require.NoError(t, DetermineContextError(ctx3)) + errortest.AssertError(t, closeStore.Close(), commonerrors.ErrCancelled) + assert.NoError(t, DetermineContextError(ctx1)) + assert.NoError(t, DetermineContextError(ctx2)) + errortest.AssertError(t, DetermineContextError(ctx3), commonerrors.ErrCancelled) + }) } From d106cf186d0fd3dcd41bd6deaafde19e9dfe81cb Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 13 Aug 2025 10:54:28 +0100 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Abdelrahman Abdelraouf --- utils/parallelisation/cancel_functions.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index cafc2f9bf1..5e1ddcea7f 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -141,7 +141,7 @@ func (s *store[T]) Execute(ctx context.Context) (err error) { return } -func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool) error { +func (s *store[T]) executeInParallel(ctx context.Context, stopOnFirstError bool) error { g, gCtx := errgroup.WithContext(ctx) if !stopOnFirstError { gCtx = ctx @@ -164,8 +164,8 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re } if reverse { for i := len(s.functions) - 1; i >= 0; i-- { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) - if shouldBreak { + mustBreak, subErr := s.executeFunction(ctx, s.functions[i]) + if mustBreak { err = subErr return } @@ -195,10 +195,10 @@ func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, re return } -func (s *store[T]) executeFunction(ctx context.Context, element T) (shouldBreak bool, err error) { +func (s *store[T]) executeFunction(ctx context.Context, element T) (mustBreak bool, err error) { err = DetermineContextError(ctx) if err != nil { - shouldBreak = true + mustBreak = true return } err = s.executeFunc(ctx, element) @@ -222,7 +222,7 @@ func (s *CancelFunctionStore) Len() int { return s.store.Len() } -// NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed. +// NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed and cleared. In other words, options `RetainAfterExecution` and `StopOnFirstError` would be discarded if selected to create the Cancel store func NewCancelFunctionsStore(options ...StoreOption) *CancelFunctionStore { return &CancelFunctionStore{ store: *newFunctionStore[context.CancelFunc](func(_ context.Context, cancelFunc context.CancelFunc) error { From 992d419f6e57dea604f80396f4ab4ab1c8e988a8 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 13 Aug 2025 10:54:47 +0100 Subject: [PATCH 3/3] Update utils/parallelisation/cancel_functions.go --- utils/parallelisation/cancel_functions.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index 5e1ddcea7f..42ee63e08c 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -132,7 +132,7 @@ func (s *store[T]) Execute(ctx context.Context) (err error) { if s.options.sequential { err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse) } else { - err = s.executeConcurrently(ctx, s.options.stopOnFirstError) + err = s.executeInParallel(ctx, s.options.stopOnFirstError) } if err == nil && s.options.clearOnExecution {