From beecb037f257cc1b6b3e8e4106a099bc01b079a9 Mon Sep 17 00:00:00 2001 From: rabbitstack Date: Fri, 26 Dec 2025 18:34:59 +0100 Subject: [PATCH] perf(rule_engine,filter): Improve bound sequences Refactor bound sequence evaluation logic to speed it up, most notably, by deferring the field hash calculation only when the event matches. Furthermore, the accessor is tied to the bound field avoiding iteration across filter's registered accessors. --- pkg/filter/accessor.go | 8 +- pkg/filter/fields/fields_windows.go | 3 +- pkg/filter/filter.go | 322 ++++++++++++---------------- pkg/filter/util.go | 126 +++++++++++ 4 files changed, 270 insertions(+), 189 deletions(-) diff --git a/pkg/filter/accessor.go b/pkg/filter/accessor.go index f1a4c0063..dc09675c2 100644 --- a/pkg/filter/accessor.go +++ b/pkg/filter/accessor.go @@ -153,7 +153,7 @@ func (*evtAccessor) Get(f Field, evt *event.Event) (params.Value, error) { // referenced in the bound field. func (f *filter) narrowAccessors() { var ( - removeKevtAccessor = true + removeEvtAccessor = true removePsAccessor = true removeThreadAccessor = true removeImageAccessor = true @@ -169,8 +169,8 @@ func (f *filter) narrowAccessors() { for _, field := range f.fields { switch { - case field.Name.IsKevtField(): - removeKevtAccessor = false + case field.Name.IsKevtField(), field.Name.IsEvtField(): + removeEvtAccessor = false case field.Name.IsPsField(): removePsAccessor = false case field.Name.IsThreadField(): @@ -196,7 +196,7 @@ func (f *filter) narrowAccessors() { } } - if removeKevtAccessor { + if removeEvtAccessor { f.removeAccessor(&evtAccessor{}) } if removePsAccessor { diff --git a/pkg/filter/fields/fields_windows.go b/pkg/filter/fields/fields_windows.go index a55819043..ec3ab55d5 100644 --- a/pkg/filter/fields/fields_windows.go +++ b/pkg/filter/fields/fields_windows.go @@ -563,7 +563,8 @@ func (f Field) String() string { return string(f) } func (f Field) Type() params.Type { return fields[f].Type } func (f Field) IsPsField() bool { return strings.HasPrefix(string(f), "ps.") } -func (f Field) IsKevtField() bool { return strings.HasPrefix(string(f), "evt.") } +func (f Field) IsKevtField() bool { return strings.HasPrefix(string(f), "kevt.") } +func (f Field) IsEvtField() bool { return strings.HasPrefix(string(f), "evt.") } func (f Field) IsThreadField() bool { return strings.HasPrefix(string(f), "thread.") } func (f Field) IsImageField() bool { return strings.HasPrefix(string(f), "image.") } func (f Field) IsFileField() bool { return strings.HasPrefix(string(f), "file.") } diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index 49a6b5705..80ee86407 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -22,7 +22,7 @@ import ( "errors" "expvar" "fmt" - "net" + "reflect" "regexp" "strconv" "strings" @@ -32,7 +32,6 @@ import ( "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/filter/fields" "github.com/rabbitstack/fibratus/pkg/filter/ql" - "github.com/rabbitstack/fibratus/pkg/util/bytes" "github.com/rabbitstack/fibratus/pkg/util/hashers" ) @@ -78,6 +77,49 @@ type BoundField struct { Field Field Value string BoundVar string + accessor Accessor +} + +// Accessor finds exactly one accessor that can serve the bound field. +func (b *BoundField) Accessor(f *filter) Accessor { + if b.accessor != nil { + return b.accessor + } + switch { + case b.Field.Name.IsKevtField(), b.Field.Name.IsEvtField(): + b.accessor = newEventAccessor() + case b.Field.Name.IsPsField(): + for _, accessor := range f.accessors { + if reflect.TypeOf(accessor) == reflect.TypeOf(&psAccessor{}) { + b.accessor = accessor + break + } + } + if b.accessor == nil { + b.accessor = newPSAccessor(nil) + } + case b.Field.Name.IsThreadField(): + b.accessor = newThreadAccessor() + case b.Field.Name.IsImageField(): + b.accessor = newImageAccessor() + case b.Field.Name.IsFileField(): + b.accessor = newFileAccessor() + case b.Field.Name.IsRegistryField(): + b.accessor = newRegistryAccessor() + case b.Field.Name.IsNetworkField(): + b.accessor = newNetworkAccessor() + case b.Field.Name.IsHandleField(): + b.accessor = newHandleAccessor() + case b.Field.Name.IsPeField(): + b.accessor = newPEAccessor() + case b.Field.Name.IsMemField(): + b.accessor = newMemAccessor() + case b.Field.Name.IsDNSField(): + b.accessor = newDNSAccessor() + case b.Field.Name.IsThreadpoolField(): + b.accessor = newThreadAccessor() + } + return b.accessor } type filter struct { @@ -191,6 +233,92 @@ func (f *filter) Run(e *event.Event) bool { return ql.Eval(f.expr, f.mapValuer(e), f.hasFunctions) } +// evalBoundSequence evaluates the sequence with bound fields +// and returns true if the sequence expression matches or false +// otherwise. +func (f *filter) evalBoundSequence( + e *event.Event, + seqID int, + expr *ql.SequenceExpr, + partials map[int][]*event.Event, + valuer ql.MapValuer, +) bool { + // map all partials to their sequence aliases + maxSlots := len(partials[seqID]) + aliasEvents := make(map[string][]*event.Event, seqID) + for i := range seqID { + alias := f.seq.Expressions[i].Alias + if alias == "" { + continue + } + aliasEvents[alias] = partials[i] + if l := len(partials[i]); l > maxSlots { + maxSlots = l + } + } + + // retrieve or compute bound fields for this sequence expression + flds, ok := f.seqBoundFields[seqID] + if !ok { + flds = f.addSeqBoundFields(seqID, expr.BoundFields) + } + + // iterate slot-by-slot across all bound aliases + for slot := 0; slot < maxSlots; slot++ { + // process each bound field in this sequence expression + var evt *event.Event + for _, fld := range flds { + evts := aliasEvents[fld.BoundVar] + switch { + case len(evts) == 0: + continue + case slot >= len(evts): + // pick the latest event if all + // events for this slot are consumed + evt = evts[len(evts)-1] + default: + evt = evts[slot] + } + + // extract bound variable value + accessor := fld.Accessor(f) + if accessor == nil { + continue + } + v, err := accessor.Get(fld.Field, evt) + if v == nil || err != nil { + if v == nil { + valuer[fld.Value] = defaultAccessorValue(fld.Field) + } + if err != nil && !errs.IsParamNotFound(err) { + valuer[fld.Value] = defaultAccessorValue(fld.Field) + accessorErrors.Add(err.Error(), 1) + } + continue + } + valuer[fld.Value] = v + } + + // evaluate the expression with the current valuer state + if ql.Eval(expr.Expr, valuer, f.hasFunctions) { + // compute sequence key hash to stich events + hash := make([]byte, 0) + for _, fld := range flds { + if !strings.HasPrefix(fld.BoundVar, "$") { + continue + } + hash = appendHash(hash, valuer[fld.Value]) + } + fnv := hashers.FnvUint64(hash) + e.AddSequenceLink(fnv) + evt.AddSequenceLink(fnv) + return true + } + } + + return false +} + func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*event.Event, rawMatch bool) bool { if f.seq == nil { return false @@ -211,96 +339,10 @@ func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*even var match bool if seqID >= 1 && expr.HasBoundFields() { - // if a sequence expression contains references to - // bound fields we map all partials to their sequence - // aliases - p := make(map[string][]*event.Event) - nslots := len(partials[seqID]) - for i := 0; i < seqID; i++ { - alias := f.seq.Expressions[i].Alias - if alias == "" { - continue - } - p[alias] = partials[i] - if len(p[alias]) > nslots { - nslots = len(p[alias]) - } - } - - flds, ok := f.seqBoundFields[seqID] - if !ok { - flds = f.addSeqBoundFields(seqID, expr.BoundFields) - } - - // process until partials from all slots are consumed - n := 0 - hash := make([]byte, 0) - for nslots > 0 { - nslots-- - var evt *event.Event - for _, field := range flds { - // get all events pertaining to the bounded event - evts := p[field.BoundVar] - if n > len(evts)-1 { - // pick the latest event if all - // events for this slot are consumed - evt = evts[len(evts)-1] - } else { - evt = evts[n] - } - - // resolve the bound field value - for _, accessor := range f.accessors { - if !accessor.IsFieldAccessible(evt) { - continue - } - v, err := accessor.Get(field.Field, evt) - if err != nil && !errs.IsParamNotFound(err) { - accessorErrors.Add(err.Error(), 1) - continue - } - if v != nil { - valuer[field.Value] = v - switch val := v.(type) { - case uint8: - hash = append(hash, val) - case uint16: - hash = append(hash, bytes.WriteUint16(val)...) - case uint32: - hash = append(hash, bytes.WriteUint32(val)...) - case uint64: - hash = append(hash, bytes.WriteUint64(val)...) - case int8: - hash = append(hash, byte(val)) - case int16: - hash = append(hash, bytes.WriteUint16(uint16(val))...) - case int32: - hash = append(hash, bytes.WriteUint32(uint32(val))...) - case int64: - hash = append(hash, bytes.WriteUint64(uint64(val))...) - case int: - hash = append(hash, bytes.WriteUint64(uint64(val))...) - case uint: - hash = append(hash, bytes.WriteUint64(uint64(val))...) - case string: - hash = append(hash, val...) - case net.IP: - hash = append(hash, val...) - } - break - } - } - } - n++ - match = ql.Eval(expr.Expr, valuer, f.hasFunctions) - if match { - // compute sequence key hash to tie the events - evt.AddSequenceLink(hashers.FnvUint64(hash)) - e.AddSequenceLink(hashers.FnvUint64(hash)) - break - } - } + // evaluate bound field driven sequences + match = f.evalBoundSequence(e, seqID, &expr, partials, valuer) } else { + // evaluate constrained/unconstrained sequences by := f.seq.By if by == nil { by = expr.By @@ -311,7 +353,7 @@ func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*even joins := make([]bool, seqID) joinID := valuer[by.Value] outer: - for i := 0; i < seqID; i++ { + for i := range seqID { for _, p := range partials[i] { if CompareSeqLink(joinID, p.SequenceLinks()) { joins[i] = true @@ -433,8 +475,11 @@ func (f *filter) mapValuer(evt *event.Event) map[string]any { } v, err := accessor.Get(field, evt) if v == nil || err != nil { - valuer[field.Value] = defaultAccessorValue(field) + if v == nil { + valuer[field.Value] = defaultAccessorValue(field) + } if err != nil && !errs.IsParamNotFound(err) { + valuer[field.Value] = defaultAccessorValue(field) accessorErrors.Add(err.Error(), 1) } continue @@ -519,94 +564,3 @@ func (f *filter) checkBoundRefs() error { return nil } - -// CompareSeqLink returns true if any value -// in the sequence link slice equals to the -// given LHS value. -func CompareSeqLink(lhs any, rhs []any) bool { - if lhs == nil || rhs == nil { - return false - } - for _, v := range rhs { - if compareSeqLink(lhs, v) { - return true - } - } - return false -} - -// CompareSeqLinks returns true any LHS sequence -// link values equal to the RHS sequence link values. -func CompareSeqLinks(lhs []any, rhs []any) bool { - if lhs == nil || rhs == nil { - return false - } - for _, v1 := range lhs { - for _, v2 := range rhs { - if compareSeqLink(v1, v2) { - return true - } - } - } - return false -} - -func compareSeqLink(lhs any, rhs any) bool { - if lhs == nil || rhs == nil { - return false - } - - switch v := lhs.(type) { - case string: - s, ok := rhs.(string) - if !ok { - return false - } - return strings.EqualFold(v, s) - case uint8: - n, ok := rhs.(uint8) - if !ok { - return false - } - return v == n - case uint16: - n, ok := rhs.(uint16) - if !ok { - return false - } - return v == n - case uint32: - n, ok := rhs.(uint32) - if !ok { - return false - } - return v == n - case uint64: - n, ok := rhs.(uint64) - if !ok { - return false - } - if v == n { - return true - } - case int: - n, ok := rhs.(int) - if !ok { - return false - } - return v == n - case uint: - n, ok := rhs.(uint) - if !ok { - return false - } - return v == n - case net.IP: - ip, ok := rhs.(net.IP) - if !ok { - return false - } - return v.Equal(ip) - } - return false -} diff --git a/pkg/filter/util.go b/pkg/filter/util.go index 5f2a89f94..98e694f9b 100644 --- a/pkg/filter/util.go +++ b/pkg/filter/util.go @@ -19,11 +19,14 @@ package filter import ( + "net" "path/filepath" + "strings" "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/event/params" "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/rabbitstack/fibratus/pkg/util/bytes" "github.com/rabbitstack/fibratus/pkg/util/loldrivers" "github.com/rabbitstack/fibratus/pkg/util/signature" "github.com/rabbitstack/fibratus/pkg/util/va" @@ -113,3 +116,126 @@ func framePID(e *event.Event) uint32 { } return e.PID } + +// CompareSeqLink returns true if any value +// in the sequence link slice equals to the +// given LHS value. +func CompareSeqLink(lhs any, rhs []any) bool { + if lhs == nil || rhs == nil { + return false + } + for _, v := range rhs { + if compareSeqLink(lhs, v) { + return true + } + } + return false +} + +// CompareSeqLinks returns true any LHS sequence +// link values equal to the RHS sequence link values. +func CompareSeqLinks(lhs []any, rhs []any) bool { + if lhs == nil || rhs == nil { + return false + } + for _, v1 := range lhs { + for _, v2 := range rhs { + if compareSeqLink(v1, v2) { + return true + } + } + } + return false +} + +func compareSeqLink(lhs any, rhs any) bool { + if lhs == nil || rhs == nil { + return false + } + + switch v := lhs.(type) { + case string: + s, ok := rhs.(string) + if !ok { + return false + } + return strings.EqualFold(v, s) + case uint8: + n, ok := rhs.(uint8) + if !ok { + return false + } + return v == n + case uint16: + n, ok := rhs.(uint16) + if !ok { + return false + } + return v == n + case uint32: + n, ok := rhs.(uint32) + if !ok { + return false + } + return v == n + case uint64: + n, ok := rhs.(uint64) + if !ok { + return false + } + if v == n { + return true + } + case int: + n, ok := rhs.(int) + if !ok { + return false + } + return v == n + case uint: + n, ok := rhs.(uint) + if !ok { + return false + } + return v == n + case net.IP: + ip, ok := rhs.(net.IP) + if !ok { + return false + } + return v.Equal(ip) + } + return false +} + +// appendHash appends the value's hashable bytes to buf. +func appendHash(buf []byte, v any) []byte { + switch val := v.(type) { + case uint8: + return append(buf, val) + case uint16: + return append(buf, bytes.WriteUint16(val)...) + case uint32: + return append(buf, bytes.WriteUint32(val)...) + case uint64: + return append(buf, bytes.WriteUint64(val)...) + case int8: + return append(buf, byte(val)) + case int16: + return append(buf, bytes.WriteUint16(uint16(val))...) + case int32: + return append(buf, bytes.WriteUint32(uint32(val))...) + case int64: + return append(buf, bytes.WriteUint64(uint64(val))...) + case int: + return append(buf, bytes.WriteUint64(uint64(val))...) + case uint: + return append(buf, bytes.WriteUint64(uint64(val))...) + case string: + return append(buf, val...) + case net.IP: + return append(buf, val...) + default: + return buf + } +}