Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250820002654.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `[parallelisation]` Added new groups (ContextualFunctionGroup) and new Store options to configure the execution (number of workers, single execution, etc.)
244 changes: 6 additions & 238 deletions utils/parallelisation/cancel_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,245 +5,14 @@

package parallelisation

import (
"context"

"github.com/sasha-s/go-deadlock"
"golang.org/x/sync/errgroup"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/reflection"
)

type StoreOptions struct {
clearOnExecution bool
stopOnFirstError bool
sequential bool
reverse bool
joinErrors bool
}
type StoreOption func(*StoreOptions) *StoreOptions

// StopOnFirstError stops store execution on first error.
var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.stopOnFirstError = true
o.joinErrors = false
return o
}

// JoinErrors will collate any errors which happened when executing functions in store.
// This option should not be used in combination to StopOnFirstError.
var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.stopOnFirstError = false
o.joinErrors = true
return o
}

// ExecuteAll executes all functions in the store even if an error is raised. the first error raised is then returned.
var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.stopOnFirstError = false
return o
}

// ClearAfterExecution clears the store after execution.
var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.clearOnExecution = true
return o
}

// RetainAfterExecution keep the store intact after execution (no reset).
var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.clearOnExecution = false
return o
}

// Parallel ensures every function registered in the store is executed concurrently in the order they were registered.
var Parallel StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.sequential = false
return o
}

// Sequential ensures every function registered in the store is executed sequentially in the order they were registered.
var Sequential StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.sequential = true
return o
}

// SequentialInReverse ensures every function registered in the store is executed sequentially but in the reverse order they were registered.
var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions {
if o == nil {
return o
}
o.sequential = true
o.reverse = true
return o
}

func newFunctionStore[T any](executeFunc func(context.Context, T) error, options ...StoreOption) *store[T] {

opts := &StoreOptions{}

for i := range options {
opts = options[i](opts)
}
return &store[T]{
mu: deadlock.RWMutex{},
functions: make([]T, 0),
executeFunc: executeFunc,
options: *opts,
}
}

type store[T any] struct {
mu deadlock.RWMutex
functions []T
executeFunc func(ctx context.Context, element T) error
options StoreOptions
}

func (s *store[T]) RegisterFunction(function ...T) {
defer s.mu.Unlock()
s.mu.Lock()
s.functions = append(s.functions, function...)
}

func (s *store[T]) Len() int {
defer s.mu.RUnlock()
s.mu.RLock()
return len(s.functions)
}

func (s *store[T]) Execute(ctx context.Context) (err error) {
defer s.mu.Unlock()
s.mu.Lock()
if reflection.IsEmpty(s.executeFunc) {
return commonerrors.New(commonerrors.ErrUndefined, "the store was not initialised correctly")
}

if s.options.sequential {
err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors)
} else {
err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors)
}

if err == nil && s.options.clearOnExecution {
s.functions = make([]T, 0, len(s.functions))
}
return
}

func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error {
g, gCtx := errgroup.WithContext(ctx)
if !stopOnFirstError {
gCtx = ctx
}
funcNum := len(s.functions)
errCh := make(chan error, funcNum)
g.SetLimit(funcNum)
for i := range s.functions {
g.Go(func() error {
_, subErr := s.executeFunction(gCtx, s.functions[i])
errCh <- subErr
return subErr
})
}
err := g.Wait()
close(errCh)
if collateErrors {
collateErr := make([]error, funcNum)
i := 0
for subErr := range errCh {
collateErr[i] = subErr
i++
}
err = commonerrors.Join(collateErr...)
}

return err
}

func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) {
err = DetermineContextError(ctx)
if err != nil {
return
}
funcNum := len(s.functions)
collateErr := make([]error, funcNum)
if reverse {
for i := funcNum - 1; i >= 0; i-- {
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
collateErr[funcNum-i-1] = subErr
if shouldBreak {
err = subErr
return
}
if subErr != nil && err == nil {
err = subErr
if stopOnFirstError {
return
}
}
}
} else {
for i := range s.functions {
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
collateErr[i] = subErr
if shouldBreak {
err = subErr
return
}
if subErr != nil && err == nil {
err = subErr
if stopOnFirstError {
return
}
}
}
}

if collateErrors {
err = commonerrors.Join(collateErr...)
}
return
}

func (s *store[T]) executeFunction(ctx context.Context, element T) (mustBreak bool, err error) {
err = DetermineContextError(ctx)
if err != nil {
mustBreak = true
return
}
err = s.executeFunc(ctx, element)
return
}
import "context"

type CancelFunctionStore struct {
store[context.CancelFunc]
ExecutionGroup[context.CancelFunc]
}

func (s *CancelFunctionStore) RegisterCancelFunction(cancel ...context.CancelFunc) {
s.store.RegisterFunction(cancel...)
s.ExecutionGroup.RegisterFunction(cancel...)
}

// Cancel will execute the cancel functions in the store. Any errors will be ignored and Execute() is recommended if you need to know if a cancellation failed
Expand All @@ -252,15 +21,14 @@ func (s *CancelFunctionStore) Cancel() {
}

func (s *CancelFunctionStore) Len() int {
return s.store.Len()
return s.ExecutionGroup.Len()
}

// NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed and cleared. In other words, options `RetainAfterExecution` and `StopOnFirstError` would be discarded if selected to create the Cancel store
func NewCancelFunctionsStore(options ...StoreOption) *CancelFunctionStore {
return &CancelFunctionStore{
store: *newFunctionStore[context.CancelFunc](func(_ context.Context, cancelFunc context.CancelFunc) error {
cancelFunc()
return nil
ExecutionGroup: *NewExecutionGroup[context.CancelFunc](func(ctx context.Context, cancelFunc context.CancelFunc) error {
return WrapCancelToContextualFunc(cancelFunc)(ctx)
}, append(options, ClearAfterExecution, ExecuteAll)...),
}
}
39 changes: 39 additions & 0 deletions utils/parallelisation/contextual.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package parallelisation

import (
"context"

"github.com/ARM-software/golang-utils/utils/commonerrors"
)

// DetermineContextError determines what the context error is if any.
func DetermineContextError(ctx context.Context) error {
return commonerrors.ConvertContextError(ctx.Err())
}

type ContextualFunctionGroup struct {
ExecutionGroup[ContextualFunc]
}

// NewContextualGroup returns a group executing contextual functions.
func NewContextualGroup(options ...StoreOption) *ContextualFunctionGroup {
return &ContextualFunctionGroup{
ExecutionGroup: *NewExecutionGroup[ContextualFunc](func(ctx context.Context, contextualF ContextualFunc) error {
return contextualF(ctx)
}, options...),
}
}

// ForEach executes all the contextual functions according to the store options and returns an error if one occurred.
func ForEach(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error {
group := NewContextualGroup(ExecuteAll(executionOptions).Options()...)
group.RegisterFunction(contextualFunc...)
return group.Execute(ctx)
}

// BreakOnError executes each functions in the group until an error is found or the context gets cancelled.
func BreakOnError(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error {
group := NewContextualGroup(StopOnFirstError(executionOptions).Options()...)
group.RegisterFunction(contextualFunc...)
return group.Execute(ctx)
}
48 changes: 48 additions & 0 deletions utils/parallelisation/contextual_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package parallelisation

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
)

func TestForEach(t *testing.T) {
cancelFunc := func() {}
t.Run("close with 1 error", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected

errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError)
})

t.Run("close with 1 error but error collection", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected
errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError)
})

t.Run("close with 1 error but error collection", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected
errortest.AssertError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError)
})

t.Run("close with 1 error but sequential", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected
errortest.AssertError(t, ForEach(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError)
errortest.AssertError(t, BreakOnError(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError)
})

t.Run("close with cancellation", func(t *testing.T) {
closeError := commonerrors.ErrUnexpected
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
errortest.AssertError(t, ForEach(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled)
errortest.AssertError(t, BreakOnError(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled)
})

t.Run("break on error with no error", func(t *testing.T) {
require.NoError(t, BreakOnError(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc)))
})
}
Loading
Loading