diff --git a/changes/20250905171217.feature b/changes/20250905171217.feature new file mode 100644 index 0000000000..842af8312e --- /dev/null +++ b/changes/20250905171217.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Define a transformation group diff --git a/changes/20250908111211.feature b/changes/20250908111211.feature new file mode 100644 index 0000000000..589aa73cff --- /dev/null +++ b/changes/20250908111211.feature @@ -0,0 +1 @@ +:sparkles: `[collection]` Added a `Range` function to populate slices of integers diff --git a/utils/collection/range.go b/utils/collection/range.go new file mode 100644 index 0000000000..653f74c62d --- /dev/null +++ b/utils/collection/range.go @@ -0,0 +1,33 @@ +package collection + +import "github.com/ARM-software/golang-utils/utils/field" + +func sign(x int) int { + if x < 0 { + return -1 + } + return 1 +} + +// Range returns a slice of integers similar to Python's built-in range(). +// https://docs.python.org/2/library/functions.html#range +// +// Note: The stop value is always exclusive. +func Range(start, stop int, step *int) []int { + s := field.OptionalInt(step, 1) + if s == 0 { + return []int{} + } + + // Compute length + length := 0 + if (s > 0 && start < stop) || (s < 0 && start > stop) { + length = (stop - start + s - sign(s)) / s + } + + result := make([]int, length) + for i, v := 0, start; i < length; i, v = i+1, v+s { + result[i] = v + } + return result +} diff --git a/utils/collection/range_test.go b/utils/collection/range_test.go new file mode 100644 index 0000000000..31b09195ae --- /dev/null +++ b/utils/collection/range_test.go @@ -0,0 +1,40 @@ +package collection + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ARM-software/golang-utils/utils/field" +) + +func TestRange(t *testing.T) { + tests := []struct { + start int + stop int + step *int + expected []int + }{ + + {2, 5, nil, []int{2, 3, 4}}, + {5, 2, nil, []int{}}, // empty, since stop < start + {2, 10, field.ToOptionalInt(2), []int{2, 4, 6, 8}}, + {0, 10, field.ToOptionalInt(3), []int{0, 3, 6, 9}}, + {1, 10, field.ToOptionalInt(3), []int{1, 4, 7}}, + {10, 2, field.ToOptionalInt(-2), []int{10, 8, 6, 4}}, + {5, -1, field.ToOptionalInt(-1), []int{5, 4, 3, 2, 1, 0}}, + {0, -5, field.ToOptionalInt(-2), []int{0, -2, -4}}, + {0, 5, nil, []int{0, 1, 2, 3, 4}}, + {0, 5, field.ToOptionalInt(0), []int{}}, + {2, 2, field.ToOptionalInt(1), []int{}}, + {2, 2, field.ToOptionalInt(-1), []int{}}, + } + + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("[%v,%v,%v]", test.start, test.stop, test.step), func(t *testing.T) { + assert.Equal(t, test.expected, Range(test.start, test.stop, test.step)) + }) + } +} diff --git a/utils/collection/search.go b/utils/collection/search.go index 77c9da9ccf..f6b5156412 100644 --- a/utils/collection/search.go +++ b/utils/collection/search.go @@ -71,8 +71,10 @@ func AnyFunc[S ~[]E, E any](s S, f func(E) bool) bool { return conditions.Any() } +type FilterFunc[E any] func(E) bool + // Filter returns a new slice that contains elements from the input slice which return true when they’re passed as a parameter to the provided filtering function f. -func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) { +func Filter[S ~[]E, E any](s S, f FilterFunc[E]) (result S) { result = make(S, 0, len(s)) for i := range s { @@ -84,8 +86,16 @@ func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) { return result } +type MapFunc[T1, T2 any] func(T1) T2 + +func IdentityMapFunc[T any]() MapFunc[T, T] { + return func(i T) T { + return i + } +} + // Map creates a new slice and populates it with the results of calling the provided function on every element in input slice. -func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) { +func Map[T1 any, T2 any](s []T1, f MapFunc[T1, T2]) (result []T2) { result = make([]T2, len(s)) for i := range s { @@ -97,12 +107,14 @@ func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) { // Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. // This is functionally equivalent to slices.DeleteFunc but it returns a new slice. -func Reject[S ~[]E, E any](s S, f func(E) bool) S { +func Reject[S ~[]E, E any](s S, f FilterFunc[E]) S { return Filter(s, func(e E) bool { return !f(e) }) } +type ReduceFunc[T1, T2 any] func(T2, T1) T2 + // Reduce runs a reducer function f over all elements in the array, in ascending-index order, and accumulates them into a single value. -func Reduce[T1, T2 any](s []T1, accumulator T2, f func(T2, T1) T2) (result T2) { +func Reduce[T1, T2 any](s []T1, accumulator T2, f ReduceFunc[T1, T2]) (result T2) { result = accumulator for i := range s { result = f(result, s[i]) diff --git a/utils/parallelisation/group.go b/utils/parallelisation/group.go index a687c913ea..de91d757e4 100644 --- a/utils/parallelisation/group.go +++ b/utils/parallelisation/group.go @@ -224,6 +224,14 @@ type ICompoundExecutionGroup[T any] interface { // NewExecutionGroup returns an execution group which executes functions according to store options. func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + return NewOrderedExecutionGroup(func(ctx context.Context, index int, element T) error { + return executeFunc(ctx, element) + }, options...) +} + +// NewOrderedExecutionGroup returns an execution group which executes functions according to store options. It also keeps track of the input index. +func NewOrderedExecutionGroup[T any](executeFunc OrderedExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + opts := WithOptions(options...) return &ExecutionGroup[T]{ mu: deadlock.RWMutex{}, @@ -235,10 +243,12 @@ func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption type ExecuteFunc[T any] func(ctx context.Context, element T) error +type OrderedExecuteFunc[T any] func(ctx context.Context, index int, element T) error + type ExecutionGroup[T any] struct { mu deadlock.RWMutex functions []wrappedElement[T] - executeFunc ExecuteFunc[T] + executeFunc OrderedExecuteFunc[T] options StoreOptions } @@ -294,7 +304,7 @@ func (s *ExecutionGroup[T]) executeConcurrently(ctx context.Context, stopOnFirst g.SetLimit(workers) for i := range s.functions { g.Go(func() error { - _, subErr := s.executeFunction(gCtx, s.functions[i]) + _, subErr := s.executeFunction(gCtx, i, s.functions[i]) errCh <- subErr return subErr }) @@ -323,7 +333,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst collateErr := make([]error, funcNum) if reverse { for i := funcNum - 1; i >= 0; i-- { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i]) collateErr[funcNum-i-1] = subErr if shouldBreak { err = subErr @@ -338,7 +348,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst } } else { for i := range s.functions { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i]) collateErr[i] = subErr if shouldBreak { err = subErr @@ -359,7 +369,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst return } -func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElement[T]) (mustBreak bool, err error) { +func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, index int, w wrappedElement[T]) (mustBreak bool, err error) { err = DetermineContextError(ctx) if err != nil { mustBreak = true @@ -370,7 +380,9 @@ func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElemen mustBreak = true return } - err = w.Execute(ctx, s.executeFunc) + err = w.Execute(ctx, func(ctx context.Context, element T) error { + return s.executeFunc(ctx, index, element) + }) return } diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 30e909afdd..5852b253bd 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -12,8 +12,8 @@ import ( "time" "go.uber.org/atomic" - "golang.org/x/sync/errgroup" + "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" ) @@ -265,69 +265,35 @@ func WaitUntil(ctx context.Context, evalCondition func(ctx2 context.Context) (bo } } -func newWorker[JobType, ResultType any](ctx context.Context, f func(context.Context, JobType) (ResultType, bool, error), jobs chan JobType, results chan ResultType) (err error) { - for job := range jobs { - result, ok, subErr := f(ctx, job) - if subErr != nil { - err = commonerrors.WrapError(commonerrors.ErrUnexpected, subErr, "an error occurred whilst handling a job") - return - } - - err = DetermineContextError(ctx) - if err != nil { - return - } - - if ok { - results <- result - } +// WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size +func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f TransformFunc[InputType, ResultType]) (results []ResultType, err error) { + g, err := workerPoolGroup[InputType, ResultType](ctx, numWorkers, jobs, f) + if err != nil { + return } - + results, err = g.Outputs(ctx) return } -// WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size -func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f func(context.Context, InputType) (ResultType, bool, error)) (results []ResultType, err error) { +func workerPoolGroup[I, O any](ctx context.Context, numWorkers int, jobs []I, f TransformFunc[I, O]) (g *TransformGroup[I, O], err error) { if numWorkers < 1 { err = commonerrors.New(commonerrors.ErrInvalid, "numWorkers must be greater than or equal to 1") return } - - numJobs := len(jobs) - jobsChan := make(chan InputType, numJobs) - resultsChan := make(chan ResultType, numJobs) - - g, gCtx := errgroup.WithContext(ctx) - g.SetLimit(numWorkers) - for range numWorkers { - g.Go(func() error { return newWorker(gCtx, f, jobsChan, resultsChan) }) - } - for i := range jobs { - if DetermineContextError(ctx) != nil { - break - } - jobsChan <- jobs[i] - } - - close(jobsChan) - err = g.Wait() - close(resultsChan) - if err == nil { - err = DetermineContextError(ctx) - } + g = NewTransformGroup[I, O](f, Workers(numWorkers), JoinErrors) + err = g.Inputs(ctx, jobs...) if err != nil { return } - - for result := range resultsChan { - results = append(results, result) + err = g.Transform(ctx) + if err != nil { + return } - return } // Filter is similar to collection.Filter but uses parallelisation. -func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) (result []T, err error) { +func Filter[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) (result []T, err error) { result, err = WorkerPool[T, T](ctx, numWorkers, s, func(fCtx context.Context, item T) (r T, ok bool, fErr error) { fErr = DetermineContextError(fCtx) if fErr != nil { @@ -340,9 +306,8 @@ func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ( return } -// Map is similar to collection.Map but uses parallelisation. -func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) T2) (result []T2, err error) { - result, err = WorkerPool[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) { +func mapGroup[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (*TransformGroup[T1, T2], error) { + return workerPoolGroup[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) { fErr = DetermineContextError(fCtx) if fErr != nil { return @@ -351,10 +316,29 @@ func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) ok = true return }) +} + +// Map is similar to collection.Map but uses parallelisation. +func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) { + g, err := mapGroup[T1, T2](ctx, numWorkers, s, f) + if err != nil { + return + } + result, err = g.Outputs(ctx) + return +} + +// OrderedMap is similar to Map but ensures the results are in the same order as the input. +func OrderedMap[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) { + g, err := mapGroup[T1, T2](ctx, numWorkers, s, f) + if err != nil { + return + } + result, err = g.OrderedOutputs(ctx) return } // Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. -func Reject[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ([]T, error) { +func Reject[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) ([]T, error) { return Filter[T](ctx, numWorkers, s, func(e T) bool { return !f(e) }) } diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index fe9d2305c1..a0e63f17b8 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -19,8 +19,10 @@ import ( "go.uber.org/atomic" "go.uber.org/goleak" + "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/field" ) var ( @@ -636,3 +638,35 @@ func TestMap(t *testing.T) { errortest.AssertError(t, err, commonerrors.ErrCancelled) }) } + +func TestMapAndOrderedMap(t *testing.T) { + defer goleak.VerifyNone(t) + ctx := context.Background() + mapped, err := OrderedMap(ctx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + require.NoError(t, err) + assert.Equal(t, []string{"Hello world 1", "Hello world 2"}, mapped) + mapped, err = OrderedMap(ctx, 3, []int64{1, 2, 3, 4}, func(x int64) string { + return strconv.FormatInt(x, 10) + }) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3", "4"}, mapped) + t.Run("cancelled context", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := Map(cancelledCtx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) + + in := collection.Range(0, 1000, field.ToOptionalInt(5)) + mappedInt, err := OrderedMap(ctx, 3, in, collection.IdentityMapFunc[int]()) + require.NoError(t, err) + assert.Equal(t, in, mappedInt) + mappedInt, err = Map(ctx, 3, in, collection.IdentityMapFunc[int]()) + require.NoError(t, err) + assert.NotEqual(t, in, mappedInt) + assert.ElementsMatch(t, in, mappedInt) +} diff --git a/utils/parallelisation/transform.go b/utils/parallelisation/transform.go new file mode 100644 index 0000000000..d12fa94294 --- /dev/null +++ b/utils/parallelisation/transform.go @@ -0,0 +1,161 @@ +package parallelisation + +import ( + "context" + + "go.uber.org/atomic" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/field" +) + +// TransformFunc defines a transformation function which converts an input into an output. +type TransformFunc[I any, O any] func(context.Context, I) (output O, success bool, err error) + +type resultElement[O any] struct { + r O + index int +} +type results[O any] struct { + terminated *atomic.Bool + r chan resultElement[O] +} + +func (r *results[O]) Append(o *resultElement[O]) { + if o == nil { + return + } + if !r.terminated.Load() { + r.r <- *o + } +} + +func (r *results[O]) Results(ctx context.Context) (slice []O, err error) { + if !r.terminated.Swap(true) { + close(r.r) + } + err = DetermineContextError(ctx) + if err != nil { + return + } + slice = make([]O, 0, len(r.r)) + for output := range r.r { + err = DetermineContextError(ctx) + if err != nil { + return + } + slice = append(slice, output.r) + } + return +} + +func (r *results[O]) OrderedResults(ctx context.Context) (slice []O, err error) { + if !r.terminated.Swap(true) { + close(r.r) + } + err = DetermineContextError(ctx) + if err != nil { + return + } + values := make(map[int]O, len(r.r)) + slice = make([]O, len(r.r)) + for output := range r.r { + err = DetermineContextError(ctx) + if err != nil { + return + } + values[output.index] = output.r + } + for i := 0; i < len(slice); i++ { + slice[i] = values[i] + } + return +} + +func newResults[O any](numberOfInput *int) *results[O] { + i := field.OptionalInt(numberOfInput, 0) + var channel chan resultElement[O] + if i <= 0 { + channel = make(chan resultElement[O]) + } else { + channel = make(chan resultElement[O], i) + } + + return &results[O]{ + terminated: atomic.NewBool(false), + r: channel, + } +} + +type TransformGroup[I any, O any] struct { + ExecutionGroup[I] + results *atomic.Pointer[results[O]] +} + +func (g *TransformGroup[I, O]) appendResult(o *resultElement[O]) { + r := g.results.Load() + if r != nil { + r.Append(o) + } +} + +// Inputs registers inputs to transform. +func (g *TransformGroup[I, O]) Inputs(ctx context.Context, i ...I) error { + for j := range i { + err := DetermineContextError(ctx) + if err != nil { + return err + } + g.RegisterFunction(i[j]) + } + return nil +} + +// Outputs returns any input which have been transformed when the Transform function was called. +func (g *TransformGroup[I, O]) Outputs(ctx context.Context) ([]O, error) { + r := g.results.Load() + if r == nil { + return nil, commonerrors.UndefinedVariable("results") + } + return r.Results(ctx) +} + +// OrderedOutputs returns any input which have been transformed when the Transform function was called. The returned output is in the same order as the input slice. +func (g *TransformGroup[I, O]) OrderedOutputs(ctx context.Context) ([]O, error) { + r := g.results.Load() + if r == nil { + return nil, commonerrors.UndefinedVariable("results") + } + return r.OrderedResults(ctx) +} + +// Transform actually performs the transformation over all registered inputs. +func (g *TransformGroup[I, O]) Transform(ctx context.Context) error { + g.results.Store(newResults[O](field.ToOptionalInt(g.Len()))) + return g.ExecutionGroup.Execute(ctx) +} + +// NewTransformGroup returns a group transforming inputs into outputs. +// To register inputs, call the Input function +// To perform the transformation of inputs, then call Transform +// To retrieve the output, then call Output +func NewTransformGroup[I any, O any](transform TransformFunc[I, O], options ...StoreOption) *TransformGroup[I, O] { + g := &TransformGroup[I, O]{ + results: atomic.NewPointer[results[O]](newResults[O](nil)), + } + g.ExecutionGroup = *NewOrderedExecutionGroup[I](func(fCtx context.Context, index int, i I) error { + err := DetermineContextError(fCtx) + if err != nil { + return err + } + o, success, err := transform(fCtx, i) + if err != nil { + return commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred whilst handling an input [%+v]", i) + } + if success { + g.appendResult(&resultElement[O]{index: index, r: o}) + } + return nil + }, options...) + return g +} diff --git a/utils/parallelisation/transform_test.go b/utils/parallelisation/transform_test.go new file mode 100644 index 0000000000..457822581d --- /dev/null +++ b/utils/parallelisation/transform_test.go @@ -0,0 +1,74 @@ +package parallelisation + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/ARM-software/golang-utils/utils/collection" +) + +func TestNewTransformGroup(t *testing.T) { + defer goleak.VerifyNone(t) + tr := func(ctx context.Context, i string) (o int, success bool, err error) { + err = DetermineContextError(ctx) + if err != nil { + return + } + o, err = strconv.Atoi(i) + if err == nil { + success = true + } + return + } + g := NewTransformGroup[string, int](tr, RetainAfterExecution, Parallel) + assert.Zero(t, g.Len()) + o, err := g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + numberOfInput := 50 + in := collection.Range(0, numberOfInput, nil) + in2 := make([]string, numberOfInput) + for i := 0; i < numberOfInput; i++ { + in2[i] = strconv.Itoa(i) + } + err = g.Inputs(context.Background(), in2...) + require.NoError(t, err) + assert.Equal(t, numberOfInput, g.Len()) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.ElementsMatch(t, in, o) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Equal(t, in, o) + err = g.Inputs(context.Background(), in2...) + require.NoError(t, err) + assert.Equal(t, 2*numberOfInput, g.Len()) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.ElementsMatch(t, append(in, in...), o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Equal(t, append(in, in...), o) +} diff --git a/utils/signing/signing_test.go b/utils/signing/signing_test.go index f00d27b626..5584bd9355 100644 --- a/utils/signing/signing_test.go +++ b/utils/signing/signing_test.go @@ -37,7 +37,7 @@ func TestSigning(t *testing.T) { signature, err := signer.Sign(message) require.NoError(t, err) - ok, err := signer.Verify([]byte(faker.Word()), signature) + ok, err := signer.Verify([]byte(faker.Word()+faker.Word()), signature) require.NoError(t, err) assert.False(t, ok) }) @@ -48,7 +48,7 @@ func TestSigning(t *testing.T) { signer, err := NewEd25519SignerFromSeed(faker.Word()) require.NoError(t, err) - wrongSignature, err := signer.Sign([]byte(faker.Word())) + wrongSignature, err := signer.Sign([]byte(faker.Word() + faker.Word())) require.NoError(t, err) ok, err := signer.Verify(message, wrongSignature)