diff --git a/changes/20250731140445.feature b/changes/20250731140445.feature new file mode 100644 index 0000000000..93bada3c01 --- /dev/null +++ b/changes/20250731140445.feature @@ -0,0 +1 @@ +:sparkles: Add support for gracefully killing child processes diff --git a/changes/20250804122854.feature b/changes/20250804122854.feature new file mode 100644 index 0000000000..5f053f18bf --- /dev/null +++ b/changes/20250804122854.feature @@ -0,0 +1 @@ +:sparkles: `[collection]` added collection functional operations `Map`, `Filter`, `Reject`, `Reduce` diff --git a/changes/20250804122923.feature b/changes/20250804122923.feature new file mode 100644 index 0000000000..32223b424c --- /dev/null +++ b/changes/20250804122923.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` added parallelised collection functional operations `Map`, `Filter`, `Reject` diff --git a/changes/20250804130842.feature b/changes/20250804130842.feature new file mode 100644 index 0000000000..5c99f84f97 --- /dev/null +++ b/changes/20250804130842.feature @@ -0,0 +1 @@ +:sparkles: `[proc]` added a function to find processes based on name diff --git a/utils/collection/conditions.go b/utils/collection/conditions.go index d83a6ba074..1f4d3741f1 100644 --- a/utils/collection/conditions.go +++ b/utils/collection/conditions.go @@ -110,7 +110,7 @@ func (c *Conditions) Xor() bool { return Xor(*c...) } -// OneHot performs an `OnHot` operation on all conditions +// OneHot performs an `OneHot` operation on all conditions func (c *Conditions) OneHot() bool { if c == nil { return false diff --git a/utils/collection/search.go b/utils/collection/search.go index d21924b95f..77c9da9ccf 100644 --- a/utils/collection/search.go +++ b/utils/collection/search.go @@ -71,6 +71,45 @@ func AnyFunc[S ~[]E, E any](s S, f func(E) bool) bool { return conditions.Any() } +// Filter returns a new slice that contains elements from the input slice which return true when they’re passed as a parameter to the provided filtering function f. +func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) { + result = make(S, 0, len(s)) + + for i := range s { + if f(s[i]) { + result = append(result, s[i]) + } + } + + return result +} + +// Map creates a new slice and populates it with the results of calling the provided function on every element in input slice. +func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) { + result = make([]T2, len(s)) + + for i := range s { + result[i] = f(s[i]) + } + + return result +} + +// Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. +// This is functionally equivalent to slices.DeleteFunc but it returns a new slice. +func Reject[S ~[]E, E any](s S, f func(E) bool) S { + return Filter(s, func(e E) bool { return !f(e) }) +} + +// Reduce runs a reducer function f over all elements in the array, in ascending-index order, and accumulates them into a single value. +func Reduce[T1, T2 any](s []T1, accumulator T2, f func(T2, T1) T2) (result T2) { + result = accumulator + for i := range s { + result = f(result, s[i]) + } + return +} + // AnyEmpty returns whether there is one entry in the slice which is empty. // If strict, then whitespaces are considered as empty strings func AnyEmpty(strict bool, slice []string) bool { diff --git a/utils/collection/search_test.go b/utils/collection/search_test.go index eb36479ae9..e8aaa31f6c 100644 --- a/utils/collection/search_test.go +++ b/utils/collection/search_test.go @@ -5,6 +5,8 @@ package collection import ( + "fmt" + "strconv" "testing" "github.com/go-faker/faker/v4" @@ -110,3 +112,44 @@ func TestAllNotEmpty(t *testing.T) { assert.False(t, AllNotEmpty(false, []string{faker.Username(), "", faker.Name(), "", faker.Sentence()})) assert.True(t, AllNotEmpty(false, []string{faker.Username(), faker.Name(), faker.Sentence()})) } + +func TestFilterReject(t *testing.T) { + nums := []int{1, 2, 3, 4, 5} + assert.ElementsMatch(t, []int{2, 4}, Filter(nums, func(n int) bool { + return n%2 == 0 + })) + assert.ElementsMatch(t, []int{1, 3, 5}, Reject(nums, func(n int) bool { + return n%2 == 0 + })) + assert.ElementsMatch(t, []int{4, 5}, Filter(nums, func(n int) bool { + return n > 3 + })) + assert.ElementsMatch(t, []int{1, 2, 3}, Reject(nums, func(n int) bool { + return n > 3 + })) + assert.ElementsMatch(t, []string{"foo", "bar"}, Filter([]string{"", "foo", "", "bar", ""}, func(x string) bool { + return len(x) > 0 + })) + assert.ElementsMatch(t, []string{"", "", ""}, Reject([]string{"", "foo", "", "bar", ""}, func(x string) bool { + return len(x) > 0 + })) +} + +func TestMap(t *testing.T) { + mapped := Map([]int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + assert.ElementsMatch(t, []string{"Hello world 1", "Hello world 2"}, mapped) + mapped = Map([]int64{1, 2, 3, 4}, func(x int64) string { + return strconv.FormatInt(x, 10) + }) + assert.ElementsMatch(t, []string{"1", "2", "3", "4"}, mapped) +} + +func TestReduce(t *testing.T) { + nums := []int{1, 2, 3, 4, 5} + sumOfNums := Reduce(nums, 0, func(acc, n int) int { + return acc + n + }) + assert.Equal(t, sumOfNums, 15) +} diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 9550de6eb6..93fc596619 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -23,20 +23,21 @@ func DetermineContextError(ctx context.Context) error { } type result struct { - Item interface{} + Item any err error } // Parallelise parallelises an action over as many goroutines as specified by the argList and retrieves all the results when all the goroutines are done. -func Parallelise(argList interface{}, action func(arg interface{}) (interface{}, error), resultType reflect.Type) (results interface{}, err error) { +// To control the number of goroutines spawned, prefer WorkerPool +func Parallelise(argList any, action func(arg any) (any, error), resultType reflect.Type) (results any, err error) { keepReturn := resultType != nil argListValue := reflect.ValueOf(argList) length := argListValue.Len() channel := make(chan result, length) for i := 0; i < length; i++ { - go func(args reflect.Value, actionFunc func(arg interface{}) (interface{}, error)) { + go func(args reflect.Value, actionFunc func(arg any) (any, error)) { var r result - r.Item, r.err = func(v reflect.Value) (interface{}, error) { + r.Item, r.err = func(v reflect.Value) (any, error) { return actionFunc(v.Interface()) }(args) channel <- r @@ -306,13 +307,19 @@ func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, for range numWorkers { g.Go(func() error { return newWorker(gCtx, f, jobsChan, resultsChan) }) } - for _, job := range jobs { - jobsChan <- job + for i := range jobs { + if DetermineContextError(ctx) != nil { + break + } + jobsChan <- jobs[i] } close(jobsChan) err = g.Wait() close(resultsChan) + if err == nil { + err = DetermineContextError(ctx) + } if err != nil { return } @@ -323,3 +330,36 @@ func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, return } + +// Filter is similar to collection.Filter but uses parallelisation. +func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) (result []T, err error) { + result, err = WorkerPool[T, T](ctx, numWorkers, s, func(fCtx context.Context, item T) (r T, ok bool, fErr error) { + fErr = DetermineContextError(fCtx) + if fErr != nil { + return + } + ok = f(item) + r = item + return + }) + return +} + +// Map is similar to collection.Map but uses parallelisation. +func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) T2) (result []T2, err error) { + result, err = WorkerPool[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) { + fErr = DetermineContextError(fCtx) + if fErr != nil { + return + } + r = f(item) + ok = true + return + }) + return +} + +// Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false. +func Reject[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ([]T, error) { + return Filter[T](ctx, numWorkers, s, func(e T) bool { return !f(e) }) +} diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index 54fdb17f0c..fe9d2305c1 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -10,6 +10,7 @@ import ( "fmt" "math/rand" "reflect" + "strconv" "testing" "time" @@ -378,6 +379,7 @@ func runActionWithParallelCheckHappy(t *testing.T, ctx context.Context) { } err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond) require.NoError(t, err) + assert.Equal(t, int32(15), counter.Load()) } func runActionWithParallelCheckFail(t *testing.T, ctx context.Context) { @@ -394,6 +396,7 @@ func runActionWithParallelCheckFail(t *testing.T, ctx context.Context) { err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond) require.Error(t, err) errortest.AssertError(t, err, commonerrors.ErrCancelled) + assert.Equal(t, int32(1), counter.Load()) } func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) { @@ -410,9 +413,11 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) { err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond) require.Error(t, err) errortest.AssertError(t, err, commonerrors.ErrCancelled) + assert.GreaterOrEqual(t, counter.Load(), int32(1)) } func TestWaitUntil(t *testing.T) { + defer goleak.VerifyNone(t) verifiedCondition := func(ctx context.Context) (bool, error) { SleepWithContext(ctx, 50*time.Millisecond) return true, nil @@ -465,6 +470,7 @@ func TestWaitUntil(t *testing.T) { } func TestWorkerPool(t *testing.T) { + defer goleak.VerifyNone(t) for _, test := range []struct { name string numWorkers int @@ -562,3 +568,71 @@ func TestWorkerPool(t *testing.T) { errortest.AssertError(t, err, commonerrors.ErrCancelled) }) } + +func TestFilterReject(t *testing.T) { + defer goleak.VerifyNone(t) + nums := []int{1, 2, 3, 4, 5} + ctx := context.Background() + results, err := Filter(ctx, 3, nums, func(n int) bool { + return n%2 == 0 + }) + require.NoError(t, err) + assert.ElementsMatch(t, []int{2, 4}, results) + results, err = Reject(ctx, 3, nums, func(n int) bool { + return n%2 == 0 + }) + require.NoError(t, err) + assert.ElementsMatch(t, []int{1, 3, 5}, results) + results, err = Filter(ctx, 3, nums, func(n int) bool { + return n > 3 + }) + require.NoError(t, err) + assert.ElementsMatch(t, []int{4, 5}, results) + results, err = Reject(ctx, 3, nums, func(n int) bool { + return n > 3 + }) + require.NoError(t, err) + assert.ElementsMatch(t, []int{1, 2, 3}, results) + results2, err := Filter(ctx, 3, []string{"", "foo", "", "bar", ""}, func(x string) bool { + return len(x) > 0 + }) + + require.NoError(t, err) + assert.ElementsMatch(t, []string{"foo", "bar"}, results2) + results3, err := Reject(ctx, 3, []string{"", "foo", "", "bar", ""}, func(x string) bool { + return len(x) > 0 + }) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"", "", ""}, results3) + t.Run("cancelled context", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := Filter(cancelledCtx, 3, nums, func(n int) bool { + return n%2 == 0 + }) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) +} + +func TestMap(t *testing.T) { + defer goleak.VerifyNone(t) + ctx := context.Background() + mapped, err := Map(ctx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"Hello world 1", "Hello world 2"}, mapped) + mapped, err = Map(ctx, 3, []int64{1, 2, 3, 4}, func(x int64) string { + return strconv.FormatInt(x, 10) + }) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"1", "2", "3", "4"}, mapped) + t.Run("cancelled context", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := Map(cancelledCtx, 3, []int{1, 2}, func(i int) string { + return fmt.Sprintf("Hello world %v", i) + }) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) +} diff --git a/utils/proc/find/find.go b/utils/proc/find/find.go new file mode 100644 index 0000000000..a4abc7fb85 --- /dev/null +++ b/utils/proc/find/find.go @@ -0,0 +1,26 @@ +package find + +import ( + "context" + "fmt" + "regexp" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/proc" +) + +const numWorkers = 10 + +// FindProcessByRegex will search for the processes that match a specific regex +func FindProcessByRegex(ctx context.Context, re *regexp.Regexp) (processes []proc.IProcess, err error) { + if re == nil { + err = commonerrors.UndefinedVariable("regex to search") + return + } + return findProcessByRegex(ctx, re) +} + +// FindProcessByName will search for the processes that match a specific name +func FindProcessByName(ctx context.Context, name string) (processes []proc.IProcess, err error) { + return FindProcessByRegex(ctx, regexp.MustCompile(fmt.Sprintf(".*%v.*", regexp.QuoteMeta(name)))) +} diff --git a/utils/proc/find/find_linux.go b/utils/proc/find/find_linux.go index 8490f8b589..bf0e76082d 100644 --- a/utils/proc/find/find_linux.go +++ b/utils/proc/find/find_linux.go @@ -82,7 +82,7 @@ func FindProcessByRegexForFS(ctx context.Context, fs filesystem.FS, re *regexp.R return } - processes, err = parallelisation.WorkerPool(ctx, 10, procEntries, func(ctx context.Context, entry string) (p proc.IProcess, matches bool, err error) { + processes, err = parallelisation.WorkerPool(ctx, numWorkers, procEntries, func(ctx context.Context, entry string) (p proc.IProcess, matches bool, err error) { matches, err = checkProcessMatch(ctx, fs, re, entry) if err != nil || !matches { return @@ -100,12 +100,7 @@ func FindProcessByRegexForFS(ctx context.Context, fs filesystem.FS, re *regexp.R return } -// FindProcessByRegex will search for the processes that match a specific regex -func FindProcessByRegex(ctx context.Context, re *regexp.Regexp) (processes []proc.IProcess, err error) { +// findProcessByRegex will search for the processes that match a specific regex +func findProcessByRegex(ctx context.Context, re *regexp.Regexp) (processes []proc.IProcess, err error) { return FindProcessByRegexForFS(ctx, filesystem.GetGlobalFileSystem(), re) } - -// FindProcessByName will search for the processes that match a specific name -func FindProcessByName(ctx context.Context, name string) (processes []proc.IProcess, err error) { - return FindProcessByRegex(ctx, regexp.MustCompile(fmt.Sprintf(".*%v.*", regexp.QuoteMeta(name)))) -} diff --git a/utils/proc/find/find_other.go b/utils/proc/find/find_other.go new file mode 100644 index 0000000000..ac83f0d0ed --- /dev/null +++ b/utils/proc/find/find_other.go @@ -0,0 +1,32 @@ +//go:build !linux + +/* + * Copyright (C) 2020-2024 Arm Limited or its affiliates and Contributors. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package find + +import ( + "context" + "regexp" + + "github.com/ARM-software/golang-utils/utils/collection" + "github.com/ARM-software/golang-utils/utils/parallelisation" + "github.com/ARM-software/golang-utils/utils/proc" +) + +func findProcessByRegex(ctx context.Context, re *regexp.Regexp) (processes []proc.IProcess, err error) { + ps, err := proc.Ps(ctx) + if err != nil || len(ps) == 0 { + return + } + + processes, err = parallelisation.Filter[proc.IProcess](ctx, 10, ps, func(iProcess proc.IProcess) bool { + if iProcess == nil { + return false + } + return collection.AnyTrue(re.MatchString(iProcess.Name()), re.MatchString(iProcess.Executable()), re.MatchString(iProcess.Cmdline())) + }) + return +} diff --git a/utils/proc/find/find_test.go b/utils/proc/find/find_test.go new file mode 100644 index 0000000000..8b2d28adf6 --- /dev/null +++ b/utils/proc/find/find_test.go @@ -0,0 +1,66 @@ +package find + +import ( + "context" + "fmt" + "os/exec" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" +) + +func TestFindProcessByName(t *testing.T) { + if runtime.GOOS != "linux" { + defer goleak.VerifyNone(t) + } + tests := []struct { + cmdWindows *exec.Cmd + cmdOther *exec.Cmd + processNameWindows string + processNameOther string + }{ + { + cmdWindows: exec.Command("cmd.exe", "/c", fmt.Sprintf("ping localhost -n %v > nul", time.Second.Seconds())), //nolint: gosec // G204 Subprocess launched with a potential tainted input or cmd arguments (gosec) + cmdOther: exec.Command("sh", "-c", fmt.Sprintf("sleep %v", time.Second.Seconds())), //nolint: gosec // G204 Subprocess launched with a potential tainted input or cmd arguments (gosec) + processNameWindows: "ping", + processNameOther: "sleep", + }, + } + + for i := range tests { + test := tests[i] + t.Run("subtest", func(t *testing.T) { + ctx := context.Background() + cmd := test.cmdOther + toFind := test.processNameOther + if runtime.GOOS == "windows" { + cmd = test.cmdWindows + toFind = test.processNameWindows + } + ps, err := FindProcessByName(ctx, toFind) + require.NoError(t, err) + numOfProcesses := len(ps) + require.NoError(t, cmd.Start()) + defer func() { _ = cmd.Process.Kill() }() + ps, err = FindProcessByName(ctx, toFind) + require.NoError(t, err) + assert.NotEmpty(t, ps) + + assert.GreaterOrEqual(t, len(ps), numOfProcesses) + require.NoError(t, cmd.Wait()) + t.Run("cancelled context", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + _, err = FindProcessByName(cancelCtx, toFind) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) + }) + } +} diff --git a/utils/proc/interfaces.go b/utils/proc/interfaces.go index b9a09fe2f0..53e2edfa7b 100644 --- a/utils/proc/interfaces.go +++ b/utils/proc/interfaces.go @@ -40,6 +40,9 @@ type IProcess interface { // Children returns the children of the process if any. Children(ctx context.Context) ([]IProcess, error) + // IsAZombie returns whether the process is a zombie process. See https://en.wikipedia.org/wiki/Zombie_process + IsAZombie() bool + // IsRunning returns whether the process is still running or not. IsRunning() bool diff --git a/utils/proc/interrupt.go b/utils/proc/interrupt.go index de7cbd73d3..621c6f9322 100644 --- a/utils/proc/interrupt.go +++ b/utils/proc/interrupt.go @@ -6,6 +6,9 @@ import ( "os/exec" "time" + "golang.org/x/sync/errgroup" + + "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" ) @@ -44,24 +47,82 @@ func InterruptProcess(ctx context.Context, pid int, signal InterruptType) (err e return } -// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed. -func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) { - defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }() - err = InterruptProcess(ctx, pid, SigInt) +// TerminateGracefullyWithChildren follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed. +// It does not attempt to terminate the process group. If you wish to terminate the process group directly then send -pgid to TerminateGracefully but +// this does not guarantee that the group will be terminated gracefully. +// Instead, this function lists each child and attempts to kill them gracefully concurrently. It will then attempt to gracefully terminate itself. +// Due to the multi-stage process and the fact that the full grace period must pass for each stage specified above, the total maximum length of this +// function will be 2*gracePeriod not gracePeriod. +func TerminateGracefullyWithChildren(ctx context.Context, pid int, gracePeriod time.Duration) (err error) { + err = parallelisation.DetermineContextError(ctx) if err != nil { return } - err = InterruptProcess(ctx, pid, SigTerm) + + p, err := FindProcess(ctx, pid) + if err != nil { + if commonerrors.Any(err, commonerrors.ErrNotFound) { + err = nil + return + } + + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred whilst searching for process '%v'", pid) + return + } + + children, err := p.Children(ctx) if err != nil { + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not check for children for pid '%v'", pid) return } - _, fErr := FindProcess(ctx, pid) - if commonerrors.Any(fErr, commonerrors.ErrNotFound) { - // The process no longer exist. - // No need to wait the grace period + + if len(children) == 0 { + err = TerminateGracefully(ctx, pid, gracePeriod) return } - parallelisation.SleepWithContext(ctx, gracePeriod) + + childGroup, terminateCtx := errgroup.WithContext(ctx) + childGroup.SetLimit(len(children)) + for _, child := range children { + if child.IsRunning() { + childGroup.Go(func() error { return TerminateGracefullyWithChildren(terminateCtx, child.Pid(), gracePeriod) }) + } + } + err = childGroup.Wait() + if err != nil { + return + } + + err = TerminateGracefully(ctx, pid, gracePeriod) + return +} + +func terminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) { + err = InterruptProcess(ctx, pid, SigInt) + if err != nil { + return + } + err = InterruptProcess(ctx, pid, SigTerm) + if err != nil { + return + } + + return parallelisation.RunActionWithParallelCheck(ctx, + func(ctx context.Context) error { + parallelisation.SleepWithContext(ctx, gracePeriod) + return nil + }, + func(ctx context.Context) bool { + _, fErr := FindProcess(ctx, pid) + return commonerrors.Any(fErr, commonerrors.ErrNotFound) + + }, 200*time.Millisecond) +} + +// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed. +func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) { + defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }() + _ = terminateGracefully(ctx, pid, gracePeriod) err = InterruptProcess(ctx, pid, SigKill) return } @@ -94,3 +155,18 @@ func DefineCmdCancel(cmd *exec.Cmd) (*exec.Cmd, error) { } return cmd, nil } + +// WaitForCompletion will wait for a given process to complete. +// This allows check to work if the underlying process was stopped without needing the os.Process that started it. +func WaitForCompletion(ctx context.Context, pid int) (err error) { + pids, err := getGroupProcesses(ctx, pid) + if err != nil { + return + } + return parallelisation.WaitUntil(ctx, func(ctx2 context.Context) (bool, error) { + return collection.AnyFunc(pids, func(subPid int) bool { + p, _ := FindProcess(ctx2, subPid) + return p.IsRunning() // FindProcess will always return an instantiated process and any non-running state should exit without error + }), nil + }, 1000*time.Millisecond) +} diff --git a/utils/proc/interrupt_test.go b/utils/proc/interrupt_test.go index 5a3fa330fa..7a845e6207 100644 --- a/utils/proc/interrupt_test.go +++ b/utils/proc/interrupt_test.go @@ -14,61 +14,114 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/parallelisation" ) func TestTerminateGracefully(t *testing.T) { - defer goleak.VerifyNone(t) - t.Run("single process", func(t *testing.T) { - cmd := exec.Command("sleep", "50") - require.NoError(t, cmd.Start()) - defer func() { _ = cmd.Wait() }() - process, err := FindProcess(context.Background(), cmd.Process.Pid) - require.NoError(t, err) - assert.True(t, process.IsRunning()) - require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond)) - time.Sleep(500 * time.Millisecond) - process, err = FindProcess(context.Background(), cmd.Process.Pid) - if err == nil { - require.NotEmpty(t, process) - assert.False(t, process.IsRunning()) - } else { - errortest.AssertError(t, err, commonerrors.ErrNotFound) - assert.Empty(t, process) - } - }) - t.Run("process with children", func(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("test with bash") - } - // see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773 - // https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16 - cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1") - require.NoError(t, cmd.Start()) - defer func() { _ = cmd.Wait() }() - require.NotNil(t, cmd.Process) - p, err := FindProcess(context.Background(), cmd.Process.Pid) - require.NoError(t, err) - assert.True(t, p.IsRunning()) - require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond)) - p, err = FindProcess(context.Background(), cmd.Process.Pid) - if err == nil { - require.NotEmpty(t, p) - assert.False(t, p.IsRunning()) - } else { - errortest.AssertError(t, err, commonerrors.ErrNotFound) - assert.Empty(t, p) - } - }) - t.Run("no process", func(t *testing.T) { - random, err := faker.RandomInt(9000, 20000, 1) - require.NoError(t, err) - require.NoError(t, TerminateGracefully(context.Background(), random[0], 100*time.Millisecond)) - }) - t.Run("cancelled", func(t *testing.T) { - random, err := faker.RandomInt(9000, 20000, 1) - require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - errortest.AssertError(t, TerminateGracefully(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled) - }) + for _, test := range []struct { + name string + testFunc func(ctx context.Context, pid int, gracePeriod time.Duration) error + }{ + { + name: "TerminateGracefully", + testFunc: TerminateGracefully, + }, + { + name: "TerminateGracefullyWithChildren", + testFunc: TerminateGracefullyWithChildren, + }, + } { + t.Run(test.name, func(t *testing.T) { + defer goleak.VerifyNone(t) + t.Run("single process", func(t *testing.T) { + cmd := exec.Command("sleep", "50") + require.NoError(t, cmd.Start()) + defer func() { + p, _ := FindProcess(context.Background(), cmd.Process.Pid) + if p != nil && (p.IsRunning() || p.IsAZombie()) { + _ = cmd.Wait() + } + }() + process, err := FindProcess(context.Background(), cmd.Process.Pid) + require.NoError(t, err) + require.True(t, process.IsRunning()) + + now := time.Now() + gracePeriod := 10 * time.Second + require.NoError(t, test.testFunc(context.Background(), cmd.Process.Pid, gracePeriod)) + assert.Less(t, time.Since(now), gracePeriod) // this indicates that the process was closed by INT/SIG not KILL + + time.Sleep(500 * time.Millisecond) + process, err = FindProcess(context.Background(), cmd.Process.Pid) + if err == nil { + require.NotEmpty(t, process) + assert.False(t, process.IsRunning()) + } else { + errortest.AssertError(t, err, commonerrors.ErrNotFound) + assert.Empty(t, process) + } + }) + t.Run("process with children", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("test with bash") + } + // see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773 + // https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16 + cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1") + require.NoError(t, cmd.Start()) + defer func() { + p, _ := FindProcess(context.Background(), cmd.Process.Pid) + if p != nil && (p.IsRunning() || p.IsAZombie()) { + _ = cmd.Wait() + } + }() + time.Sleep(500 * time.Millisecond) + require.NotNil(t, cmd.Process) + + timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, parallelisation.WaitUntil(timeoutCtx, func(fCtx context.Context) (bool, error) { + p, fErr := FindProcess(fCtx, cmd.Process.Pid) + if fErr != nil { + return false, fErr + } + return p.IsRunning() || p.IsAZombie(), nil + }, 200*time.Millisecond)) + p, err := FindProcess(context.Background(), cmd.Process.Pid) + require.NoError(t, err) + require.True(t, p.IsRunning() || p.IsAZombie()) + children, err := p.Children(timeoutCtx) + require.NoError(t, err) + if !p.IsAZombie() { + assert.NotEmpty(t, children) + } + + now := time.Now() + gracePeriod := 10 * time.Second + require.NoError(t, test.testFunc(context.Background(), cmd.Process.Pid, gracePeriod)) + assert.Less(t, time.Since(now), gracePeriod) // this indicates that the process was closed by INT/SIG not KILL + + p, err = FindProcess(context.Background(), cmd.Process.Pid) + if err == nil { + require.NotEmpty(t, p) + assert.False(t, p.IsRunning()) + } else { + errortest.AssertError(t, err, commonerrors.ErrNotFound) + assert.Empty(t, p) + } + }) + t.Run("no process", func(t *testing.T) { + random, err := faker.RandomInt(9000, 20000, 1) + require.NoError(t, err) + require.NoError(t, test.testFunc(context.Background(), random[0], 100*time.Millisecond)) + }) + t.Run("cancelled", func(t *testing.T) { + random, err := faker.RandomInt(9000, 20000, 1) + require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, test.testFunc(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled) + }) + }) + } } diff --git a/utils/proc/process.go b/utils/proc/process.go index 0e0b751861..860d3fd159 100644 --- a/utils/proc/process.go +++ b/utils/proc/process.go @@ -20,6 +20,8 @@ const ( statusRunning = "running" statusSleep = "sleep" statusIdle = "idle" + statusZombie = "zombie" + workers = 10 ) // Ps returns all processes in a similar fashion to `ps` command on Unix. @@ -34,9 +36,7 @@ func Ps(ctx context.Context) (processes []IProcess, err error) { if err != nil { return } - for i := range pss { - processes = append(processes, wrapProcess(pss[i])) - } + processes, err = parallelisation.Map[*process.Process, IProcess](ctx, workers, pss, wrapProcess) return } @@ -92,6 +92,10 @@ func (p *ps) IsRunning() (running bool) { return } +func (p *ps) IsAZombie() bool { + return isProcessAZombie(p.imp) +} + func (p *ps) Cmdline() string { cmd, _ := p.imp.Cmdline() return cmd @@ -252,7 +256,7 @@ func isProcessRunning(p *process.Process) (running bool) { running = false return } - // On some platforms, such as *nix, a zombie process is reported as a running process by p.IsRunning() but this is not the case. Therefore, a further check is performed on the process status to verify a running process is actually in the expected running state. Nonetheless, status is not cross platform and is not implemented on Windows. For those platform, the status returned by IsRunning is then considered + // On some platforms, such as *nix, a zombie process is reported as a running process by p.IsRunning() but this is not the case. Therefore, a further check is performed on the process status to verify a running process is actually in the expected running state. Nonetheless, status is not cross-platform and is not implemented on Windows. For those platform, the status returned by IsRunning is then considered status, err := p.Status() if err != nil { return @@ -262,6 +266,25 @@ func isProcessRunning(p *process.Process) (running bool) { return } +// a zombie process +func isProcessAZombie(p *process.Process) (zombie bool) { + if p == nil { + return + } + exist, _ := process.PidExists(p.Pid) + if !exist { + zombie = false + return + } + status, err := p.Status() + if err != nil { + return + } + // https://github.com/shirou/gopsutil/blob/e230f528f075f78e713f167c28b692cc15307d19/process/process.go#L48 + _, zombie = collection.FindInSlice(false, status, statusZombie) + return +} + // NewProcess creates a new Process instance, it only stores the pid and // checks that the process exists. Other method on Process can be used // to get more information about the process. An error will be returned diff --git a/utils/proc/ps_posix.go b/utils/proc/ps_posix.go index 2cfbcc30bd..10c7e25cb1 100644 --- a/utils/proc/ps_posix.go +++ b/utils/proc/ps_posix.go @@ -10,9 +10,7 @@ package proc import ( "context" - "fmt" "syscall" - "time" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" @@ -41,24 +39,19 @@ func killGroup(ctx context.Context, pid int32) (err error) { // kill a whole process group by sending a signal to -xyz where xyz is the pgid // http://unix.stackexchange.com/questions/14815/process-descendants if pgid != int(pid) { - err = fmt.Errorf("%w: process #%v is not group leader", commonerrors.ErrUnexpected, pid) + err = commonerrors.Newf(commonerrors.ErrUnexpected, "process #%v is not group leader", pid) return } err = ConvertProcessError(syscall.Kill(-pgid, syscall.SIGKILL)) return } -// WaitForCompletion will wait for a given process to complete. -// This allows check to work if the underlying process was stopped without needing the os.Process that started it. -func WaitForCompletion(ctx context.Context, pid int) (err error) { +func getGroupProcesses(ctx context.Context, pid int) (pids []int, err error) { pgid, err := getpgid(pid) if err != nil { err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not get group PID for '%v'", pid) return } - - return parallelisation.WaitUntil(ctx, func(ctx2 context.Context) (bool, error) { - p, _ := FindProcess(ctx, pgid) - return p.IsRunning(), nil // FindProcess will always return an instantiated process and any non-runnning state should exit without error - }, 1000*time.Millisecond) + pids = append(pids, pgid) + return } diff --git a/utils/proc/ps_windows.go b/utils/proc/ps_windows.go index 1e54d4c77f..a62f5d6d31 100644 --- a/utils/proc/ps_windows.go +++ b/utils/proc/ps_windows.go @@ -10,12 +10,10 @@ package proc import ( "context" - "fmt" "os/exec" "strconv" "time" - "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" ) @@ -29,31 +27,26 @@ func killGroup(ctx context.Context, pid int32) (err error) { // setting the following to avoid having hanging subprocesses as described in https://github.com/golang/go/issues/24050 cmd.WaitDelay = 50 * time.Millisecond err = ConvertProcessError(cmd.Run()) - if commonerrors.Any(err, nil, commonerrors.ErrCancelled, commonerrors.ErrTimeout) { - return - } else { - err = fmt.Errorf("%w: could not kill process group (#%v): %v", commonerrors.ErrUnexpected, pid, err.Error()) + if err != nil { + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not kill process group (#%v)", pid) } return } -// WaitForCompletion will wait for a given process to complete. -// This allows check to work if the underlying process was stopped without needing the os.Process that started it. -func WaitForCompletion(ctx context.Context, pid int) (err error) { +func getGroupProcesses(ctx context.Context, pid int) (pids []int, err error) { parent, err := FindProcess(ctx, pid) + if err != nil { + return + } children, err := parent.Children(ctx) - + if err != nil { + return + } // Windows doesn't have group PIDs - var pids = make([]int, len(children)+1) + pids = make([]int, len(children)+1) pids[0] = parent.Pid() for i := range children { pids[i+1] = children[i].Pid() } - - return parallelisation.WaitUntil(ctx, func(ctx2 context.Context) (bool, error) { - return collection.AnyFunc(pids, func(pid int) bool { - p, _ := FindProcess(ctx2, pid) - return p.IsRunning() - }), nil - }, 1000*time.Millisecond) + return }