From 2d5ce156af52e0ade885bf4d09f6307efc4dd694 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Fri, 5 Sep 2025 17:13:11 +0100 Subject: [PATCH 1/4] :sparkles: `[parallelisation]` Define a transformation group --- changes/20250905171217.feature | 1 + utils/parallelisation/parallelisation.go | 55 ++--------- utils/parallelisation/transform.go | 121 +++++++++++++++++++++++ 3 files changed, 128 insertions(+), 49 deletions(-) create mode 100644 changes/20250905171217.feature create mode 100644 utils/parallelisation/transform.go diff --git a/changes/20250905171217.feature b/changes/20250905171217.feature new file mode 100644 index 0000000000..842af8312e --- /dev/null +++ b/changes/20250905171217.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Define a transformation group diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 30e909afdd..e11924dcc2 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -12,7 +12,6 @@ import ( "time" "go.uber.org/atomic" - "golang.org/x/sync/errgroup" "github.com/ARM-software/golang-utils/utils/commonerrors" ) @@ -265,64 +264,22 @@ func WaitUntil(ctx context.Context, evalCondition func(ctx2 context.Context) (bo } } -func newWorker[JobType, ResultType any](ctx context.Context, f func(context.Context, JobType) (ResultType, bool, error), jobs chan JobType, results chan ResultType) (err error) { - for job := range jobs { - result, ok, subErr := f(ctx, job) - if subErr != nil { - err = commonerrors.WrapError(commonerrors.ErrUnexpected, subErr, "an error occurred whilst handling a job") - return - } - - err = DetermineContextError(ctx) - if err != nil { - return - } - - if ok { - results <- result - } - } - - return -} - // WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f func(context.Context, InputType) (ResultType, bool, error)) (results []ResultType, err error) { if numWorkers < 1 { err = commonerrors.New(commonerrors.ErrInvalid, "numWorkers must be greater than or equal to 1") return } - - numJobs := len(jobs) - jobsChan := make(chan InputType, numJobs) - resultsChan := make(chan ResultType, numJobs) - - g, gCtx := errgroup.WithContext(ctx) - g.SetLimit(numWorkers) - for range numWorkers { - g.Go(func() error { return newWorker(gCtx, f, jobsChan, resultsChan) }) - } - for i := range jobs { - if DetermineContextError(ctx) != nil { - break - } - jobsChan <- jobs[i] - } - - close(jobsChan) - err = g.Wait() - close(resultsChan) - if err == nil { - err = DetermineContextError(ctx) - } + g := NewTransformGroup[InputType, ResultType](f, Workers(numWorkers), JoinErrors) + err = g.Inputs(ctx, jobs...) if err != nil { return } - - for result := range resultsChan { - results = append(results, result) + err = g.Transform(ctx) + if err != nil { + return } - + results, err = g.Outputs(ctx) return } diff --git a/utils/parallelisation/transform.go b/utils/parallelisation/transform.go new file mode 100644 index 0000000000..0248c9125d --- /dev/null +++ b/utils/parallelisation/transform.go @@ -0,0 +1,121 @@ +package parallelisation + +import ( + "context" + + "go.uber.org/atomic" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/field" +) + +type TransformFunc[I any, O any] func(context.Context, I) (output O, success bool, err error) + +type results[O any] struct { + terminated *atomic.Bool + r chan O +} + +func (r *results[O]) Append(o O) { + if !r.terminated.Load() { + r.r <- o + } +} + +func (r *results[O]) Results(ctx context.Context) (slice []O, err error) { + if !r.terminated.Swap(true) { + close(r.r) + } + err = DetermineContextError(ctx) + if err != nil { + return + } + slice = make([]O, 0, len(r.r)) + for output := range r.r { + err = DetermineContextError(ctx) + if err != nil { + return + } + slice = append(slice, output) + } + return +} + +func newResults[O any](numberOfInput *int) *results[O] { + i := field.OptionalInt(numberOfInput, 0) + var channel chan O + if i <= 0 { + channel = make(chan O) + } else { + channel = make(chan O, i) + } + + return &results[O]{ + terminated: atomic.NewBool(false), + r: channel, + } +} + +type TransformGroup[I any, O any] struct { + ExecutionGroup[I] + results *atomic.Pointer[results[O]] +} + +func (g *TransformGroup[I, O]) appendResult(o O) { + r := g.results.Load() + if r != nil { + r.Append(o) + } +} + +// Inputs registers inputs to transform. +func (g *TransformGroup[I, O]) Inputs(ctx context.Context, i ...I) error { + for j := range i { + err := DetermineContextError(ctx) + if err != nil { + return err + } + g.RegisterFunction(i[j]) + } + return nil +} + +// Outputs returns any input which have been transformed when the Transform function was called. +func (g *TransformGroup[I, O]) Outputs(ctx context.Context) ([]O, error) { + r := g.results.Load() + if r == nil { + return nil, commonerrors.UndefinedVariable("results") + } + return r.Results(ctx) +} + +// Transform actually performs the transformation +func (g *TransformGroup[I, O]) Transform(ctx context.Context) error { + g.results.Store(newResults[O](field.ToOptionalInt(g.Len()))) + return g.ExecutionGroup.Execute(ctx) +} + +// NewTransformGroup returns a group transforming inputs into outputs. +// To register inputs, call the Input function +// To perform the transformation of inputs, then call Transform +// To retrieve the output, then call Output +func NewTransformGroup[I any, O any](transform TransformFunc[I, O], options ...StoreOption) *TransformGroup[I, O] { + g := &TransformGroup[I, O]{ + results: atomic.NewPointer[results[O]](newResults[O](nil)), + } + g.ExecutionGroup = *NewExecutionGroup[I](func(fCtx context.Context, i I) error { + err := DetermineContextError(fCtx) + if err != nil { + return err + } + o, success, err := transform(fCtx, i) + if err != nil { + return commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred whilst handling an input [%+v]", i) + } + if success { + g.appendResult(o) + } + return nil + }, options...) + return g +} From 5a2f3d500bd68374d9b457e38c1c1594d2c084ff Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Mon, 8 Sep 2025 11:26:42 +0100 Subject: [PATCH 2/4] :sparkles: Keep tracks of the input order :green_heart: Added tests --- changes/20250908111211.feature | 1 + utils/collection/range.go | 33 ++++++++++++ utils/collection/range_test.go | 40 ++++++++++++++ utils/parallelisation/group.go | 24 ++++++--- utils/parallelisation/transform.go | 62 +++++++++++++++++---- utils/parallelisation/transform_test.go | 72 +++++++++++++++++++++++++ 6 files changed, 215 insertions(+), 17 deletions(-) create mode 100644 changes/20250908111211.feature create mode 100644 utils/collection/range.go create mode 100644 utils/collection/range_test.go create mode 100644 utils/parallelisation/transform_test.go diff --git a/changes/20250908111211.feature b/changes/20250908111211.feature new file mode 100644 index 0000000000..589aa73cff --- /dev/null +++ b/changes/20250908111211.feature @@ -0,0 +1 @@ +:sparkles: `[collection]` Added a `Range` function to populate slices of integers diff --git a/utils/collection/range.go b/utils/collection/range.go new file mode 100644 index 0000000000..653f74c62d --- /dev/null +++ b/utils/collection/range.go @@ -0,0 +1,33 @@ +package collection + +import "github.com/ARM-software/golang-utils/utils/field" + +func sign(x int) int { + if x < 0 { + return -1 + } + return 1 +} + +// Range returns a slice of integers similar to Python's built-in range(). +// https://docs.python.org/2/library/functions.html#range +// +// Note: The stop value is always exclusive. +func Range(start, stop int, step *int) []int { + s := field.OptionalInt(step, 1) + if s == 0 { + return []int{} + } + + // Compute length + length := 0 + if (s > 0 && start < stop) || (s < 0 && start > stop) { + length = (stop - start + s - sign(s)) / s + } + + result := make([]int, length) + for i, v := 0, start; i < length; i, v = i+1, v+s { + result[i] = v + } + return result +} diff --git a/utils/collection/range_test.go b/utils/collection/range_test.go new file mode 100644 index 0000000000..31b09195ae --- /dev/null +++ b/utils/collection/range_test.go @@ -0,0 +1,40 @@ +package collection + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ARM-software/golang-utils/utils/field" +) + +func TestRange(t *testing.T) { + tests := []struct { + start int + stop int + step *int + expected []int + }{ + + {2, 5, nil, []int{2, 3, 4}}, + {5, 2, nil, []int{}}, // empty, since stop < start + {2, 10, field.ToOptionalInt(2), []int{2, 4, 6, 8}}, + {0, 10, field.ToOptionalInt(3), []int{0, 3, 6, 9}}, + {1, 10, field.ToOptionalInt(3), []int{1, 4, 7}}, + {10, 2, field.ToOptionalInt(-2), []int{10, 8, 6, 4}}, + {5, -1, field.ToOptionalInt(-1), []int{5, 4, 3, 2, 1, 0}}, + {0, -5, field.ToOptionalInt(-2), []int{0, -2, -4}}, + {0, 5, nil, []int{0, 1, 2, 3, 4}}, + {0, 5, field.ToOptionalInt(0), []int{}}, + {2, 2, field.ToOptionalInt(1), []int{}}, + {2, 2, field.ToOptionalInt(-1), []int{}}, + } + + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("[%v,%v,%v]", test.start, test.stop, test.step), func(t *testing.T) { + assert.Equal(t, test.expected, Range(test.start, test.stop, test.step)) + }) + } +} diff --git a/utils/parallelisation/group.go b/utils/parallelisation/group.go index a687c913ea..de91d757e4 100644 --- a/utils/parallelisation/group.go +++ b/utils/parallelisation/group.go @@ -224,6 +224,14 @@ type ICompoundExecutionGroup[T any] interface { // NewExecutionGroup returns an execution group which executes functions according to store options. func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + return NewOrderedExecutionGroup(func(ctx context.Context, index int, element T) error { + return executeFunc(ctx, element) + }, options...) +} + +// NewOrderedExecutionGroup returns an execution group which executes functions according to store options. It also keeps track of the input index. +func NewOrderedExecutionGroup[T any](executeFunc OrderedExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + opts := WithOptions(options...) return &ExecutionGroup[T]{ mu: deadlock.RWMutex{}, @@ -235,10 +243,12 @@ func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption type ExecuteFunc[T any] func(ctx context.Context, element T) error +type OrderedExecuteFunc[T any] func(ctx context.Context, index int, element T) error + type ExecutionGroup[T any] struct { mu deadlock.RWMutex functions []wrappedElement[T] - executeFunc ExecuteFunc[T] + executeFunc OrderedExecuteFunc[T] options StoreOptions } @@ -294,7 +304,7 @@ func (s *ExecutionGroup[T]) executeConcurrently(ctx context.Context, stopOnFirst g.SetLimit(workers) for i := range s.functions { g.Go(func() error { - _, subErr := s.executeFunction(gCtx, s.functions[i]) + _, subErr := s.executeFunction(gCtx, i, s.functions[i]) errCh <- subErr return subErr }) @@ -323,7 +333,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst collateErr := make([]error, funcNum) if reverse { for i := funcNum - 1; i >= 0; i-- { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i]) collateErr[funcNum-i-1] = subErr if shouldBreak { err = subErr @@ -338,7 +348,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst } } else { for i := range s.functions { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i]) collateErr[i] = subErr if shouldBreak { err = subErr @@ -359,7 +369,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst return } -func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElement[T]) (mustBreak bool, err error) { +func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, index int, w wrappedElement[T]) (mustBreak bool, err error) { err = DetermineContextError(ctx) if err != nil { mustBreak = true @@ -370,7 +380,9 @@ func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElemen mustBreak = true return } - err = w.Execute(ctx, s.executeFunc) + err = w.Execute(ctx, func(ctx context.Context, element T) error { + return s.executeFunc(ctx, index, element) + }) return } diff --git a/utils/parallelisation/transform.go b/utils/parallelisation/transform.go index 0248c9125d..d12fa94294 100644 --- a/utils/parallelisation/transform.go +++ b/utils/parallelisation/transform.go @@ -9,16 +9,24 @@ import ( "github.com/ARM-software/golang-utils/utils/field" ) +// TransformFunc defines a transformation function which converts an input into an output. type TransformFunc[I any, O any] func(context.Context, I) (output O, success bool, err error) +type resultElement[O any] struct { + r O + index int +} type results[O any] struct { terminated *atomic.Bool - r chan O + r chan resultElement[O] } -func (r *results[O]) Append(o O) { +func (r *results[O]) Append(o *resultElement[O]) { + if o == nil { + return + } if !r.terminated.Load() { - r.r <- o + r.r <- *o } } @@ -36,18 +44,41 @@ func (r *results[O]) Results(ctx context.Context) (slice []O, err error) { if err != nil { return } - slice = append(slice, output) + slice = append(slice, output.r) + } + return +} + +func (r *results[O]) OrderedResults(ctx context.Context) (slice []O, err error) { + if !r.terminated.Swap(true) { + close(r.r) + } + err = DetermineContextError(ctx) + if err != nil { + return + } + values := make(map[int]O, len(r.r)) + slice = make([]O, len(r.r)) + for output := range r.r { + err = DetermineContextError(ctx) + if err != nil { + return + } + values[output.index] = output.r + } + for i := 0; i < len(slice); i++ { + slice[i] = values[i] } return } func newResults[O any](numberOfInput *int) *results[O] { i := field.OptionalInt(numberOfInput, 0) - var channel chan O + var channel chan resultElement[O] if i <= 0 { - channel = make(chan O) + channel = make(chan resultElement[O]) } else { - channel = make(chan O, i) + channel = make(chan resultElement[O], i) } return &results[O]{ @@ -61,7 +92,7 @@ type TransformGroup[I any, O any] struct { results *atomic.Pointer[results[O]] } -func (g *TransformGroup[I, O]) appendResult(o O) { +func (g *TransformGroup[I, O]) appendResult(o *resultElement[O]) { r := g.results.Load() if r != nil { r.Append(o) @@ -89,7 +120,16 @@ func (g *TransformGroup[I, O]) Outputs(ctx context.Context) ([]O, error) { return r.Results(ctx) } -// Transform actually performs the transformation +// OrderedOutputs returns any input which have been transformed when the Transform function was called. The returned output is in the same order as the input slice. +func (g *TransformGroup[I, O]) OrderedOutputs(ctx context.Context) ([]O, error) { + r := g.results.Load() + if r == nil { + return nil, commonerrors.UndefinedVariable("results") + } + return r.OrderedResults(ctx) +} + +// Transform actually performs the transformation over all registered inputs. func (g *TransformGroup[I, O]) Transform(ctx context.Context) error { g.results.Store(newResults[O](field.ToOptionalInt(g.Len()))) return g.ExecutionGroup.Execute(ctx) @@ -103,7 +143,7 @@ func NewTransformGroup[I any, O any](transform TransformFunc[I, O], options ...S g := &TransformGroup[I, O]{ results: atomic.NewPointer[results[O]](newResults[O](nil)), } - g.ExecutionGroup = *NewExecutionGroup[I](func(fCtx context.Context, i I) error { + g.ExecutionGroup = *NewOrderedExecutionGroup[I](func(fCtx context.Context, index int, i I) error { err := DetermineContextError(fCtx) if err != nil { return err @@ -113,7 +153,7 @@ func NewTransformGroup[I any, O any](transform TransformFunc[I, O], options ...S return commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred whilst handling an input [%+v]", i) } if success { - g.appendResult(o) + g.appendResult(&resultElement[O]{index: index, r: o}) } return nil }, options...) diff --git a/utils/parallelisation/transform_test.go b/utils/parallelisation/transform_test.go new file mode 100644 index 0000000000..a76f88c4ad --- /dev/null +++ b/utils/parallelisation/transform_test.go @@ -0,0 +1,72 @@ +package parallelisation + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ARM-software/golang-utils/utils/collection" +) + +func TestNewTransformGroup(t *testing.T) { + tr := func(ctx context.Context, i string) (o int, success bool, err error) { + err = DetermineContextError(ctx) + if err != nil { + return + } + o, err = strconv.Atoi(i) + if err == nil { + success = true + } + return + } + g := NewTransformGroup[string, int](tr, RetainAfterExecution, Parallel) + assert.Zero(t, g.Len()) + o, err := g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + numberOfInput := 50 + in := collection.Range(0, numberOfInput, nil) + in2 := make([]string, numberOfInput) + for i := 0; i < numberOfInput; i++ { + in2[i] = strconv.Itoa(i) + } + err = g.Inputs(context.Background(), in2...) + require.NoError(t, err) + assert.Equal(t, numberOfInput, g.Len()) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.ElementsMatch(t, in, o) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Equal(t, in, o) + err = g.Inputs(context.Background(), in2...) + require.NoError(t, err) + assert.Equal(t, 2*numberOfInput, g.Len()) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.Empty(t, o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.Outputs(context.Background()) + require.NoError(t, err) + assert.ElementsMatch(t, append(in, in...), o) + err = g.Transform(context.Background()) + require.NoError(t, err) + o, err = g.OrderedOutputs(context.Background()) + require.NoError(t, err) + assert.Equal(t, append(in, in...), o) +} From 6bbc7166ecac039b591a891c1024d213ddf11652 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Mon, 8 Sep 2025 12:10:15 +0100 Subject: [PATCH 3/4] Added an `OrderedMap` function --- utils/collection/search.go | 20 +++++++-- utils/parallelisation/parallelisation.go | 43 +++++++++++++++---- utils/parallelisation/parallelisation_test.go | 34 +++++++++++++++ utils/parallelisation/transform_test.go | 2 + 4 files changed, 87 insertions(+), 12 deletions(-) diff --git a/utils/collection/search.go b/utils/collection/search.go index 77c9da9ccf..f6b5156412 100644 --- a/utils/collection/search.go +++ b/utils/collection/search.go @@ -71,8 +71,10 @@ func AnyFunc[S ~[]E, E any](s S, f func(E) bool) bool { return conditions.Any() } +type FilterFunc[E any] func(E) bool + // Filter returns a new slice that contains elements from the input slice which return true when they’re passed as a parameter to the provided filtering function f. -func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) { +func Filter[S ~[]E, E any](s S, f FilterFunc[E]) (result S) { result = make(S, 0, len(s)) for i := range s { @@ -84,8 +86,16 @@ func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) { return result } +type MapFunc[T1, T2 any] func(T1) T2 + +func IdentityMapFunc[T any]() MapFunc[T, T] { + return func(i T) T { + return i + } +} + // Map creates a new slice and populates it with the results of calling the provided function on every element in input slice. -func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) { +func Map[T1 any, T2 any](s []T1, f MapFunc[T1, T2]) (result []T2) { result = make([]T2, len(s)) for i := range s { @@ -97,12 +107,14 @@ func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) { // Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. // This is functionally equivalent to slices.DeleteFunc but it returns a new slice. -func Reject[S ~[]E, E any](s S, f func(E) bool) S { +func Reject[S ~[]E, E any](s S, f FilterFunc[E]) S { return Filter(s, func(e E) bool { return !f(e) }) } +type ReduceFunc[T1, T2 any] func(T2, T1) T2 + // Reduce runs a reducer function f over all elements in the array, in ascending-index order, and accumulates them into a single value. -func Reduce[T1, T2 any](s []T1, accumulator T2, f func(T2, T1) T2) (result T2) { +func Reduce[T1, T2 any](s []T1, accumulator T2, f ReduceFunc[T1, T2]) (result T2) { result = accumulator for i := range s { result = f(result, s[i]) diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index e11924dcc2..5852b253bd 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -13,6 +13,7 @@ import ( "go.uber.org/atomic" + "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" ) @@ -265,12 +266,21 @@ func WaitUntil(ctx context.Context, evalCondition func(ctx2 context.Context) (bo } // WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size -func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f func(context.Context, InputType) (ResultType, bool, error)) (results []ResultType, err error) { +func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f TransformFunc[InputType, ResultType]) (results []ResultType, err error) { + g, err := workerPoolGroup[InputType, ResultType](ctx, numWorkers, jobs, f) + if err != nil { + return + } + results, err = g.Outputs(ctx) + return +} + +func workerPoolGroup[I, O any](ctx context.Context, numWorkers int, jobs []I, f TransformFunc[I, O]) (g *TransformGroup[I, O], err error) { if numWorkers < 1 { err = commonerrors.New(commonerrors.ErrInvalid, "numWorkers must be greater than or equal to 1") return } - g := NewTransformGroup[InputType, ResultType](f, Workers(numWorkers), JoinErrors) + g = NewTransformGroup[I, O](f, Workers(numWorkers), JoinErrors) err = g.Inputs(ctx, jobs...) if err != nil { return @@ -279,12 +289,11 @@ func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, if err != nil { return } - results, err = g.Outputs(ctx) return } // Filter is similar to collection.Filter but uses parallelisation. -func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) (result []T, err error) { +func Filter[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) (result []T, err error) { result, err = WorkerPool[T, T](ctx, numWorkers, s, func(fCtx context.Context, item T) (r T, ok bool, fErr error) { fErr = DetermineContextError(fCtx) if fErr != nil { @@ -297,9 +306,8 @@ func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ( return } -// Map is similar to collection.Map but uses parallelisation. -func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) T2) (result []T2, err error) { - result, err = WorkerPool[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) { +func mapGroup[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (*TransformGroup[T1, T2], error) { + return workerPoolGroup[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) { fErr = DetermineContextError(fCtx) if fErr != nil { return @@ -308,10 +316,29 @@ func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) ok = true return }) +} + +// Map is similar to collection.Map but uses parallelisation. +func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) { + g, err := mapGroup[T1, T2](ctx, numWorkers, s, f) + if err != nil { + return + } + result, err = g.Outputs(ctx) + return +} + +// OrderedMap is similar to Map but ensures the results are in the same order as the input. +func OrderedMap[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) { + g, err := mapGroup[T1, T2](ctx, numWorkers, s, f) + if err != nil { + return + } + result, err = g.OrderedOutputs(ctx) return } // Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. -func Reject[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ([]T, error) { +func Reject[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) ([]T, error) { return Filter[T](ctx, numWorkers, s, func(e T) bool { return !f(e) }) } diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index fe9d2305c1..a0e63f17b8 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -19,8 +19,10 @@ import ( "go.uber.org/atomic" "go.uber.org/goleak" + "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/field" ) var ( @@ -636,3 +638,35 @@ func TestMap(t *testing.T) { errortest.AssertError(t, err, commonerrors.ErrCancelled) }) } + +func TestMapAndOrderedMap(t *testing.T) { + defer goleak.VerifyNone(t) + ctx := context.Background() + mapped, err := OrderedMap(ctx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + require.NoError(t, err) + assert.Equal(t, []string{"Hello world 1", "Hello world 2"}, mapped) + mapped, err = OrderedMap(ctx, 3, []int64{1, 2, 3, 4}, func(x int64) string { + return strconv.FormatInt(x, 10) + }) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3", "4"}, mapped) + t.Run("cancelled context", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := Map(cancelledCtx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) + + in := collection.Range(0, 1000, field.ToOptionalInt(5)) + mappedInt, err := OrderedMap(ctx, 3, in, collection.IdentityMapFunc[int]()) + require.NoError(t, err) + assert.Equal(t, in, mappedInt) + mappedInt, err = Map(ctx, 3, in, collection.IdentityMapFunc[int]()) + require.NoError(t, err) + assert.NotEqual(t, in, mappedInt) + assert.ElementsMatch(t, in, mappedInt) +} diff --git a/utils/parallelisation/transform_test.go b/utils/parallelisation/transform_test.go index a76f88c4ad..457822581d 100644 --- a/utils/parallelisation/transform_test.go +++ b/utils/parallelisation/transform_test.go @@ -7,11 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/ARM-software/golang-utils/utils/collection" ) func TestNewTransformGroup(t *testing.T) { + defer goleak.VerifyNone(t) tr := func(ctx context.Context, i string) (o int, success bool, err error) { err = DetermineContextError(ctx) if err != nil { From 2ed092381457f788364dfdbb0a7ded7334e7e54d Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Mon, 8 Sep 2025 12:17:42 +0100 Subject: [PATCH 4/4] :green_heart: fix flaky test --- utils/signing/signing_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/signing/signing_test.go b/utils/signing/signing_test.go index f00d27b626..5584bd9355 100644 --- a/utils/signing/signing_test.go +++ b/utils/signing/signing_test.go @@ -37,7 +37,7 @@ func TestSigning(t *testing.T) { signature, err := signer.Sign(message) require.NoError(t, err) - ok, err := signer.Verify([]byte(faker.Word()), signature) + ok, err := signer.Verify([]byte(faker.Word()+faker.Word()), signature) require.NoError(t, err) assert.False(t, ok) }) @@ -48,7 +48,7 @@ func TestSigning(t *testing.T) { signer, err := NewEd25519SignerFromSeed(faker.Word()) require.NoError(t, err) - wrongSignature, err := signer.Sign([]byte(faker.Word())) + wrongSignature, err := signer.Sign([]byte(faker.Word() + faker.Word())) require.NoError(t, err) ok, err := signer.Verify(message, wrongSignature)