Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251205145557.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `parallelisation` Add support for more context when returning from RunActionWithParallelCheck
88 changes: 68 additions & 20 deletions utils/parallelisation/parallelisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"go.uber.org/atomic"
"golang.org/x/sync/errgroup"

"github.com/ARM-software/golang-utils/utils/collection"
"github.com/ARM-software/golang-utils/utils/commonerrors"
Expand Down Expand Up @@ -210,38 +211,85 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati
}
}

type (
ActionFunc func(ctx context.Context) (err error)
CheckFunc func(ctx context.Context) (ok bool)
CheckWithResultFunc[T any] func(ctx context.Context) (res T, ok bool)
ResultCheckFunc[T any] func(res T) (err error)
)

// RunActionWithParallelCheck runs an action with a check in parallel
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error {
err := DetermineContextError(ctx)
// The function performing the check should return true if the check should be repeated; false otherwise it should not.
// For more context about how the check ended, a result can be returned. If the check did not have the expected result
// then the whole function would be cancelled.
func RunActionWithParallelCheckAndResult[T any](ctx context.Context, action ActionFunc, checkAction CheckWithResultFunc[T], onCheckResult ResultCheckFunc[T], checkPeriod time.Duration) (res T, ok bool, err error) {
err = DetermineContextError(ctx)
if err != nil {
return err
return
}

var errGroup errgroup.Group

cancelStore := NewCancelFunctionsStore()
defer cancelStore.Cancel()

cancellableCtx, cancelFunc := context.WithCancel(ctx)
cancelStore.RegisterCancelFunction(cancelFunc)
go func(ctx context.Context, store *CancelFunctionStore) {
for {
select {
case <-ctx.Done():
store.Cancel()
return
default:
if !checkAction(ctx) {
store.Cancel()

errGroup.Go(func() error {
return func(ctx context.Context, store *CancelFunctionStore) (err error) {
defer store.Cancel()
for {
select {
case <-ctx.Done():
return
default:
res, ok = checkAction(ctx)

err = onCheckResult(res)
if err != nil {
return
}

if !ok {
return
}

SleepWithContext(ctx, checkPeriod)
}
SleepWithContext(ctx, checkPeriod)
}
}
}(cancellableCtx, cancelStore)
}(cancellableCtx, cancelStore)
})

err = action(cancellableCtx)
err2 := DetermineContextError(cancellableCtx)
if err2 != nil {
return err2
if errCtx := DetermineContextError(cancellableCtx); errCtx != nil {
err = errCtx
}
return err
cancelFunc()

if egErr := errGroup.Wait(); commonerrors.Ignore(egErr, commonerrors.ErrCancelled, commonerrors.ErrTimeout) != nil {
err = egErr
return
}

return
}

// RunActionWithParallelCheck runs an action with a check in parallel
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result then the whole function would be cancelled.
func RunActionWithParallelCheck(ctx context.Context, action ActionFunc, checkAction CheckFunc, checkPeriod time.Duration) (err error) {
_, _, err = RunActionWithParallelCheckAndResult(
ctx,
action,
func(ctx context.Context) (_ struct{}, ok bool) {
ok = checkAction(ctx)
return
},
func(_ struct{}) error { return nil },
checkPeriod,
)

return
}

// WaitUntil waits for a condition evaluated by evalCondition to be verified
Expand Down
195 changes: 195 additions & 0 deletions utils/parallelisation/parallelisation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -418,6 +419,200 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
assert.GreaterOrEqual(t, counter.Load(), int32(1))
}

func TestRunActionWithParallelCheckAndResult(t *testing.T) {
type parallelisationCheckResult struct {
checks int32
status string
}

t.Run("Happy", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) (err error) {
time.Sleep(120 * time.Millisecond)
return
},
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: "healthy",
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
10*time.Millisecond,
)

require.NoError(t, err)
require.True(t, ok)

assert.GreaterOrEqual(t, res.checks, int32(10))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, "healthy", res.status)
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Check Fails With Reason", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
actionStarted := atomic.NewBool(false)

status := "adrien"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
actionStarted.Store(true)
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
if n := checkCounter.Inc(); n >= 5 {
return parallelisationCheckResult{
checks: n,
status: status,
}, false
} else {
return parallelisationCheckResult{
checks: n,
status: "ok",
}, true
}
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.True(t, actionStarted.Load())
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)

require.False(t, ok)
assert.Equal(t, status, res.status)
assert.Equal(t, int32(5), res.checks)
assert.Equal(t, int32(5), checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})
t.Run("Action Error (no context cancel)", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "abdel"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
time.Sleep(30 * time.Millisecond)
return commonerrors.New(commonerrors.ErrForbidden, faker.Sentence())
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrForbidden)
require.True(t, ok)

assert.Equal(t, status, res.status)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Context cancel", func(t *testing.T) {
defer goleak.VerifyNone(t)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "kem"

res, ok, err := RunActionWithParallelCheckAndResult(
ctx,
func(ctx context.Context) error {
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return nil
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrTimeout)
assert.True(t, ok)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})

t.Run("Check result error", func(t *testing.T) {
defer goleak.VerifyNone(t)

checkCounter := atomic.NewInt32(0)
checkResultCounter := atomic.NewInt32(0)
status := "kem"

res, ok, err := RunActionWithParallelCheckAndResult(
context.Background(),
func(ctx context.Context) error {
<-ctx.Done()
return DetermineContextError(ctx)
},
func(ctx context.Context) (parallelisationCheckResult, bool) {
return parallelisationCheckResult{
checks: checkCounter.Inc(),
status: status,
}, true
},
func(_ parallelisationCheckResult) error {
checkResultCounter.Inc()
return commonerrors.ErrUnexpected
},
5*time.Millisecond,
)

require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrUnexpected)
assert.True(t, ok)
assert.GreaterOrEqual(t, res.checks, int32(1))
assert.Equal(t, res.checks, checkCounter.Load())
assert.Equal(t, checkCounter.Load(), checkResultCounter.Load())
})
}

func TestWaitUntil(t *testing.T) {
defer goleak.VerifyNone(t)
verifiedCondition := func(ctx context.Context) (bool, error) {
Expand Down
Loading