diff --git a/changes/20250729164750.feature b/changes/20250729164750.feature new file mode 100644 index 0000000000..391d969045 --- /dev/null +++ b/changes/20250729164750.feature @@ -0,0 +1 @@ +:sparkles: Add support for finding processes by their name/the command that was started diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index e295e2ba6c..9550de6eb6 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -12,6 +12,7 @@ import ( "time" "go.uber.org/atomic" + "golang.org/x/sync/errgroup" "github.com/ARM-software/golang-utils/utils/commonerrors" ) @@ -267,3 +268,58 @@ func WaitUntil(ctx context.Context, evalCondition func(ctx2 context.Context) (bo SleepWithContext(ctx, pauseBetweenEvaluations) } } + +func newWorker[JobType, ResultType any](ctx context.Context, f func(context.Context, JobType) (ResultType, bool, error), jobs chan JobType, results chan ResultType) (err error) { + for job := range jobs { + result, ok, subErr := f(ctx, job) + if subErr != nil { + err = commonerrors.WrapError(commonerrors.ErrUnexpected, subErr, "an error occurred whilst handling a job") + return + } + + err = DetermineContextError(ctx) + if err != nil { + return + } + + if ok { + results <- result + } + } + + return +} + +// WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size +func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f func(context.Context, InputType) (ResultType, bool, error)) (results []ResultType, err error) { + if numWorkers < 1 { + err = commonerrors.New(commonerrors.ErrInvalid, "numWorkers must be greater than or equal to 1") + return + } + + numJobs := len(jobs) + jobsChan := make(chan InputType, numJobs) + resultsChan := make(chan ResultType, numJobs) + + g, gCtx := errgroup.WithContext(ctx) + g.SetLimit(numWorkers) + for range numWorkers { + g.Go(func() error { return newWorker(gCtx, f, jobsChan, resultsChan) }) + } + for _, job := range jobs { + jobsChan <- job + } + + close(jobsChan) + err = g.Wait() + close(resultsChan) + if err != nil { + return + } + + for result := range resultsChan { + results = append(results, result) + } + + return +} diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index ffda26e15f..54fdb17f0c 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -463,3 +463,102 @@ func TestWaitUntil(t *testing.T) { errortest.AssertError(t, err, commonerrors.ErrUnexpected) }) } + +func TestWorkerPool(t *testing.T) { + for _, test := range []struct { + name string + numWorkers int + jobs []int + results []int + workerFunc func(context.Context, int) (int, bool, error) + err error + }{ + { + name: "Success", + numWorkers: 3, + jobs: []int{1, 2, 3, 4, 5}, + results: []int{2, 4, 6, 8, 10}, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + return job * 2, true, nil + }, + err: nil, + }, + { + name: "Invalid Num Workers", + numWorkers: 0, + jobs: []int{1, 2, 3}, + results: nil, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + return 0, true, nil + }, + err: commonerrors.ErrInvalid, + }, + { + name: "Worker Returns Error", + numWorkers: 2, + jobs: []int{1, 2, 3}, + results: nil, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + if job == 2 { + return 0, false, errors.New("fail") + } + return job, true, nil + }, + err: commonerrors.ErrUnexpected, + }, + { + name: "Some ok False", + numWorkers: 1, + jobs: []int{1, 2, 3}, + results: []int{1, 3}, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + return job, job != 2, nil + }, + err: nil, + }, + { + name: "All ok False", + numWorkers: 1, + jobs: []int{1, 2, 3}, + results: []int{}, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + return job, false, nil + }, + err: nil, + }, + { + name: "Empty Jobs", + numWorkers: 2, + jobs: []int{}, + results: []int{}, + workerFunc: func(ctx context.Context, job int) (int, bool, error) { + return job, true, nil + }, + err: nil, + }, + } { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + + results, err := WorkerPool(ctx, test.numWorkers, test.jobs, test.workerFunc) + + if test.err != nil { + errortest.AssertError(t, err, test.err) + } else { + require.NoError(t, err) + assert.ElementsMatch(t, test.results, results) + } + }) + } + + t.Run("Context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := WorkerPool(ctx, 100, []int{1, 2, 3}, func(ctx context.Context, job int) (int, bool, error) { + return job, true, nil + }) + + errortest.AssertError(t, err, commonerrors.ErrCancelled) + }) +} diff --git a/utils/proc/find/find_linux.go b/utils/proc/find/find_linux.go new file mode 100644 index 0000000000..23667a9ef7 --- /dev/null +++ b/utils/proc/find/find_linux.go @@ -0,0 +1,111 @@ +//go:build linux + +package find + +import ( + "bytes" + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/filesystem" + "github.com/ARM-software/golang-utils/utils/parallelisation" + "github.com/ARM-software/golang-utils/utils/proc" +) + +const ( + procFS = "/proc" + procDataFile = "cmdline" +) + +func checkProcessMatch(ctx context.Context, fs filesystem.FS, re *regexp.Regexp, procEntry string) (ok bool, err error) { + err = parallelisation.DetermineContextError(ctx) + if err != nil { + return + } + + data, err := fs.ReadFile(procEntry) + if err != nil { + if commonerrors.CorrespondTo(err, "no bytes were read") { + err = nil + return // ignore special descriptors since our cmdline will have content (we still have to check since all files in proc have size zero) + } + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not read proc entry '%v'", procEntry) + return + } + + data = bytes.ReplaceAll(data, []byte{0}, []byte{' '}) // https://man7.org/linux/man-pages/man5/proc_pid_cmdline.5.html + + ok = re.Match(data) + return +} + +func parseProcess(ctx context.Context, entry string) (p proc.IProcess, err error) { + err = parallelisation.DetermineContextError(ctx) + if err != nil { + return + } + + pid, err := strconv.Atoi(strings.Trim(strings.TrimSuffix(strings.TrimPrefix(entry, procFS), fmt.Sprintf("%v", procDataFile)), "/")) + if err != nil { + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not parse PID from proc path '%v'", entry) + return + } + + p, err = proc.FindProcess(ctx, pid) + if err != nil { + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "could not find process '%v'", pid) + return + } + + return +} + +// FindProcessByRegexForFS will search a given filesystem for the processes that match a specific regex +func FindProcessByRegexForFS(ctx context.Context, fs filesystem.FS, re *regexp.Regexp) (processes []proc.IProcess, err error) { + if !filesystem.Exists(procFS) { + err = commonerrors.Newf(commonerrors.ErrNotFound, "the proc filesystem was not found at '%v'", procFS) + return + } + err = parallelisation.DetermineContextError(ctx) + if err != nil { + return + } + + searchGlobTerm := fmt.Sprintf("%v/*/%v", procFS, procDataFile) + procEntries, err := fs.Glob(searchGlobTerm) + if err != nil { + err = commonerrors.WrapErrorf(commonerrors.ErrUnexpected, err, "an error occurred when searching for processes using the following glob '%v'", searchGlobTerm) + return + } + + processes, err = parallelisation.WorkerPool(ctx, 10, procEntries, func(ctx context.Context, entry string) (p proc.IProcess, matches bool, err error) { + matches, err = checkProcessMatch(ctx, fs, re, entry) + if err != nil || !matches { + return + } + + p, err = parseProcess(ctx, entry) + if err != nil { + return + } + + matches = true + return + }) + + return +} + +// FindProcessByRegex will search for the processes that match a specific regex +func FindProcessByRegex(ctx context.Context, re *regexp.Regexp) (processes []proc.IProcess, err error) { + return FindProcessByRegexForFS(ctx, filesystem.GetGlobalFileSystem(), re) +} + +// FindProcessByName will search for the processes that match a specific name +func FindProcessByName(ctx context.Context, name string) (processes []proc.IProcess, err error) { + return FindProcessByRegex(ctx, regexp.MustCompile(fmt.Sprintf(".*%v.*", regexp.QuoteMeta(name)))) +} diff --git a/utils/proc/find/find_linux_test.go b/utils/proc/find/find_linux_test.go new file mode 100644 index 0000000000..be2a9ff421 --- /dev/null +++ b/utils/proc/find/find_linux_test.go @@ -0,0 +1,90 @@ +//go:build linux + +package find + +import ( + "context" + "fmt" + "testing" + + "github.com/go-faker/faker/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/logs" + "github.com/ARM-software/golang-utils/utils/logs/logstest" + "github.com/ARM-software/golang-utils/utils/subprocess" +) + +func TestFind(t *testing.T) { + for _, test := range []struct { + name string + processes int + }{ + { + name: "One process", + processes: 1, + }, + { + name: "Many processes", + processes: 10, + }, + { + name: "No process", + processes: 0, + }, + } { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + processString := faker.Sentence() + for range test.processes { + l, err := logs.NewLogrLogger(logstest.NewStdTestLogger(), test.name) + require.NoError(t, err) + + cmd, err := subprocess.New(ctx, l, "start", "success", "failed", "sh", "-c", fmt.Sprintf("sleep 10 ; echo '%v'", processString)) + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + } + + processes, err := FindProcessByName(ctx, processString) + assert.NoError(t, err) + assert.Len(t, processes, test.processes) + + // stopping processes shows they were parsed correctly + for _, process := range processes { + err = process.Terminate(ctx) + require.NoError(t, err) + } + processes, err = FindProcessByName(ctx, processString) + require.NoError(t, err) + assert.Empty(t, processes) + }) + } + + t.Run("Cancel context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + processString := faker.Sentence() + + l, err := logs.NewLogrLogger(logstest.NewStdTestLogger(), "context cancelled") + require.NoError(t, err) + + cmd, err := subprocess.New(ctx, l, "start", "success", "failed", "sh", "-c", fmt.Sprintf("sleep 10 ; echo '%v'", processString)) + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + cancel() + + processes, err := FindProcessByName(ctx, processString) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + assert.Empty(t, processes) + }) + +}