Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251205145557.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `parallisation` Add support for more context when returning from RunActionWithParallelCheck
91 changes: 68 additions & 23 deletions utils/parallelisation/parallelisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"go.uber.org/atomic"
"golang.org/x/sync/errgroup"

"github.com/ARM-software/golang-utils/utils/collection"
"github.com/ARM-software/golang-utils/utils/commonerrors"
Expand Down Expand Up @@ -210,38 +211,82 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati
}
}

// RunActionWithParallelCheck runs an action with a check in parallel
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error {
err := DetermineContextError(ctx)
if err != nil {
return err
}
func runActionAndWait[T any](ctx context.Context, errGroup *errgroup.Group, action func(ctx context.Context) error, checkAction func(ctx context.Context) (res T, ok bool), onCheckResult func(res T) error, checkPeriod time.Duration) (res T, ok bool, err error) {
cancelStore := NewCancelFunctionsStore()
defer cancelStore.Cancel()

cancellableCtx, cancelFunc := context.WithCancel(ctx)
cancelStore.RegisterCancelFunction(cancelFunc)
go func(ctx context.Context, store *CancelFunctionStore) {
for {
select {
case <-ctx.Done():
store.Cancel()
return
default:
if !checkAction(ctx) {
store.Cancel()

errGroup.Go(func() error {
return func(ctx context.Context, store *CancelFunctionStore) (err error) {
defer store.Cancel()
for {
select {
case <-ctx.Done():
return
default:
res, ok = checkAction(ctx)

err = onCheckResult(res)
if err != nil {
return
}

if !ok {
return
}

SleepWithContext(ctx, checkPeriod)
}
SleepWithContext(ctx, checkPeriod)
}
}
}(cancellableCtx, cancelStore)
}(cancellableCtx, cancelStore)
})

err = action(cancellableCtx)
err2 := DetermineContextError(cancellableCtx)
if err2 != nil {
return err2
if errCtx := DetermineContextError(cancellableCtx); errCtx != nil {
err = errCtx
}
cancelFunc()

if egErr := errGroup.Wait(); commonerrors.Ignore(egErr, commonerrors.ErrCancelled, commonerrors.ErrTimeout) != nil {
err = egErr
return
}

return
}

// RunActionWithParallelCheck runs an action with a check in parallel
// The function performing the check should return true if the check was favourable; false otherwise.
// For more context, a result can be returned. If the check did not have the expected result and the
// whole function would be cancelled.
func RunActionWithParallelCheckAndResult[T any](ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) (res T, ok bool), onCheckResult func(res T) error, checkPeriod time.Duration) (res T, ok bool, err error) {
err = DetermineContextError(ctx)
if err != nil {
return
}
return err

var errGroup errgroup.Group
res, ok, err = runActionAndWait(ctx, &errGroup, action, checkAction, onCheckResult, checkPeriod)
return
}

// RunActionWithParallelCheck runs an action with a check in parallel
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) (err error) {
_, _, err = RunActionWithParallelCheckAndResult(
ctx,
action,
func(ctx context.Context) (_ struct{}, ok bool) {
ok = checkAction(ctx)
return
},
func(_ struct{}) error { return nil },
checkPeriod,
)

return
}

// WaitUntil waits for a condition evaluated by evalCondition to be verified
Expand Down
195 changes: 195 additions & 0 deletions utils/parallelisation/parallelisation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -418,6 +419,200 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
assert.GreaterOrEqual(t, counter.Load(), int32(1))
}

func TestRunActionWithParallelCheckAndResult(t *testing.T) {
type parallelisationCheckResult struct {
checks int32
status string
}

t.Run("Happy", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) (err error) {
time.Sleep(120 * time.Millisecond)
return
},
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: "healthy",
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
10*time.Millisecond,
)

require.NoError(t, err)
require.True(t, ok)

assert.GreaterOrEqual(t, res.checks, int32(10))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, "healthy", res.status)
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Check Fails With Reason", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
actionStarted := atomic.NewBool(false)

status := "adrien"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
actionStarted.Store(true)
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
if n := checkCounter.Inc(); n >= 5 {
return parallelisationCheckResult{
checks: n,
status: status,
}, false
} else {
return parallelisationCheckResult{
checks: n,
status: "ok",
}, true
}
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.True(t, actionStarted.Load())
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)

require.False(t, ok)
assert.Equal(t, status, res.status)
assert.Equal(t, int32(5), res.checks)
assert.Equal(t, int32(5), checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})
t.Run("Action Error (no context cancel)", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "abdel"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
time.Sleep(30 * time.Millisecond)
return commonerrors.New(commonerrors.ErrForbidden, faker.Sentence())
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrForbidden)
require.True(t, ok)

assert.Equal(t, status, res.status)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Context cancel", func(t *testing.T) {
defer goleak.VerifyNone(t)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "kem"

res, ok, err := RunActionWithParallelCheckAndResult(
ctx,
func(ctx context.Context) error {
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrTimeout)
assert.True(t, ok)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Check result error", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "kem"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return commonerrors.ErrUnexpected
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrUnexpected)
assert.True(t, ok)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})
}

func TestWaitUntil(t *testing.T) {
defer goleak.VerifyNone(t)
verifiedCondition := func(ctx context.Context) (bool, error) {
Expand Down
Loading