Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250714171923.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:bug: `[subprocess]` make sure child processes are killed on context cancellation
1 change: 1 addition & 0 deletions changes/20250715105927.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `[proc]` Add utilities to kill processes gracefully
13 changes: 6 additions & 7 deletions utils/proc/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package proc

import (
"fmt"
"os"
"os/exec"
"syscall"
Expand Down Expand Up @@ -33,17 +32,17 @@ func ConvertProcessError(err error) error {
// ESRCH is "no such process", meaning the process has already exited.
return nil
case commonerrors.Any(err, exec.ErrWaitDelay):
return fmt.Errorf("%w: %v", commonerrors.ErrTimeout, err.Error())
return commonerrors.WrapError(commonerrors.ErrTimeout, err, "")
case commonerrors.Any(err, exec.ErrDot, exec.ErrNotFound):
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
case commonerrors.Any(process.ErrorNotPermitted):
return fmt.Errorf("%w: %v", commonerrors.ErrForbidden, err.Error())
return commonerrors.WrapError(commonerrors.ErrForbidden, err, "")
case commonerrors.Any(process.ErrorProcessNotRunning):
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
case commonerrors.CorrespondTo(err, errAccessDenied):
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
case commonerrors.CorrespondTo(err, errNotImplemented):
return fmt.Errorf("%w: %v", commonerrors.ErrNotImplemented, err.Error())
return commonerrors.WrapError(commonerrors.ErrNotImplemented, err, "")
default:
return err
}
Expand Down
64 changes: 64 additions & 0 deletions utils/proc/interrupt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package proc

import (
"context"
"time"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/parallelisation"
)

//go:generate go run github.com/dmarkham/enumer -type=InterruptType -text -json -yaml
type InterruptType int

const (
SigInt InterruptType = 2
SigKill InterruptType = 9
SigTerm InterruptType = 15
)

func InterruptProcess(ctx context.Context, pid int, signal InterruptType) (err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
process, err := FindProcess(ctx, pid)
if err != nil || process == nil {
err = commonerrors.Ignore(err, commonerrors.ErrNotFound)
return
}

switch signal {
case SigInt:
err = process.Interrupt(ctx)
case SigKill:
err = process.KillWithChildren(ctx)
case SigTerm:
err = process.Terminate(ctx)
default:
err = commonerrors.New(commonerrors.ErrInvalid, "unknown interrupt type for process")
}
return
}

// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed.
func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }()
err = InterruptProcess(ctx, pid, SigInt)
if err != nil {
return
}
err = InterruptProcess(ctx, pid, SigTerm)
if err != nil {
return
}
_, fErr := FindProcess(ctx, pid)
if commonerrors.Any(fErr, commonerrors.ErrNotFound) {
// The process no longer exist.
// No need to wait the grace period
return
}
parallelisation.SleepWithContext(ctx, gracePeriod)
err = InterruptProcess(ctx, pid, SigKill)
return
}
74 changes: 74 additions & 0 deletions utils/proc/interrupt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package proc

import (
"context"
"os/exec"
"runtime"
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
)

func TestTerminateGracefully(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("single process", func(t *testing.T) {
cmd := exec.Command("sleep", "50")
require.NoError(t, cmd.Start())
defer func() { _ = cmd.Wait() }()
process, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
assert.True(t, process.IsRunning())
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
time.Sleep(500 * time.Millisecond)
process, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, process)
assert.False(t, process.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, process)
}
})
t.Run("process with children", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test with bash")
}
// see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
// https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16
cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1")
require.NoError(t, cmd.Start())
defer func() { _ = cmd.Wait() }()
require.NotNil(t, cmd.Process)
p, err := FindProcess(context.Background(), cmd.Process.Pid)
require.NoError(t, err)
assert.True(t, p.IsRunning())
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
p, err = FindProcess(context.Background(), cmd.Process.Pid)
if err == nil {
require.NotEmpty(t, p)
assert.False(t, p.IsRunning())
} else {
errortest.AssertError(t, err, commonerrors.ErrNotFound)
assert.Empty(t, p)
}
})
t.Run("no process", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
require.NoError(t, TerminateGracefully(context.Background(), random[0], 100*time.Millisecond))
})
t.Run("cancelled", func(t *testing.T) {
random, err := faker.RandomInt(9000, 20000, 1)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
errortest.AssertError(t, TerminateGracefully(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled)
})
}
72 changes: 31 additions & 41 deletions utils/subprocess/command_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ package subprocess

import (
"context"
"fmt"
"os/exec"
"time"

"github.com/sasha-s/go-deadlock"
"go.uber.org/atomic"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/logs"
Expand All @@ -20,6 +20,8 @@ import (
commandUtils "github.com/ARM-software/golang-utils/utils/subprocess/command"
)

const subprocessTerminationGracePeriod = 10 * time.Millisecond

// INTERNAL
// wrapper over an exec cmd.
type cmdWrapper struct {
Expand All @@ -45,7 +47,7 @@ func (c *cmdWrapper) Start() error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.cmd == nil {
return fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
return commonerrors.UndefinedVariable("command")
}
return ConvertCommandError(c.cmd.Start())
}
Expand All @@ -54,45 +56,31 @@ func (c *cmdWrapper) Run() error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.cmd == nil {
return fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
return commonerrors.UndefinedVariable("command")
}
return ConvertCommandError(c.cmd.Run())
}

type interruptType int

const (
sigint interruptType = 2
sigkill interruptType = 9
sigterm interruptType = 15
)

func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt interruptType) error {
func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt proc.InterruptType) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.cmd == nil {
return commonerrors.New(commonerrors.ErrUndefined, "undefined command")
return commonerrors.UndefinedVariable("command")
}
subprocess := c.cmd.Process
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var stopErr error
stopErr := atomic.NewError(nil)
if subprocess != nil {
pid := subprocess.Pid
parallelisation.ScheduleAfter(ctx, 10*time.Millisecond, func(time.Time) {
process, err := proc.FindProcess(ctx, pid)
if process == nil || err != nil {
parallelisation.ScheduleAfter(ctx, subprocessTerminationGracePeriod, func(time.Time) {
process, sErr := proc.FindProcess(ctx, pid)
if process == nil || sErr != nil {
return
}
switch interrupt {
case sigint:
_ = process.Interrupt(ctx)
case sigkill:
_ = process.KillWithChildren(ctx)
case sigterm:
_ = process.Terminate(ctx)
default:
stopErr = commonerrors.New(commonerrors.ErrInvalid, "unknown interrupt type for process")
sErr = proc.InterruptProcess(ctx, pid, interrupt)
if commonerrors.Any(sErr, commonerrors.ErrInvalid, commonerrors.ErrCancelled, commonerrors.ErrTimeout) {
stopErr.Store(sErr)
}
})
}
Expand All @@ -102,31 +90,31 @@ func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt interru
return err
}

return stopErr
return stopErr.Load()
}

func (c *cmdWrapper) interrupt(interrupt interruptType) error {
func (c *cmdWrapper) interrupt(interrupt proc.InterruptType) error {
return c.interruptWithContext(context.Background(), interrupt)
}

func (c *cmdWrapper) Stop() error {
return c.interrupt(sigkill)
return c.interrupt(proc.SigKill)
}

func (c *cmdWrapper) Interrupt(ctx context.Context) error {
return c.interruptWithContext(ctx, sigint)
return c.interruptWithContext(ctx, proc.SigInt)
}

func (c *cmdWrapper) Pid() (pid int, err error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.cmd == nil {
err = fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
err = commonerrors.UndefinedVariable("command")
return
}
subprocess := c.cmd.Process
if subprocess == nil {
err = fmt.Errorf("%w:undefined subprocess", commonerrors.ErrUndefined)
err = commonerrors.UndefinedVariable("subprocess")
return
}
pid = subprocess.Pid
Expand All @@ -146,6 +134,13 @@ type command struct {
func (c *command) createCommand(cmdCtx context.Context) *exec.Cmd {
newCmd, newArgs := c.as.Redefine(c.cmd, c.args...)
cmd := exec.CommandContext(cmdCtx, newCmd, newArgs...) //nolint:gosec
cmd.Cancel = func() error {
p := cmd.Process
if p == nil {
return nil
}
return proc.TerminateGracefully(context.Background(), p.Pid, subprocessTerminationGracePeriod)
}
cmd.Stdout = newOutStreamer(cmdCtx, c.loggers)
cmd.Stderr = newErrLogStreamer(cmdCtx, c.loggers)
cmd.Env = cmd.Environ()
Expand All @@ -169,11 +164,11 @@ func (c *command) Reset() {

func (c *command) Check() (err error) {
if c.cmd == "" {
err = fmt.Errorf("missing command: %w", commonerrors.ErrUndefined)
err = commonerrors.UndefinedVariable("command")
return
}
if c.as == nil {
err = fmt.Errorf("missing command translator: %w", commonerrors.ErrUndefined)
err = commonerrors.UndefinedVariable("command translator")
return
}
if c.loggers == nil {
Expand All @@ -199,6 +194,7 @@ func ConvertCommandError(err error) error {
return proc.ConvertProcessError(err)
}

// CleanKillOfCommand tries to terminate a command gracefully.
func CleanKillOfCommand(ctx context.Context, cmd *exec.Cmd) (err error) {
if cmd == nil {
return
Expand All @@ -212,13 +208,7 @@ func CleanKillOfCommand(ctx context.Context, cmd *exec.Cmd) (err error) {
thisP := cmd.Process
if thisP == nil {
return
} else {
p, subErr := proc.FindProcess(ctx, thisP.Pid)
if subErr != nil {
err = subErr
return
}
err = p.KillWithChildren(ctx)
}
err = proc.TerminateGracefully(ctx, thisP.Pid, subprocessTerminationGracePeriod)
return
}
11 changes: 2 additions & 9 deletions utils/subprocess/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ package subprocess

import (
"context"
"fmt"

"github.com/sasha-s/go-deadlock"
"go.uber.org/atomic"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/logs"
"github.com/ARM-software/golang-utils/utils/platform"
"github.com/ARM-software/golang-utils/utils/proc"
commandUtils "github.com/ARM-software/golang-utils/utils/subprocess/command"
)
Expand Down Expand Up @@ -167,7 +165,7 @@ func (s *Subprocess) check() (err error) {
// In GO, there is no reentrant locks and so following what is described there
// https://groups.google.com/forum/#!msg/golang-nuts/XqW1qcuZgKg/Ui3nQkeLV80J
if s.command == nil {
err = fmt.Errorf("missing command: %w", commonerrors.ErrUndefined)
err = commonerrors.UndefinedVariable("command")
return
}
err = s.command.Check()
Expand Down Expand Up @@ -204,11 +202,6 @@ func (s *Subprocess) Wait(ctx context.Context) (err error) {
return commonerrors.New(commonerrors.ErrConflict, "command not started")
}

// FIXME: verify proc.WaitForCompletion works on windows. Remove this platform check once this is verified.
if platform.IsWindows() {
return s.command.cmdWrapper.cmd.Wait()
}

return proc.WaitForCompletion(ctx, pid)
}

Expand Down Expand Up @@ -267,7 +260,7 @@ func (s *Subprocess) Execute() (err error) {
defer s.Cancel()

if s.IsOn() {
return fmt.Errorf("process is already started: %w", commonerrors.ErrConflict)
return commonerrors.New(commonerrors.ErrConflict, "process is already started")
}
s.processMonitoring.Reset()
s.command.Reset()
Expand Down
Loading
Loading