Skip to content

Commit 96466ba

Browse files
authored
🐛 [subprocess] make sure child processes are killed on context cancellation (#651)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description Improvements about processes cancellation ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent 88e15e1 commit 96466ba

File tree

8 files changed

+289
-109
lines changed

8 files changed

+289
-109
lines changed

changes/20250714171923.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:bug: `[subprocess]` make sure child processes are killed on context cancellation

changes/20250715105927.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `[proc]` Add utilities to kill processes gracefully

utils/proc/errors.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package proc
66

77
import (
8-
"fmt"
98
"os"
109
"os/exec"
1110
"syscall"
@@ -33,17 +32,17 @@ func ConvertProcessError(err error) error {
3332
// ESRCH is "no such process", meaning the process has already exited.
3433
return nil
3534
case commonerrors.Any(err, exec.ErrWaitDelay):
36-
return fmt.Errorf("%w: %v", commonerrors.ErrTimeout, err.Error())
35+
return commonerrors.WrapError(commonerrors.ErrTimeout, err, "")
3736
case commonerrors.Any(err, exec.ErrDot, exec.ErrNotFound):
38-
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
37+
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
3938
case commonerrors.Any(process.ErrorNotPermitted):
40-
return fmt.Errorf("%w: %v", commonerrors.ErrForbidden, err.Error())
39+
return commonerrors.WrapError(commonerrors.ErrForbidden, err, "")
4140
case commonerrors.Any(process.ErrorProcessNotRunning):
42-
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
41+
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
4342
case commonerrors.CorrespondTo(err, errAccessDenied):
44-
return fmt.Errorf("%w: %v", commonerrors.ErrNotFound, err.Error())
43+
return commonerrors.WrapError(commonerrors.ErrNotFound, err, "")
4544
case commonerrors.CorrespondTo(err, errNotImplemented):
46-
return fmt.Errorf("%w: %v", commonerrors.ErrNotImplemented, err.Error())
45+
return commonerrors.WrapError(commonerrors.ErrNotImplemented, err, "")
4746
default:
4847
return err
4948
}

utils/proc/interrupt.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package proc
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/ARM-software/golang-utils/utils/commonerrors"
8+
"github.com/ARM-software/golang-utils/utils/parallelisation"
9+
)
10+
11+
//go:generate go run github.com/dmarkham/enumer -type=InterruptType -text -json -yaml
12+
type InterruptType int
13+
14+
const (
15+
SigInt InterruptType = 2
16+
SigKill InterruptType = 9
17+
SigTerm InterruptType = 15
18+
)
19+
20+
func InterruptProcess(ctx context.Context, pid int, signal InterruptType) (err error) {
21+
err = parallelisation.DetermineContextError(ctx)
22+
if err != nil {
23+
return
24+
}
25+
process, err := FindProcess(ctx, pid)
26+
if err != nil || process == nil {
27+
err = commonerrors.Ignore(err, commonerrors.ErrNotFound)
28+
return
29+
}
30+
31+
switch signal {
32+
case SigInt:
33+
err = process.Interrupt(ctx)
34+
case SigKill:
35+
err = process.KillWithChildren(ctx)
36+
case SigTerm:
37+
err = process.Terminate(ctx)
38+
default:
39+
err = commonerrors.New(commonerrors.ErrInvalid, "unknown interrupt type for process")
40+
}
41+
return
42+
}
43+
44+
// TerminateGracefully follows the pattern set by [kubernetes](https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-termination) and terminates processes gracefully by first sending a SIGTERM and then a SIGKILL after the grace period has elapsed.
45+
func TerminateGracefully(ctx context.Context, pid int, gracePeriod time.Duration) (err error) {
46+
defer func() { _ = InterruptProcess(context.Background(), pid, SigKill) }()
47+
err = InterruptProcess(ctx, pid, SigInt)
48+
if err != nil {
49+
return
50+
}
51+
err = InterruptProcess(ctx, pid, SigTerm)
52+
if err != nil {
53+
return
54+
}
55+
_, fErr := FindProcess(ctx, pid)
56+
if commonerrors.Any(fErr, commonerrors.ErrNotFound) {
57+
// The process no longer exist.
58+
// No need to wait the grace period
59+
return
60+
}
61+
parallelisation.SleepWithContext(ctx, gracePeriod)
62+
err = InterruptProcess(ctx, pid, SigKill)
63+
return
64+
}

utils/proc/interrupt_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package proc
2+
3+
import (
4+
"context"
5+
"os/exec"
6+
"runtime"
7+
"testing"
8+
"time"
9+
10+
"github.com/go-faker/faker/v4"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/goleak"
14+
15+
"github.com/ARM-software/golang-utils/utils/commonerrors"
16+
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
17+
)
18+
19+
func TestTerminateGracefully(t *testing.T) {
20+
defer goleak.VerifyNone(t)
21+
t.Run("single process", func(t *testing.T) {
22+
cmd := exec.Command("sleep", "50")
23+
require.NoError(t, cmd.Start())
24+
defer func() { _ = cmd.Wait() }()
25+
process, err := FindProcess(context.Background(), cmd.Process.Pid)
26+
require.NoError(t, err)
27+
assert.True(t, process.IsRunning())
28+
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
29+
time.Sleep(500 * time.Millisecond)
30+
process, err = FindProcess(context.Background(), cmd.Process.Pid)
31+
if err == nil {
32+
require.NotEmpty(t, process)
33+
assert.False(t, process.IsRunning())
34+
} else {
35+
errortest.AssertError(t, err, commonerrors.ErrNotFound)
36+
assert.Empty(t, process)
37+
}
38+
})
39+
t.Run("process with children", func(t *testing.T) {
40+
if runtime.GOOS == "windows" {
41+
t.Skip("test with bash")
42+
}
43+
// see https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
44+
// https://forum.golangbridge.org/t/killing-child-process-on-timeout-in-go-code/995/16
45+
cmd := exec.Command("bash", "-c", "watch date > date.txt 2>&1")
46+
require.NoError(t, cmd.Start())
47+
defer func() { _ = cmd.Wait() }()
48+
require.NotNil(t, cmd.Process)
49+
p, err := FindProcess(context.Background(), cmd.Process.Pid)
50+
require.NoError(t, err)
51+
assert.True(t, p.IsRunning())
52+
require.NoError(t, TerminateGracefully(context.Background(), cmd.Process.Pid, 100*time.Millisecond))
53+
p, err = FindProcess(context.Background(), cmd.Process.Pid)
54+
if err == nil {
55+
require.NotEmpty(t, p)
56+
assert.False(t, p.IsRunning())
57+
} else {
58+
errortest.AssertError(t, err, commonerrors.ErrNotFound)
59+
assert.Empty(t, p)
60+
}
61+
})
62+
t.Run("no process", func(t *testing.T) {
63+
random, err := faker.RandomInt(9000, 20000, 1)
64+
require.NoError(t, err)
65+
require.NoError(t, TerminateGracefully(context.Background(), random[0], 100*time.Millisecond))
66+
})
67+
t.Run("cancelled", func(t *testing.T) {
68+
random, err := faker.RandomInt(9000, 20000, 1)
69+
require.NoError(t, err)
70+
ctx, cancel := context.WithCancel(context.Background())
71+
cancel()
72+
errortest.AssertError(t, TerminateGracefully(ctx, random[0], 100*time.Millisecond), commonerrors.ErrCancelled)
73+
})
74+
}

utils/subprocess/command_wrapper.go

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ package subprocess
77

88
import (
99
"context"
10-
"fmt"
1110
"os/exec"
1211
"time"
1312

1413
"github.com/sasha-s/go-deadlock"
14+
"go.uber.org/atomic"
1515

1616
"github.com/ARM-software/golang-utils/utils/commonerrors"
1717
"github.com/ARM-software/golang-utils/utils/logs"
@@ -20,6 +20,8 @@ import (
2020
commandUtils "github.com/ARM-software/golang-utils/utils/subprocess/command"
2121
)
2222

23+
const subprocessTerminationGracePeriod = 10 * time.Millisecond
24+
2325
// INTERNAL
2426
// wrapper over an exec cmd.
2527
type cmdWrapper struct {
@@ -45,7 +47,7 @@ func (c *cmdWrapper) Start() error {
4547
c.mu.RLock()
4648
defer c.mu.RUnlock()
4749
if c.cmd == nil {
48-
return fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
50+
return commonerrors.UndefinedVariable("command")
4951
}
5052
return ConvertCommandError(c.cmd.Start())
5153
}
@@ -54,45 +56,31 @@ func (c *cmdWrapper) Run() error {
5456
c.mu.RLock()
5557
defer c.mu.RUnlock()
5658
if c.cmd == nil {
57-
return fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
59+
return commonerrors.UndefinedVariable("command")
5860
}
5961
return ConvertCommandError(c.cmd.Run())
6062
}
6163

62-
type interruptType int
63-
64-
const (
65-
sigint interruptType = 2
66-
sigkill interruptType = 9
67-
sigterm interruptType = 15
68-
)
69-
70-
func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt interruptType) error {
64+
func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt proc.InterruptType) error {
7165
c.mu.RLock()
7266
defer c.mu.RUnlock()
7367
if c.cmd == nil {
74-
return commonerrors.New(commonerrors.ErrUndefined, "undefined command")
68+
return commonerrors.UndefinedVariable("command")
7569
}
7670
subprocess := c.cmd.Process
7771
ctx, cancel := context.WithCancel(ctx)
7872
defer cancel()
79-
var stopErr error
73+
stopErr := atomic.NewError(nil)
8074
if subprocess != nil {
8175
pid := subprocess.Pid
82-
parallelisation.ScheduleAfter(ctx, 10*time.Millisecond, func(time.Time) {
83-
process, err := proc.FindProcess(ctx, pid)
84-
if process == nil || err != nil {
76+
parallelisation.ScheduleAfter(ctx, subprocessTerminationGracePeriod, func(time.Time) {
77+
process, sErr := proc.FindProcess(ctx, pid)
78+
if process == nil || sErr != nil {
8579
return
8680
}
87-
switch interrupt {
88-
case sigint:
89-
_ = process.Interrupt(ctx)
90-
case sigkill:
91-
_ = process.KillWithChildren(ctx)
92-
case sigterm:
93-
_ = process.Terminate(ctx)
94-
default:
95-
stopErr = commonerrors.New(commonerrors.ErrInvalid, "unknown interrupt type for process")
81+
sErr = proc.InterruptProcess(ctx, pid, interrupt)
82+
if commonerrors.Any(sErr, commonerrors.ErrInvalid, commonerrors.ErrCancelled, commonerrors.ErrTimeout) {
83+
stopErr.Store(sErr)
9684
}
9785
})
9886
}
@@ -102,31 +90,31 @@ func (c *cmdWrapper) interruptWithContext(ctx context.Context, interrupt interru
10290
return err
10391
}
10492

105-
return stopErr
93+
return stopErr.Load()
10694
}
10795

108-
func (c *cmdWrapper) interrupt(interrupt interruptType) error {
96+
func (c *cmdWrapper) interrupt(interrupt proc.InterruptType) error {
10997
return c.interruptWithContext(context.Background(), interrupt)
11098
}
11199

112100
func (c *cmdWrapper) Stop() error {
113-
return c.interrupt(sigkill)
101+
return c.interrupt(proc.SigKill)
114102
}
115103

116104
func (c *cmdWrapper) Interrupt(ctx context.Context) error {
117-
return c.interruptWithContext(ctx, sigint)
105+
return c.interruptWithContext(ctx, proc.SigInt)
118106
}
119107

120108
func (c *cmdWrapper) Pid() (pid int, err error) {
121109
c.mu.RLock()
122110
defer c.mu.RUnlock()
123111
if c.cmd == nil {
124-
err = fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined)
112+
err = commonerrors.UndefinedVariable("command")
125113
return
126114
}
127115
subprocess := c.cmd.Process
128116
if subprocess == nil {
129-
err = fmt.Errorf("%w:undefined subprocess", commonerrors.ErrUndefined)
117+
err = commonerrors.UndefinedVariable("subprocess")
130118
return
131119
}
132120
pid = subprocess.Pid
@@ -146,6 +134,13 @@ type command struct {
146134
func (c *command) createCommand(cmdCtx context.Context) *exec.Cmd {
147135
newCmd, newArgs := c.as.Redefine(c.cmd, c.args...)
148136
cmd := exec.CommandContext(cmdCtx, newCmd, newArgs...) //nolint:gosec
137+
cmd.Cancel = func() error {
138+
p := cmd.Process
139+
if p == nil {
140+
return nil
141+
}
142+
return proc.TerminateGracefully(context.Background(), p.Pid, subprocessTerminationGracePeriod)
143+
}
149144
cmd.Stdout = newOutStreamer(cmdCtx, c.loggers)
150145
cmd.Stderr = newErrLogStreamer(cmdCtx, c.loggers)
151146
cmd.Env = cmd.Environ()
@@ -169,11 +164,11 @@ func (c *command) Reset() {
169164

170165
func (c *command) Check() (err error) {
171166
if c.cmd == "" {
172-
err = fmt.Errorf("missing command: %w", commonerrors.ErrUndefined)
167+
err = commonerrors.UndefinedVariable("command")
173168
return
174169
}
175170
if c.as == nil {
176-
err = fmt.Errorf("missing command translator: %w", commonerrors.ErrUndefined)
171+
err = commonerrors.UndefinedVariable("command translator")
177172
return
178173
}
179174
if c.loggers == nil {
@@ -199,6 +194,7 @@ func ConvertCommandError(err error) error {
199194
return proc.ConvertProcessError(err)
200195
}
201196

197+
// CleanKillOfCommand tries to terminate a command gracefully.
202198
func CleanKillOfCommand(ctx context.Context, cmd *exec.Cmd) (err error) {
203199
if cmd == nil {
204200
return
@@ -212,13 +208,7 @@ func CleanKillOfCommand(ctx context.Context, cmd *exec.Cmd) (err error) {
212208
thisP := cmd.Process
213209
if thisP == nil {
214210
return
215-
} else {
216-
p, subErr := proc.FindProcess(ctx, thisP.Pid)
217-
if subErr != nil {
218-
err = subErr
219-
return
220-
}
221-
err = p.KillWithChildren(ctx)
222211
}
212+
err = proc.TerminateGracefully(ctx, thisP.Pid, subprocessTerminationGracePeriod)
223213
return
224214
}

utils/subprocess/executor.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@ package subprocess
88

99
import (
1010
"context"
11-
"fmt"
1211

1312
"github.com/sasha-s/go-deadlock"
1413
"go.uber.org/atomic"
1514

1615
"github.com/ARM-software/golang-utils/utils/commonerrors"
1716
"github.com/ARM-software/golang-utils/utils/logs"
18-
"github.com/ARM-software/golang-utils/utils/platform"
1917
"github.com/ARM-software/golang-utils/utils/proc"
2018
commandUtils "github.com/ARM-software/golang-utils/utils/subprocess/command"
2119
)
@@ -167,7 +165,7 @@ func (s *Subprocess) check() (err error) {
167165
// In GO, there is no reentrant locks and so following what is described there
168166
// https://groups.google.com/forum/#!msg/golang-nuts/XqW1qcuZgKg/Ui3nQkeLV80J
169167
if s.command == nil {
170-
err = fmt.Errorf("missing command: %w", commonerrors.ErrUndefined)
168+
err = commonerrors.UndefinedVariable("command")
171169
return
172170
}
173171
err = s.command.Check()
@@ -204,11 +202,6 @@ func (s *Subprocess) Wait(ctx context.Context) (err error) {
204202
return commonerrors.New(commonerrors.ErrConflict, "command not started")
205203
}
206204

207-
// FIXME: verify proc.WaitForCompletion works on windows. Remove this platform check once this is verified.
208-
if platform.IsWindows() {
209-
return s.command.cmdWrapper.cmd.Wait()
210-
}
211-
212205
return proc.WaitForCompletion(ctx, pid)
213206
}
214207

@@ -267,7 +260,7 @@ func (s *Subprocess) Execute() (err error) {
267260
defer s.Cancel()
268261

269262
if s.IsOn() {
270-
return fmt.Errorf("process is already started: %w", commonerrors.ErrConflict)
263+
return commonerrors.New(commonerrors.ErrConflict, "process is already started")
271264
}
272265
s.processMonitoring.Reset()
273266
s.command.Reset()

0 commit comments

Comments
 (0)