diff --git a/changes/20251205145557.feature b/changes/20251205145557.feature new file mode 100644 index 0000000000..1710378d56 --- /dev/null +++ b/changes/20251205145557.feature @@ -0,0 +1 @@ +:sparkles: `parallelisation` Add support for more context when returning from RunActionWithParallelCheck diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 5852b253bd..f57d808adb 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -12,6 +12,7 @@ import ( "time" "go.uber.org/atomic" + "golang.org/x/sync/errgroup" "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" @@ -210,38 +211,85 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati } } +type ( + ActionFunc func(ctx context.Context) (err error) + CheckFunc func(ctx context.Context) (ok bool) + CheckWithResultFunc[T any] func(ctx context.Context) (res T, ok bool) + ResultCheckFunc[T any] func(res T) (err error) +) + // RunActionWithParallelCheck runs an action with a check in parallel -// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled. -func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error { - err := DetermineContextError(ctx) +// The function performing the check should return true if the check should be repeated; false otherwise it should not. +// For more context about how the check ended, a result can be returned. If the check did not have the expected result +// then the whole function would be cancelled. +func RunActionWithParallelCheckAndResult[T any](ctx context.Context, action ActionFunc, checkAction CheckWithResultFunc[T], onCheckResult ResultCheckFunc[T], checkPeriod time.Duration) (res T, ok bool, err error) { + err = DetermineContextError(ctx) if err != nil { - return err + return } + + var errGroup errgroup.Group + cancelStore := NewCancelFunctionsStore() defer cancelStore.Cancel() + cancellableCtx, cancelFunc := context.WithCancel(ctx) cancelStore.RegisterCancelFunction(cancelFunc) - go func(ctx context.Context, store *CancelFunctionStore) { - for { - select { - case <-ctx.Done(): - store.Cancel() - return - default: - if !checkAction(ctx) { - store.Cancel() + + errGroup.Go(func() error { + return func(ctx context.Context, store *CancelFunctionStore) (err error) { + defer store.Cancel() + for { + select { + case <-ctx.Done(): return + default: + res, ok = checkAction(ctx) + + err = onCheckResult(res) + if err != nil { + return + } + + if !ok { + return + } + + SleepWithContext(ctx, checkPeriod) } - SleepWithContext(ctx, checkPeriod) } - } - }(cancellableCtx, cancelStore) + }(cancellableCtx, cancelStore) + }) + err = action(cancellableCtx) - err2 := DetermineContextError(cancellableCtx) - if err2 != nil { - return err2 + if errCtx := DetermineContextError(cancellableCtx); errCtx != nil { + err = errCtx } - return err + cancelFunc() + + if egErr := errGroup.Wait(); commonerrors.Ignore(egErr, commonerrors.ErrCancelled, commonerrors.ErrTimeout) != nil { + err = egErr + return + } + + return +} + +// RunActionWithParallelCheck runs an action with a check in parallel +// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result then the whole function would be cancelled. +func RunActionWithParallelCheck(ctx context.Context, action ActionFunc, checkAction CheckFunc, checkPeriod time.Duration) (err error) { + _, _, err = RunActionWithParallelCheckAndResult( + ctx, + action, + func(ctx context.Context) (_ struct{}, ok bool) { + ok = checkAction(ctx) + return + }, + func(_ struct{}) error { return nil }, + checkPeriod, + ) + + return } // WaitUntil waits for a condition evaluated by evalCondition to be verified diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index a0e63f17b8..2ec31007aa 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/go-faker/faker/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -418,6 +419,200 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) { assert.GreaterOrEqual(t, counter.Load(), int32(1)) } +func TestRunActionWithParallelCheckAndResult(t *testing.T) { + type parallelisationCheckResult struct { + checks int32 + status string + } + + t.Run("Happy", func(t *testing.T) { + defer goleak.VerifyNone(t) + + checkCounter := atomic.NewInt32(0) + checkResultCounter := atomic.NewInt32(0) + + res, ok, err := RunActionWithParallelCheckAndResult( + context.Background(), + func(ctx context.Context) (err error) { + time.Sleep(120 * time.Millisecond) + return + }, + func(ctx context.Context) (res parallelisationCheckResult, ok bool) { + return parallelisationCheckResult{ + checks: checkCounter.Inc(), + status: "healthy", + }, true + }, + func(_ parallelisationCheckResult) error { + checkResultCounter.Inc() + return nil + }, + 10*time.Millisecond, + ) + + require.NoError(t, err) + require.True(t, ok) + + assert.GreaterOrEqual(t, res.checks, int32(10)) + assert.Equal(t, res.checks, checkCounter.Load()) + assert.Equal(t, "healthy", res.status) + assert.Equal(t, checkCounter.Load(), checkResultCounter.Load()) + }) + + t.Run("Check Fails With Reason", func(t *testing.T) { + defer goleak.VerifyNone(t) + + checkCounter := atomic.NewInt32(0) + checkResultCounter := atomic.NewInt32(0) + actionStarted := atomic.NewBool(false) + + status := "adrien" + + res, ok, err := RunActionWithParallelCheckAndResult( + context.Background(), + func(ctx context.Context) error { + actionStarted.Store(true) + <-ctx.Done() + return DetermineContextError(ctx) + }, + func(ctx context.Context) (res parallelisationCheckResult, ok bool) { + if n := checkCounter.Inc(); n >= 5 { + return parallelisationCheckResult{ + checks: n, + status: status, + }, false + } else { + return parallelisationCheckResult{ + checks: n, + status: "ok", + }, true + } + }, + func(_ parallelisationCheckResult) error { + checkResultCounter.Inc() + return nil + }, + 5*time.Millisecond, + ) + + require.True(t, actionStarted.Load()) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + + require.False(t, ok) + assert.Equal(t, status, res.status) + assert.Equal(t, int32(5), res.checks) + assert.Equal(t, int32(5), checkCounter.Load()) + assert.Equal(t, checkCounter.Load(), checkResultCounter.Load()) + }) + t.Run("Action Error (no context cancel)", func(t *testing.T) { + defer goleak.VerifyNone(t) + + checkCounter := atomic.NewInt32(0) + checkResultCounter := atomic.NewInt32(0) + status := "abdel" + + res, ok, err := RunActionWithParallelCheckAndResult( + context.Background(), + func(ctx context.Context) error { + time.Sleep(30 * time.Millisecond) + return commonerrors.New(commonerrors.ErrForbidden, faker.Sentence()) + }, + func(ctx context.Context) (parallelisationCheckResult, bool) { + return parallelisationCheckResult{ + checks: checkCounter.Inc(), + status: status, + }, true + }, + func(_ parallelisationCheckResult) error { + checkResultCounter.Inc() + return nil + }, + 5*time.Millisecond, + ) + + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrForbidden) + require.True(t, ok) + + assert.Equal(t, status, res.status) + assert.GreaterOrEqual(t, res.checks, int32(1)) + assert.Equal(t, res.checks, checkCounter.Load()) + assert.Equal(t, checkCounter.Load(), checkResultCounter.Load()) + }) + + t.Run("Context cancel", func(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + checkCounter := atomic.NewInt32(0) + checkResultCounter := atomic.NewInt32(0) + status := "kem" + + res, ok, err := RunActionWithParallelCheckAndResult( + ctx, + func(ctx context.Context) error { + <-ctx.Done() + return DetermineContextError(ctx) + }, + func(ctx context.Context) (parallelisationCheckResult, bool) { + return parallelisationCheckResult{ + checks: checkCounter.Inc(), + status: status, + }, true + }, + func(_ parallelisationCheckResult) error { + checkResultCounter.Inc() + return nil + }, + 5*time.Millisecond, + ) + + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrTimeout) + assert.True(t, ok) + assert.GreaterOrEqual(t, res.checks, int32(1)) + assert.Equal(t, res.checks, checkCounter.Load()) + assert.Equal(t, checkCounter.Load(), checkResultCounter.Load()) + }) + + t.Run("Check result error", func(t *testing.T) { + defer goleak.VerifyNone(t) + + checkCounter := atomic.NewInt32(0) + checkResultCounter := atomic.NewInt32(0) + status := "kem" + + res, ok, err := RunActionWithParallelCheckAndResult( + context.Background(), + func(ctx context.Context) error { + <-ctx.Done() + return DetermineContextError(ctx) + }, + func(ctx context.Context) (parallelisationCheckResult, bool) { + return parallelisationCheckResult{ + checks: checkCounter.Inc(), + status: status, + }, true + }, + func(_ parallelisationCheckResult) error { + checkResultCounter.Inc() + return commonerrors.ErrUnexpected + }, + 5*time.Millisecond, + ) + + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrUnexpected) + assert.True(t, ok) + assert.GreaterOrEqual(t, res.checks, int32(1)) + assert.Equal(t, res.checks, checkCounter.Load()) + assert.Equal(t, checkCounter.Load(), checkResultCounter.Load()) + }) +} + func TestWaitUntil(t *testing.T) { defer goleak.VerifyNone(t) verifiedCondition := func(ctx context.Context) (bool, error) {