Skip to content

Commit 6d26d24

Browse files
parallisation Add support for more context when returning from RunActionWithParallelCheck
1 parent 4c8c804 commit 6d26d24

File tree

3 files changed

+184
-12
lines changed

3 files changed

+184
-12
lines changed

changes/20251205145557.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `parallisation` Add support for more context when returning from RunActionWithParallelCheck

utils/parallelisation/parallelisation.go

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package parallelisation
99
import (
1010
"context"
1111
"reflect"
12+
"sync"
1213
"time"
1314

1415
"go.uber.org/atomic"
@@ -210,38 +211,72 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati
210211
}
211212
}
212213

213-
// RunActionWithParallelCheck runs an action with a check in parallel
214-
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
215-
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error {
216-
err := DetermineContextError(ctx)
217-
if err != nil {
218-
return err
219-
}
214+
func runActionAndWait[T any](ctx context.Context, wg *sync.WaitGroup, action func(ctx context.Context) error, checkAction func(ctx context.Context) (res T, ok bool), checkPeriod time.Duration) (res T, ok bool, err error) {
220215
cancelStore := NewCancelFunctionsStore()
221216
defer cancelStore.Cancel()
217+
222218
cancellableCtx, cancelFunc := context.WithCancel(ctx)
223219
cancelStore.RegisterCancelFunction(cancelFunc)
220+
221+
wg.Add(1)
224222
go func(ctx context.Context, store *CancelFunctionStore) {
223+
defer wg.Done()
225224
for {
226225
select {
227226
case <-ctx.Done():
228227
store.Cancel()
229228
return
230229
default:
231-
if !checkAction(ctx) {
230+
res, ok = checkAction(ctx)
231+
if !ok {
232232
store.Cancel()
233233
return
234234
}
235+
235236
SleepWithContext(ctx, checkPeriod)
236237
}
237238
}
238239
}(cancellableCtx, cancelStore)
240+
239241
err = action(cancellableCtx)
240-
err2 := DetermineContextError(cancellableCtx)
241-
if err2 != nil {
242-
return err2
242+
if errCtx := DetermineContextError(cancellableCtx); errCtx != nil {
243+
err = errCtx
244+
}
245+
246+
return
247+
}
248+
249+
// RunActionWithParallelCheck runs an action with a check in parallel
250+
// The function performing the check should return true if the check was favourable; false otherwise.
251+
// For more context, a result can be returned. If the check did not have the expected result and the
252+
// whole function would be cancelled.
253+
func RunActionWithParallelCheckAndResult[T any](ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) (res T, ok bool), checkPeriod time.Duration) (res T, ok bool, err error) {
254+
err = DetermineContextError(ctx)
255+
if err != nil {
256+
return
243257
}
244-
return err
258+
259+
var wg sync.WaitGroup
260+
defer wg.Wait()
261+
262+
res, ok, err = runActionAndWait(ctx, &wg, action, checkAction, checkPeriod)
263+
return
264+
}
265+
266+
// RunActionWithParallelCheck runs an action with a check in parallel
267+
// The function performing the check should return true if the check was favourable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
268+
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) (err error) {
269+
_, _, err = RunActionWithParallelCheckAndResult(
270+
ctx,
271+
action,
272+
func(ctx context.Context) (_ struct{}, ok bool) {
273+
ok = checkAction(ctx)
274+
return
275+
},
276+
checkPeriod,
277+
)
278+
279+
return
245280
}
246281

247282
// WaitUntil waits for a condition evaluated by evalCondition to be verified

utils/parallelisation/parallelisation_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"testing"
1515
"time"
1616

17+
"github.com/go-faker/faker/v4"
1718
"github.com/stretchr/testify/assert"
1819
"github.com/stretchr/testify/require"
1920
"go.uber.org/atomic"
@@ -418,6 +419,141 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
418419
assert.GreaterOrEqual(t, counter.Load(), int32(1))
419420
}
420421

422+
func TestRunActionWithParallelCheckAndResult(t *testing.T) {
423+
type parallelisationCheckResult struct {
424+
checks int32
425+
status string
426+
}
427+
428+
t.Run("Happy", func(t *testing.T) {
429+
defer goleak.VerifyNone(t)
430+
431+
counter := atomic.NewInt32(0)
432+
433+
res, ok, err := RunActionWithParallelCheckAndResult(
434+
context.Background(),
435+
func(ctx context.Context) (err error) {
436+
time.Sleep(120 * time.Millisecond)
437+
return
438+
},
439+
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
440+
return parallelisationCheckResult{
441+
checks: counter.Inc(),
442+
status: "healthy",
443+
}, true
444+
},
445+
10*time.Millisecond,
446+
)
447+
448+
require.NoError(t, err)
449+
require.True(t, ok)
450+
451+
assert.GreaterOrEqual(t, res.checks, int32(10))
452+
assert.Equal(t, res.checks, counter.Load())
453+
assert.Equal(t, "healthy", res.status)
454+
})
455+
456+
t.Run("Check Fails With Reason", func(t *testing.T) {
457+
defer goleak.VerifyNone(t)
458+
459+
counter := atomic.NewInt32(0)
460+
actionStarted := atomic.NewBool(false)
461+
462+
status := "adrien"
463+
464+
res, ok, err := RunActionWithParallelCheckAndResult(
465+
context.Background(),
466+
func(ctx context.Context) error {
467+
actionStarted.Store(true)
468+
<-ctx.Done()
469+
return DetermineContextError(ctx)
470+
},
471+
func(ctx context.Context) (res parallelisationCheckResult, ok bool) {
472+
if n := counter.Inc(); n >= 5 {
473+
return parallelisationCheckResult{
474+
checks: n,
475+
status: status,
476+
}, false
477+
} else {
478+
return parallelisationCheckResult{
479+
checks: n,
480+
status: "ok",
481+
}, true
482+
}
483+
},
484+
5*time.Millisecond,
485+
)
486+
487+
require.True(t, actionStarted.Load())
488+
require.Error(t, err)
489+
errortest.AssertError(t, err, commonerrors.ErrCancelled)
490+
491+
require.False(t, ok)
492+
assert.Equal(t, status, res.status)
493+
assert.Equal(t, int32(5), res.checks)
494+
assert.Equal(t, int32(5), counter.Load())
495+
})
496+
t.Run("Action Error (no context cancel)", func(t *testing.T) {
497+
defer goleak.VerifyNone(t)
498+
499+
counter := atomic.NewInt32(0)
500+
status := "abdel"
501+
502+
res, ok, err := RunActionWithParallelCheckAndResult(
503+
context.Background(),
504+
func(ctx context.Context) error {
505+
time.Sleep(30 * time.Millisecond)
506+
return commonerrors.New(commonerrors.ErrForbidden, faker.Sentence())
507+
},
508+
func(ctx context.Context) (parallelisationCheckResult, bool) {
509+
return parallelisationCheckResult{
510+
checks: counter.Inc(),
511+
status: status,
512+
}, true
513+
},
514+
5*time.Millisecond,
515+
)
516+
517+
require.Error(t, err)
518+
errortest.AssertError(t, err, commonerrors.ErrForbidden)
519+
require.True(t, ok)
520+
521+
assert.Equal(t, status, res.status)
522+
assert.GreaterOrEqual(t, res.checks, int32(1))
523+
assert.Equal(t, res.checks, counter.Load())
524+
})
525+
526+
t.Run("Context cancel", func(t *testing.T) {
527+
defer goleak.VerifyNone(t)
528+
529+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
530+
defer cancel()
531+
532+
counter := atomic.NewInt32(0)
533+
status := "kem"
534+
535+
res, ok, err := RunActionWithParallelCheckAndResult(
536+
ctx,
537+
func(ctx context.Context) error {
538+
<-ctx.Done()
539+
return DetermineContextError(ctx)
540+
},
541+
func(ctx context.Context) (parallelisationCheckResult, bool) {
542+
return parallelisationCheckResult{
543+
checks: counter.Inc(),
544+
status: status,
545+
}, true
546+
},
547+
5*time.Millisecond,
548+
)
549+
550+
require.Error(t, err)
551+
errortest.AssertError(t, err, commonerrors.ErrTimeout)
552+
assert.True(t, ok)
553+
assert.GreaterOrEqual(t, res.checks, int32(1))
554+
})
555+
}
556+
421557
func TestWaitUntil(t *testing.T) {
422558
defer goleak.VerifyNone(t)
423559
verifiedCondition := func(ctx context.Context) (bool, error) {

0 commit comments

Comments
 (0)