Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250731140445.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: Add support for gracefully killing child processes
3 changes: 3 additions & 0 deletions utils/parallelisation/parallelisation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ func runActionWithParallelCheckHappy(t *testing.T, ctx context.Context) {
}
err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond)
require.NoError(t, err)
assert.Equal(t, int32(15), counter.Load())
}

func runActionWithParallelCheckFail(t *testing.T, ctx context.Context) {
Expand All @@ -394,6 +395,7 @@ func runActionWithParallelCheckFail(t *testing.T, ctx context.Context) {
err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond)
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.Equal(t, int32(1), counter.Load())
}

func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
Expand All @@ -410,6 +412,7 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
err := RunActionWithParallelCheck(ctx, action, checkAction, 10*time.Millisecond)
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.GreaterOrEqual(t, counter.Load(), int32(1))
}

func TestWaitUntil(t *testing.T) {
Expand Down
80 changes: 70 additions & 10 deletions utils/proc/interrupt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os/exec"
"time"

"golang.org/x/sync/errgroup"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/parallelisation"
)
Expand Down Expand Up @@ -44,24 +46,82 @@ func InterruptProcess(ctx context.Context, pid int, signal InterruptType) (err e
return
}

// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed.
func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }()
err = InterruptProcess(ctx, pid, SigInt)
// TerminateGracefullyWithChildren follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed.
// It does not attempt to terminate the process group. If you wish to terminate the process group directly then send -pgid to TerminateGracefully but
// this does not guarantee that the group will be terminated gracefully.
// Instead this function lists each child and attempts to kill them gracefully in a concurrently. It will then attempt to gracefully terminate itself.
// Due to the multi-stage process and the fact that the full grace period must pass for each stage specified above, the total maximum length of this
// function will be 2*gracePeriod not gracePeriod.
func TerminateGracefullyWithChildren(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
err = InterruptProcess(ctx, pid, SigTerm)

p, err := FindProcess(ctx, pid)
if err != nil {
if commonerrors.Any(err, commonerrors.ErrNotFound) {
err = nil
return
}

err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred whilst searching for process '%v'", pid)
return
}

children, err := p.Children(ctx)
if err != nil {
err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not check for children for pid '%v'", pid)
return
}

if len(children) == 0 {
err = TerminateGracefully(ctx, pid, gracePeriod)
return
}

childGroup, terminateCtx := errgroup.WithContext(ctx)
childGroup.SetLimit(len(children))
for _, child := range children {
if child.IsRunning() {
childGroup.Go(func() error { return TerminateGracefullyWithChildren(terminateCtx, child.Pid(), gracePeriod) })
}
}
err = childGroup.Wait()
if err != nil {
return
}
_, fErr := FindProcess(ctx, pid)
if commonerrors.Any(fErr, commonerrors.ErrNotFound) {
// The process no longer exist.
// No need to wait the grace period

err = TerminateGracefully(ctx, pid, gracePeriod)
return
}

func terminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
err = InterruptProcess(ctx, pid, SigInt)
if err != nil {
return
}
parallelisation.SleepWithContext(ctx, gracePeriod)
err = InterruptProcess(ctx, pid, SigTerm)
if err != nil {
return
}

return parallelisation.RunActionWithParallelCheck(ctx,
func(ctx context.Context) error {
parallelisation.SleepWithContext(ctx, gracePeriod)
return nil
},
func(ctx context.Context) bool {
_, fErr := FindProcess(ctx, pid)
return commonerrors.Any(fErr, commonerrors.ErrNotFound)

}, 200*time.Millisecond)
}

// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed.
func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }()
_ = terminateGracefully(ctx, pid, gracePeriod)
err = InterruptProcess(ctx, pid, SigKill)
return
}
Expand Down
151 changes: 97 additions & 54 deletions utils/proc/interrupt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,101 @@ import (
)

func TestTerminateGracefully(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("single process", func(t *testing.T) {
cmd := exec.Command("sleep", "50")
require.NoError(t, cmd.Start())
defer func() { _ = cmd.Wait() }()
process, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
assert.True(t, process.IsRunning())
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
time.Sleep(500 * time.Millisecond)
process, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, process)
assert.False(t, process.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, process)
}
})
t.Run("process with children", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test with bash")
}
// see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
// https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16
cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1")
require.NoError(t, cmd.Start())
defer func() { _ = cmd.Wait() }()
require.NotNil(t, cmd.Process)
p, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
assert.True(t, p.IsRunning())
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
p, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, p)
assert.False(t, p.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, p)
}
})
t.Run("no process", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
require.NoError(t, TerminateGracefully(context.Background(), random[0], 100*time.Millisecond))
})
t.Run("cancelled", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
errortest.AssertError(t, TerminateGracefully(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled)
})
for _, test := range []struct {
name string
testFunc func(ctx context.Context, pid int, gracePeriod time.Duration) error
}{
{
name: "TerminateGracefully",
testFunc: TerminateGracefully,
},
{
name: "TerminateGracefullyWithChildren",
testFunc: TerminateGracefullyWithChildren,
},
} {
t.Run(test.name, func(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("single process", func(t *testing.T) {
cmd := exec.Command("sleep", "50")
require.NoError(t, cmd.Start())
defer func() {
p, _ := FindProcess(context.Background(), cmd.Process.Pid)
if p != nil && p.IsRunning() {
_ = cmd.Wait()
}
}()
process, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
require.True(t, process.IsRunning())

now := time.Now()
gracePeriod := 10 * time.Second
require.NoError(t, test.testFunc(context.Background(), cmd.Process.Pid, gracePeriod))
assert.Less(t, time.Since(now), gracePeriod) // this indicates that the process was closed by INT/SIG not KILL

time.Sleep(500 * time.Millisecond)
process, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, process)
assert.False(t, process.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, process)
}
})
t.Run("process with children", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test with bash")
}
// see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
// https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16
cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1")
require.NoError(t, cmd.Start())
defer func() {
p, _ := FindProcess(context.Background(), cmd.Process.Pid)
if p != nil && p.IsRunning() {
_ = cmd.Wait()
}
}()
time.Sleep(500 * time.Millisecond)
require.NotNil(t, cmd.Process)
p, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
for !p.IsRunning() {
time.Sleep(200 * time.Millisecond)
}
require.True(t, p.IsRunning())
children, err := p.Children(context.Background())
require.NoError(t, err)
require.Greater(t, len(children), 0)

now := time.Now()
gracePeriod := 10 * time.Second
require.NoError(t, test.testFunc(context.Background(), cmd.Process.Pid, gracePeriod))
assert.Less(t, time.Since(now), gracePeriod) // this indicates that the process was closed by INT/SIG not KILL

p, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, p)
assert.False(t, p.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, p)
}
})
t.Run("no process", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
require.NoError(t, test.testFunc(context.Background(), random[0], 100*time.Millisecond))
})
t.Run("cancelled", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
errortest.AssertError(t, test.testFunc(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled)
})
})
}
}
Loading