From 62658ac288c0eb1de502cc838b0070ea115a5284 Mon Sep 17 00:00:00 2001 From: rabbitstack Date: Fri, 21 Mar 2025 18:20:39 +0100 Subject: [PATCH] feat(filter): Thread pool filter fields Provides the implementation of the thread pool field accessor. --- pkg/filter/accessor.go | 28 ++++++---- pkg/filter/accessor_windows.go | 53 +++++++++++++++++++ pkg/filter/fields/fields_windows.go | 73 +++++++++++++++++++++---- pkg/filter/filter_test.go | 82 ++++++++++++++++++++++++++--- pkg/filter/filter_windows.go | 3 ++ pkg/filter/ql/function.go | 23 ++++---- 6 files changed, 221 insertions(+), 41 deletions(-) diff --git a/pkg/filter/accessor.go b/pkg/filter/accessor.go index 863f6c11c..49ea069bd 100644 --- a/pkg/filter/accessor.go +++ b/pkg/filter/accessor.go @@ -145,17 +145,18 @@ func (k *kevtAccessor) Get(f Field, kevt *kevent.Kevent) (kparams.Value, error) // referenced in the bound field. func (f *filter) narrowAccessors() { var ( - removeKevtAccessor = true - removePsAccessor = true - removeThreadAccessor = true - removeImageAccessor = true - removeFileAccessor = true - removeRegistryAccessor = true - removeNetworkAccessor = true - removeHandleAccessor = true - removePEAccessor = true - removeMemAccessor = true - removeDNSAccessor = true + removeKevtAccessor = true + removePsAccessor = true + removeThreadAccessor = true + removeImageAccessor = true + removeFileAccessor = true + removeRegistryAccessor = true + removeNetworkAccessor = true + removeHandleAccessor = true + removePEAccessor = true + removeMemAccessor = true + removeDNSAccessor = true + removeThreadpoolAccessor = true ) for _, field := range f.fields { @@ -182,6 +183,8 @@ func (f *filter) narrowAccessors() { removeMemAccessor = false case field.Name.IsDNSField(): removeDNSAccessor = false + case field.Name.IsThreadpoolField(): + removeThreadpoolAccessor = false } } @@ -218,6 +221,9 @@ func (f *filter) narrowAccessors() { if removeDNSAccessor { f.removeAccessor(&dnsAccessor{}) } + if removeThreadpoolAccessor { + f.removeAccessor(&threadpoolAccessor{}) + } for _, accessor := range f.accessors { accessor.SetFields(f.fields) diff --git a/pkg/filter/accessor_windows.go b/pkg/filter/accessor_windows.go index a642b212e..2b06b2eb8 100644 --- a/pkg/filter/accessor_windows.go +++ b/pkg/filter/accessor_windows.go @@ -65,6 +65,7 @@ func GetAccessors() []Accessor { newHandleAccessor(), newNetworkAccessor(), newRegistryAccessor(), + newThreadpoolAccessor(), } } @@ -1337,3 +1338,55 @@ func (*dnsAccessor) Get(f Field, kevt *kevent.Kevent) (kparams.Value, error) { return nil, nil } + +// threadpoolAccessor extracts values from thread pool events +type threadpoolAccessor struct{} + +func (threadpoolAccessor) SetFields([]Field) {} +func (threadpoolAccessor) SetSegments([]fields.Segment) {} +func (threadpoolAccessor) IsFieldAccessible(e *kevent.Kevent) bool { + return e.Category == ktypes.Threadpool +} + +func newThreadpoolAccessor() Accessor { + return &threadpoolAccessor{} +} + +func (*threadpoolAccessor) Get(f Field, e *kevent.Kevent) (kparams.Value, error) { + switch f.Name { + case fields.ThreadpoolPoolID: + return e.GetParamAsString(kparams.ThreadpoolPoolID), nil + case fields.ThreadpoolTaskID: + return e.GetParamAsString(kparams.ThreadpoolTaskID), nil + case fields.ThreadpoolCallbackAddress: + return e.GetParamAsString(kparams.ThreadpoolCallback), nil + case fields.ThreadpoolCallbackSymbol: + return e.GetParamAsString(kparams.ThreadpoolCallbackSymbol), nil + case fields.ThreadpoolCallbackModule: + return e.GetParamAsString(kparams.ThreadpoolCallbackModule), nil + case fields.ThreadpoolCallbackContext: + return e.GetParamAsString(kparams.ThreadpoolContext), nil + case fields.ThreadpoolCallbackContextRip: + return e.GetParamAsString(kparams.ThreadpoolContextRip), nil + case fields.ThreadpoolCallbackContextRipSymbol: + return e.GetParamAsString(kparams.ThreadpoolContextRipSymbol), nil + case fields.ThreadpoolCallbackContextRipModule: + return e.GetParamAsString(kparams.ThreadpoolContextRipModule), nil + case fields.ThreadpoolSubprocessTag: + return e.GetParamAsString(kparams.ThreadpoolSubprocessTag), nil + case fields.ThreadpoolTimer: + return e.GetParamAsString(kparams.ThreadpoolTimer), nil + case fields.ThreadpoolTimerSubqueue: + return e.GetParamAsString(kparams.ThreadpoolTimerSubqueue), nil + case fields.ThreadpoolTimerDuetime: + return e.Kparams.GetUint64(kparams.ThreadpoolTimerDuetime) + case fields.ThreadpoolTimerPeriod: + return e.Kparams.GetUint32(kparams.ThreadpoolTimerPeriod) + case fields.ThreadpoolTimerWindow: + return e.Kparams.GetUint32(kparams.ThreadpoolTimerWindow) + case fields.ThreadpoolTimerAbsolute: + return e.Kparams.GetBool(kparams.ThreadpoolTimerAbsolute) + } + + return nil, nil +} diff --git a/pkg/filter/fields/fields_windows.go b/pkg/filter/fields/fields_windows.go index 83d4dceeb..3cee4fe7a 100644 --- a/pkg/filter/fields/fields_windows.go +++ b/pkg/filter/fields/fields_windows.go @@ -508,22 +508,56 @@ const ( DNSAnswers Field = "dns.answers" // DNSRcode identifies the field that represents the DNS response code DNSRcode Field = "dns.rcode" + + // ThreadpoolPoolID identifies the field that represents the thread pool identifier + ThreadpoolPoolID = "threadpool.id" + // ThreadpoolTaskID identifies the field that represents the thread pool task identifier + ThreadpoolTaskID = "threadpool.task.id" + // ThreadpoolCallbackAddress identifies the field that represents the address of the callback function + ThreadpoolCallbackAddress = "threadpool.callback.address" + // ThreadpoolCallbackSymbol identifies the field that represents the callback symbol + ThreadpoolCallbackSymbol = "threadpool.callback.symbol" + // ThreadpoolCallbackModule identifies the field that represents the module containing the callback symbol + ThreadpoolCallbackModule = "threadpool.callback.module" + // ThreadpoolCallbackContext identifies the field that represents the address of the callback context + ThreadpoolCallbackContext = "threadpool.callback.context" + // ThreadpoolCallbackContextRip identifies the field that represents the value of instruction pointer contained in the callback context + ThreadpoolCallbackContextRip = "threadpool.callback.context.rip" + // ThreadpoolCallbackContextRipSymbol identifies the field that represents the symbol name associated with the instruction pointer in callback context + ThreadpoolCallbackContextRipSymbol = "threadpool.callback.context.rip.symbol" + // ThreadpoolCallbackContextRipModule identifies the field that represents the module name associated with the instruction pointer in callback context + ThreadpoolCallbackContextRipModule = "threadpool.callback.context.rip.module" + // ThreadpoolSubprocessTag identifies the field that represents the service identifier associated with the thread pool + ThreadpoolSubprocessTag = "threadpool.subprocess_tag" + // ThreadpoolTimerDuetime identifies the field that represents the timer due time + ThreadpoolTimerDuetime = "threadpool.timer.duetime" + // ThreadpoolTimerSubqueue identifies the field that represents the memory address of the timer subqueue + ThreadpoolTimerSubqueue = "threadpool.timer.subqueue" + // ThreadpoolTimer identifies the field that represents the memory address of the timer object + ThreadpoolTimer = "threadpool.timer.address" + // ThreadpoolTimerPeriod identifies the field that represents the period of the timer + ThreadpoolTimerPeriod = "threadpool.timer.period" + // ThreadpoolTimerWindow identifies the field that represents the timer tolerate period + ThreadpoolTimerWindow = "threadpool.timer.window" + // ThreadpoolTimerAbsolute identifies the field that indicates if the timer is absolute or relative + ThreadpoolTimerAbsolute = "threadpool.timer.is_absolute" ) // String casts the field type to string. func (f Field) String() string { return string(f) } -func (f Field) IsPsField() bool { return strings.HasPrefix(string(f), "ps.") } -func (f Field) IsKevtField() bool { return strings.HasPrefix(string(f), "kevt.") } -func (f Field) IsThreadField() bool { return strings.HasPrefix(string(f), "thread.") } -func (f Field) IsImageField() bool { return strings.HasPrefix(string(f), "image.") } -func (f Field) IsFileField() bool { return strings.HasPrefix(string(f), "file.") } -func (f Field) IsRegistryField() bool { return strings.HasPrefix(string(f), "registry.") } -func (f Field) IsNetworkField() bool { return strings.HasPrefix(string(f), "net.") } -func (f Field) IsHandleField() bool { return strings.HasPrefix(string(f), "handle.") } -func (f Field) IsPeField() bool { return strings.HasPrefix(string(f), "pe.") || f == PsChildPeFilename } -func (f Field) IsMemField() bool { return strings.HasPrefix(string(f), "mem.") } -func (f Field) IsDNSField() bool { return strings.HasPrefix(string(f), "dns.") } +func (f Field) IsPsField() bool { return strings.HasPrefix(string(f), "ps.") } +func (f Field) IsKevtField() bool { return strings.HasPrefix(string(f), "kevt.") } +func (f Field) IsThreadField() bool { return strings.HasPrefix(string(f), "thread.") } +func (f Field) IsImageField() bool { return strings.HasPrefix(string(f), "image.") } +func (f Field) IsFileField() bool { return strings.HasPrefix(string(f), "file.") } +func (f Field) IsRegistryField() bool { return strings.HasPrefix(string(f), "registry.") } +func (f Field) IsNetworkField() bool { return strings.HasPrefix(string(f), "net.") } +func (f Field) IsHandleField() bool { return strings.HasPrefix(string(f), "handle.") } +func (f Field) IsPeField() bool { return strings.HasPrefix(string(f), "pe.") || f == PsChildPeFilename } +func (f Field) IsMemField() bool { return strings.HasPrefix(string(f), "mem.") } +func (f Field) IsDNSField() bool { return strings.HasPrefix(string(f), "dns.") } +func (f Field) IsThreadpoolField() bool { return strings.HasPrefix(string(f), "threadpool.") } func (f Field) IsPeSection() bool { return f == PeNumSections } func (f Field) IsPeSymbol() bool { return f == PeSymbols || f == PeNumSymbols || f == PeImports } @@ -966,6 +1000,23 @@ var fields = map[Field]FieldInfo{ DNSOptions: {DNSOptions, "dns query options", kparams.Flags64, []string{"dns.options in ('ADDRCONFIG', 'DUAL_ADDR')"}, nil, nil}, DNSRcode: {DNSRR, "dns response status", kparams.AnsiString, []string{"dns.rcode = 'NXDOMAIN'"}, nil, nil}, DNSAnswers: {DNSAnswers, "dns response answers", kparams.Slice, []string{"dns.answers in ('o.lencr.edgesuite.net', 'a1887.dscq.akamai.net')"}, nil, nil}, + + ThreadpoolPoolID: {ThreadpoolPoolID, "thread pool identifier", kparams.Address, []string{"threadpool.id = '20f5fc02440'"}, nil, nil}, + ThreadpoolTaskID: {ThreadpoolTaskID, "thread pool task identifier", kparams.Address, []string{"threadpool.task.id = '20f7ecd21f8'"}, nil, nil}, + ThreadpoolCallbackAddress: {ThreadpoolCallbackAddress, "thread pool callback address", kparams.Address, []string{"threadpool.callback.address = '7ff868739ed0'"}, nil, nil}, + ThreadpoolCallbackSymbol: {ThreadpoolCallbackSymbol, "thread pool callback symbol", kparams.UnicodeString, []string{"threadpool.callback.symbol = 'RtlDestroyQueryDebugBuffer'"}, nil, nil}, + ThreadpoolCallbackModule: {ThreadpoolCallbackModule, "thread pool module containing the callback symbol", kparams.UnicodeString, []string{"threadpool.callback.module contains 'ntdll.dll'"}, nil, nil}, + ThreadpoolCallbackContext: {ThreadpoolCallbackContext, "thread pool callback context address", kparams.Address, []string{"threadpool.callback.context = '1df41e07bd0'"}, nil, nil}, + ThreadpoolCallbackContextRip: {ThreadpoolCallbackContextRip, "thread pool callback thread context instruction pointer", kparams.Address, []string{"threadpool.callback.context.rip = '1df42ffc1f8'"}, nil, nil}, + ThreadpoolCallbackContextRipSymbol: {ThreadpoolCallbackContextRipSymbol, "thread pool callback thread context instruction pointer symbol", kparams.UnicodeString, []string{"threadpool.callback.context.rip.symbol = 'VirtualProtect'"}, nil, nil}, + ThreadpoolCallbackContextRipModule: {ThreadpoolCallbackContextRipModule, "thread pool callback thread context instruction pointer symbol module", kparams.UnicodeString, []string{"threadpool.callback.context.rip.module contains 'ntdll.dll'"}, nil, nil}, + ThreadpoolSubprocessTag: {ThreadpoolSubprocessTag, "thread pool service identifier", kparams.Address, []string{"threadpool.subprocess_tag = '10d'"}, nil, nil}, + ThreadpoolTimerDuetime: {ThreadpoolTimerDuetime, "thread pool timer due time", kparams.Uint64, []string{"threadpool.timer.duetime > 10"}, nil, nil}, + ThreadpoolTimerSubqueue: {ThreadpoolTimerSubqueue, "thread pool timer subqueue address", kparams.Address, []string{"threadpool.timer.subqueue = '1db401703e8'"}, nil, nil}, + ThreadpoolTimer: {ThreadpoolTimer, "thread pool timer address", kparams.Address, []string{"threadpool.timer.address = '3e8'"}, nil, nil}, + ThreadpoolTimerPeriod: {ThreadpoolTimerPeriod, "thread pool timer period", kparams.Uint32, []string{"threadpool.timer.period = 0'"}, nil, nil}, + ThreadpoolTimerWindow: {ThreadpoolTimerWindow, "thread pool timer tolerate period", kparams.Uint32, []string{"threadpool.timer.window = 0'"}, nil, nil}, + ThreadpoolTimerAbsolute: {ThreadpoolTimerAbsolute, "indicates if the thread pool timer is absolute or relative", kparams.Bool, []string{"threadpool.timer.is_absolute = true'"}, nil, nil}, } // ArgumentOf returns argument data for the specified field. diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index 6e2eae2da..1a64e3636 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -47,14 +47,15 @@ import ( var cfg = &config.Config{ Kstream: config.KstreamConfig{ - EnableHandleKevents: true, - EnableNetKevents: true, - EnableRegistryKevents: true, - EnableFileIOKevents: true, - EnableImageKevents: true, - EnableThreadKevents: true, - EnableMemKevents: true, - EnableDNSEvents: true, + EnableHandleKevents: true, + EnableNetKevents: true, + EnableRegistryKevents: true, + EnableFileIOKevents: true, + EnableImageKevents: true, + EnableThreadKevents: true, + EnableMemKevents: true, + EnableDNSEvents: true, + EnableThreadpoolEvents: true, }, Filters: &config.Filters{}, PE: pe.Config{Enabled: true}, @@ -1218,6 +1219,71 @@ func TestDNSFilter(t *testing.T) { } } +func TestThreadpoolFilter(t *testing.T) { + e := &kevent.Kevent{ + Type: ktypes.SubmitThreadpoolCallback, + Tid: 2484, + PID: 1023, + CPU: 1, + Seq: 2, + Name: "SubmitThreadpoolCallback", + Timestamp: time.Now(), + Category: ktypes.Threadpool, + Kparams: kevent.Kparams{ + kparams.ThreadpoolPoolID: {Name: kparams.ThreadpoolPoolID, Type: kparams.Address, Value: uint64(0x20f5fc02440)}, + kparams.ThreadpoolTaskID: {Name: kparams.ThreadpoolTaskID, Type: kparams.Address, Value: uint64(0x20f7ecd21f8)}, + kparams.ThreadpoolCallback: {Name: kparams.ThreadpoolCallback, Type: kparams.Address, Value: uint64(0x7ffb3138592e)}, + kparams.ThreadpoolContext: {Name: kparams.ThreadpoolContext, Type: kparams.Address, Value: uint64(0x14d0d16fed8)}, + kparams.ThreadpoolContextRip: {Name: kparams.ThreadpoolContextRip, Type: kparams.Address, Value: uint64(0x143c9b07bd0)}, + kparams.ThreadpoolSubprocessTag: {Name: kparams.ThreadpoolSubprocessTag, Type: kparams.Address, Value: uint64(0x10d)}, + kparams.ThreadpoolContextRipSymbol: {Name: kparams.ThreadpoolContextRipSymbol, Type: kparams.UnicodeString, Value: "VirtualProtect"}, + kparams.ThreadpoolContextRipModule: {Name: kparams.ThreadpoolContextRipModule, Type: kparams.UnicodeString, Value: "C:\\Windows\\System32\\kernelbase.dll"}, + kparams.ThreadpoolCallbackSymbol: {Name: kparams.ThreadpoolCallbackSymbol, Type: kparams.UnicodeString, Value: "RtlDestroyQueryDebugBuffer"}, + kparams.ThreadpoolCallbackModule: {Name: kparams.ThreadpoolCallbackModule, Type: kparams.UnicodeString, Value: "C:\\Windows\\System32\\ntdll.dll"}, + kparams.ThreadpoolTimerSubqueue: {Name: kparams.ThreadpoolTimerSubqueue, Type: kparams.Address, Value: uint64(0x1db401703e8)}, + kparams.ThreadpoolTimerDuetime: {Name: kparams.ThreadpoolTimerDuetime, Type: kparams.Uint64, Value: uint64(18446744073699551616)}, + kparams.ThreadpoolTimer: {Name: kparams.ThreadpoolTimer, Type: kparams.Address, Value: uint64(0x3e8)}, + kparams.ThreadpoolTimerPeriod: {Name: kparams.ThreadpoolTimerPeriod, Type: kparams.Uint32, Value: uint32(100)}, + kparams.ThreadpoolTimerWindow: {Name: kparams.ThreadpoolTimerWindow, Type: kparams.Uint32, Value: uint32(50)}, + kparams.ThreadpoolTimerAbsolute: {Name: kparams.ThreadpoolTimerAbsolute, Type: kparams.Bool, Value: true}, + }, + } + + var tests = []struct { + filter string + matches bool + }{ + + {`threadpool.id = '20f5fc02440'`, true}, + {`threadpool.task.id = '20f7ecd21f8'`, true}, + {`threadpool.callback.address = '7ffb3138592e'`, true}, + {`threadpool.callback.symbol = 'RtlDestroyQueryDebugBuffer'`, true}, + {`threadpool.callback.module = 'C:\\Windows\\System32\\ntdll.dll'`, true}, + {`threadpool.callback.context = '14d0d16fed8'`, true}, + {`threadpool.callback.context.rip = '143c9b07bd0'`, true}, + {`threadpool.callback.context.rip.symbol = 'VirtualProtect'`, true}, + {`threadpool.callback.context.rip.module = 'C:\\Windows\\System32\\kernelbase.dll'`, true}, + {`threadpool.timer.address = '3e8'`, true}, + {`threadpool.timer.subqueue = '1db401703e8'`, true}, + {`threadpool.timer.duetime = 18446744073699551616`, true}, + {`threadpool.timer.period = 100`, true}, + {`threadpool.timer.window = 50`, true}, + {`threadpool.timer.is_absolute = true`, true}, + } + + for i, tt := range tests { + f := New(tt.filter, cfg) + err := f.Compile() + if err != nil { + t.Fatal(err) + } + matches := f.Run(e) + if matches != tt.matches { + t.Errorf("%d. %q threadpool filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) + } + } +} + func TestInterpolateFields(t *testing.T) { var tests = []struct { original string diff --git a/pkg/filter/filter_windows.go b/pkg/filter/filter_windows.go index dae664161..12cc19b17 100644 --- a/pkg/filter/filter_windows.go +++ b/pkg/filter/filter_windows.go @@ -84,6 +84,9 @@ func New(expr string, config *config.Config, options ...Option) Filter { if kconfig.EnableDNSEvents { accessors = append(accessors, newDNSAccessor()) } + if kconfig.EnableThreadpoolEvents { + accessors = append(accessors, newThreadpoolAccessor()) + } var parser *ql.Parser if fconfig.HasMacros() { diff --git a/pkg/filter/ql/function.go b/pkg/filter/ql/function.go index d6f54630f..799aac899 100644 --- a/pkg/filter/ql/function.go +++ b/pkg/filter/ql/function.go @@ -297,17 +297,18 @@ func (f *Foreach) Desc() functions.FunctionDesc { e := args[2] // expression var reserved = map[string]bool{ // reserved bound variable names - "$ps": true, - "$pe": true, - "$file": true, - "$image": true, - "$thread": true, - "$registry": true, - "$net": true, - "$mem": true, - "$handle": true, - "$dns": true, - "$kevt": true, + "$ps": true, + "$pe": true, + "$file": true, + "$image": true, + "$thread": true, + "$threadpool": true, + "$registry": true, + "$net": true, + "$mem": true, + "$handle": true, + "$dns": true, + "$kevt": true, } if reserved[v] {