diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index 80ee86407..b433e3527 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -32,7 +32,6 @@ import ( "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/filter/fields" "github.com/rabbitstack/fibratus/pkg/filter/ql" - "github.com/rabbitstack/fibratus/pkg/util/hashers" ) var ( @@ -207,12 +206,17 @@ func (f *filter) Compile() error { ql.WalkFunc(f.expr, walk) } else { if f.seq.By != nil { - f.addField(f.seq.By) + for _, fld := range f.seq.By.Fields { + f.addField(fld) + } } for _, expr := range f.seq.Expressions { ql.WalkFunc(expr.Expr, walk) - if expr.By != nil { - f.addField(expr.By) + if expr.By == nil { + continue + } + for _, fld := range expr.By.Fields { + f.addField(fld) } } } @@ -302,16 +306,16 @@ func (f *filter) evalBoundSequence( // evaluate the expression with the current valuer state if ql.Eval(expr.Expr, valuer, f.hasFunctions) { // compute sequence key hash to stich events - hash := make([]byte, 0) + values := make([]any, 0) for _, fld := range flds { if !strings.HasPrefix(fld.BoundVar, "$") { continue } - hash = appendHash(hash, valuer[fld.Value]) + values = append(values, valuer[fld.Value]) } - fnv := hashers.FnvUint64(hash) - e.AddSequenceLink(fnv) - evt.AddSequenceLink(fnv) + hash := hashFields(values) + e.AddSequenceLink(hash) + evt.AddSequenceLink(hash) return true } } @@ -319,6 +323,50 @@ func (f *filter) evalBoundSequence( return false } +// evalSequence evaluates the sequence with one, multiple or +// no join links. The sequence link is first consulted for the +// global sequence definition, and if it is not defined then +// the expression sequence link is used. +func (f *filter) evalSequence( + e *event.Event, + seqID int, + expr *ql.SequenceExpr, + partials map[int][]*event.Event, + valuer ql.MapValuer, +) bool { + // top-level sequence link is defined + by := f.seq.By + if by == nil { + // otherwise, use the expression link + by = expr.By + } + + var match bool + if seqID >= 1 && by != nil { + linkID := makeSequenceLinkID(valuer, by) + // traverse upstream partials for join equality + joins := make([]bool, seqID) + outer: + for i := range seqID { + for _, p := range partials[i] { + if CompareSeqLink(linkID, p.SequenceLinks()) { + joins[i] = true + continue outer + } + } + } + match = joinsEqual(joins) && ql.Eval(expr.Expr, valuer, f.hasFunctions) + } else { + match = ql.Eval(expr.Expr, valuer, f.hasFunctions) + } + + if match && by != nil { + e.AddSequenceLink(makeSequenceLinkID(valuer, by)) + } + + return match +} + func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*event.Event, rawMatch bool) bool { if f.seq == nil { return false @@ -343,45 +391,10 @@ func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*even match = f.evalBoundSequence(e, seqID, &expr, partials, valuer) } else { // evaluate constrained/unconstrained sequences - by := f.seq.By - if by == nil { - by = expr.By - } - - if seqID >= 1 && by != nil { - // traverse upstream partials for join equality - joins := make([]bool, seqID) - joinID := valuer[by.Value] - outer: - for i := range seqID { - for _, p := range partials[i] { - if CompareSeqLink(joinID, p.SequenceLinks()) { - joins[i] = true - continue outer - } - } - } - match = joinsEqual(joins) && ql.Eval(expr.Expr, valuer, f.hasFunctions) - } else { - match = ql.Eval(expr.Expr, valuer, f.hasFunctions) - } - - if match && by != nil { - if v := valuer[by.Value]; v != nil { - e.AddSequenceLink(v) - } - } + match = f.evalSequence(e, seqID, &expr, partials, valuer) } - return match -} -func joinsEqual(joins []bool) bool { - for _, j := range joins { - if !j { - return false - } - } - return true + return match } func (f *filter) GetStringFields() map[fields.Field][]string { return f.stringFields } @@ -564,3 +577,14 @@ func (f *filter) checkBoundRefs() error { return nil } + +func makeSequenceLinkID(valuer ql.MapValuer, link *ql.SequenceLink) any { + if !link.IsCompound() { + return valuer[link.First()] + } + values := make([]any, 0, len(link.Fields)) + for _, fld := range link.Fields { + values = append(values, valuer[fld.Value]) + } + return hashFields(values) +} diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index 9d17d7b87..cb00d0f5c 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -19,6 +19,7 @@ package filter import ( + "fmt" "net" "os" "path/filepath" @@ -33,6 +34,7 @@ import ( "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/event/params" "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/rabbitstack/fibratus/pkg/filter/ql" "github.com/rabbitstack/fibratus/pkg/fs" "github.com/rabbitstack/fibratus/pkg/pe" "github.com/rabbitstack/fibratus/pkg/ps" @@ -110,6 +112,45 @@ func TestStringFields(t *testing.T) { assert.Len(t, f.GetStringFields()[fields.PsName], 1) } +func TestMakeSequenceLinkID(t *testing.T) { + var tests = []struct { + valuer ql.MapValuer + seqLink *ql.SequenceLink + id any + }{ + {ql.MapValuer{ + "ps.uuid": uint64(123232454234232132), + "ps.exe": "C:\\Windows\\System32\\cmd.exe"}, + &ql.SequenceLink{Fields: []*ql.FieldLiteral{{Value: "ps.exe"}, {Value: "ps.uuid"}}}, + "433a5c57696e646f77735c53797374656d33325c636d642e65786544556ea343cfb501", + }, + {ql.MapValuer{ + "ps.uuid": uint64(123232454234232132), + "module.address": uint64(0xfff32343)}, + &ql.SequenceLink{Fields: []*ql.FieldLiteral{{Value: "ps.uuid"}, {Value: "module.address"}}}, + "44556ea343cfb5014323f3ff00000000", + }, + {ql.MapValuer{ + "ps.uuid": uint64(123232454234232132), + "ps.exe": "C:\\Windows\\System32\\cmd.exe"}, + &ql.SequenceLink{Fields: []*ql.FieldLiteral{{Value: "ps.exe"}}}, + "C:\\Windows\\System32\\cmd.exe", + }, + {ql.MapValuer{ + "ps.uuid": uint64(123232454234232132), + "ps.exe": "C:\\Windows\\System32\\cmd.exe"}, + &ql.SequenceLink{Fields: []*ql.FieldLiteral{{Value: "ps.uuid"}}}, + uint64(123232454234232132), + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.valuer), func(t *testing.T) { + assert.Equal(t, tt.id, makeSequenceLinkID(tt.valuer, tt.seqLink)) + }) + } +} + func TestProcFilter(t *testing.T) { parent := &pstypes.PS{ Name: "svchost.exe", diff --git a/pkg/filter/ql/literal.go b/pkg/filter/ql/literal.go index d102d851e..99aac363a 100644 --- a/pkg/filter/ql/literal.go +++ b/pkg/filter/ql/literal.go @@ -19,14 +19,15 @@ package ql import ( - "github.com/rabbitstack/fibratus/pkg/event" - "github.com/rabbitstack/fibratus/pkg/filter/fields" "net" "reflect" "strconv" "strings" "time" + "github.com/rabbitstack/fibratus/pkg/event" + "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/rabbitstack/fibratus/pkg/filter/ql/functions" ) @@ -271,11 +272,11 @@ func (f *Function) validate() error { // SequenceExpr represents a single binary expression within the sequence. type SequenceExpr struct { Expr Expr - // By contains the field literal if the sequence expression is constrained. - By *FieldLiteral + // By contains the expression link if the sequence is constrained. + By *SequenceLink // BoundFields is a group of bound fields referenced in the sequence expression. BoundFields []*BoundFieldLiteral - // Alias represents the sequence expression alias. + // Alias represents the sequence expression alias when bound fields are used. Alias string bitsets event.BitSets @@ -381,10 +382,31 @@ func (e *SequenceExpr) HasBoundFields() bool { return len(e.BoundFields) > 0 } +// SequenceLink represents a single or +// a collection of fields that are used to +// build the sequence join link. +type SequenceLink struct { + Fields []*FieldLiteral +} + +// IsCompound indicates if the sequence expression +// uses multiple fields for the join link. +func (l *SequenceLink) IsCompound() bool { + return len(l.Fields) > 1 +} + +// First returns the first field if the link is not compound. +func (l *SequenceLink) First() string { + if len(l.Fields) == 1 { + return l.Fields[0].Value + } + return "" +} + // Sequence is a collection of two or more sequence expressions. type Sequence struct { MaxSpan time.Duration - By *FieldLiteral + By *SequenceLink Expressions []SequenceExpr IsUnordered bool } diff --git a/pkg/filter/ql/parser.go b/pkg/filter/ql/parser.go index 536e20578..99399de4d 100644 --- a/pkg/filter/ql/parser.go +++ b/pkg/filter/ql/parser.go @@ -23,13 +23,14 @@ package ql import ( "errors" "fmt" - "github.com/rabbitstack/fibratus/pkg/config" - "github.com/rabbitstack/fibratus/pkg/filter/fields" - "github.com/rabbitstack/fibratus/pkg/util/multierror" "net" "strconv" "strings" "time" + + "github.com/rabbitstack/fibratus/pkg/config" + "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/rabbitstack/fibratus/pkg/util/multierror" ) // Parser builds the binary expression tree from the filter string. @@ -71,7 +72,7 @@ func (p *Parser) ParseSequence() (*Sequence, error) { p.unscan() } - // parse optional global join + // parse optional global link tok, _, _ = p.scanIgnoreWhitespace() if tok == By { tok, pos, lit := p.scanIgnoreWhitespace() @@ -79,10 +80,33 @@ func (p *Parser) ParseSequence() (*Sequence, error) { return nil, newParseError(tokstr(tok, lit), []string{"field"}, pos, p.expr) } var err error - seq.By, err = p.parseField(lit) + field, err := p.parseField(lit) if err != nil { return nil, err } + + seqLink := &SequenceLink{Fields: []*FieldLiteral{field}} + + // handle multiple join fields separated by comma + for { + if tok, _, _ := p.scanIgnoreWhitespace(); tok != Comma { + p.unscan() + break + } + + tok, pos, lit := p.scanIgnoreWhitespace() + if !fields.IsField(lit) { + return nil, newParseError(tokstr(tok, lit), []string{"field"}, pos, p.expr) + } + field, err := p.parseField(lit) + if err != nil { + return nil, err + } + + seqLink.Fields = append(seqLink.Fields, field) + } + + seq.By = seqLink } else { p.unscan() } @@ -127,7 +151,7 @@ func (p *Parser) ParseSequence() (*Sequence, error) { var seqexpr SequenceExpr - // parse sequence BY or AS constraints + // parse sequence BY or AS constraints (links) tok, _, _ = p.scanIgnoreWhitespace() switch tok { case By: @@ -139,7 +163,28 @@ func (p *Parser) ParseSequence() (*Sequence, error) { if err != nil { return nil, err } - seqexpr = SequenceExpr{Expr: expr, By: field} + + seqLink := &SequenceLink{Fields: []*FieldLiteral{field}} + + // handle multiple join fields separated by comma + for { + if tok, _, _ := p.scanIgnoreWhitespace(); tok != Comma { + p.unscan() + break + } + + tok, pos, lit := p.scanIgnoreWhitespace() + if !fields.IsField(lit) { + return nil, newParseError(tokstr(tok, lit), []string{"field"}, pos, p.expr) + } + field, err := p.parseField(lit) + if err != nil { + return nil, err + } + + seqLink.Fields = append(seqLink.Fields, field) + } + seqexpr = SequenceExpr{Expr: expr, By: seqLink} case As: tok, pos, lit := p.scanIgnoreWhitespace() if tok != Ident { diff --git a/pkg/filter/ql/parser_test.go b/pkg/filter/ql/parser_test.go index 0875d77a1..c8e2b1669 100644 --- a/pkg/filter/ql/parser_test.go +++ b/pkg/filter/ql/parser_test.go @@ -20,6 +20,8 @@ package ql import ( "errors" + "fmt" + "strings" "testing" "time" @@ -271,6 +273,31 @@ func TestParseSequence(t *testing.T) { time.Duration(0), true, }, + { + `|evt.name = 'CreateProcess'| by ps.exe, ps.uuid + |evt.name = 'CreateFile'| by file.name, ps.uuid + `, + nil, + time.Duration(0), + true, + }, + { + `by ps.exe, ps.uuid + |evt.name = 'CreateProcess'| + |evt.name = 'CreateFile'| + `, + nil, + time.Duration(0), + true, + }, + { + `|evt.name = 'CreateProcess'| by ps.exe, + |evt.name = 'CreateFile'| by file.name, ps.uuid + `, + errors.New("expected field"), + time.Duration(0), + true, + }, { `by ps.pid @@ -336,7 +363,7 @@ func TestParseSequence(t *testing.T) { |evt.name = 'CreateProcess'| as e1 |evt.name = 'CreateFile' and $e1.ps.ame = file.name | `, - errors.New("expected field after bound ref"), + errors.New("expected field/segment after bound ref"), time.Second * 30, false, }, @@ -352,8 +379,8 @@ func TestParseSequence(t *testing.T) { }, { - `by ps.uuid - maxspan 2m + `maxspan 2m + by ps.uuid |evt.name = 'CreateProcess'| by ps.uuid |evt.name = 'CreateFile'| by ps.uuid `, @@ -372,6 +399,10 @@ func TestParseSequence(t *testing.T) { t.Errorf("%d. exp=%s got error=\n%v", i, tt.expr, err) } + if err != nil && tt.err != nil { + assert.True(t, strings.Contains(err.Error(), tt.err.Error()), fmt.Sprintf("error '%v' should contain '%v'", err, tt.err)) + } + if seq != nil { if seq.MaxSpan != tt.maxSpan { t.Errorf("%d. exp=%s maxspan=%s got maxspan=%v", i, tt.expr, tt.maxSpan, seq.MaxSpan) diff --git a/pkg/filter/util.go b/pkg/filter/util.go index 98e694f9b..ccf4a05e7 100644 --- a/pkg/filter/util.go +++ b/pkg/filter/util.go @@ -19,6 +19,7 @@ package filter import ( + "encoding/hex" "net" "path/filepath" "strings" @@ -208,34 +209,45 @@ func compareSeqLink(lhs any, rhs any) bool { return false } -// appendHash appends the value's hashable bytes to buf. -func appendHash(buf []byte, v any) []byte { - switch val := v.(type) { - case uint8: - return append(buf, val) - case uint16: - return append(buf, bytes.WriteUint16(val)...) - case uint32: - return append(buf, bytes.WriteUint32(val)...) - case uint64: - return append(buf, bytes.WriteUint64(val)...) - case int8: - return append(buf, byte(val)) - case int16: - return append(buf, bytes.WriteUint16(uint16(val))...) - case int32: - return append(buf, bytes.WriteUint32(uint32(val))...) - case int64: - return append(buf, bytes.WriteUint64(uint64(val))...) - case int: - return append(buf, bytes.WriteUint64(uint64(val))...) - case uint: - return append(buf, bytes.WriteUint64(uint64(val))...) - case string: - return append(buf, val...) - case net.IP: - return append(buf, val...) - default: - return buf +// hashFields computes the hash of all field values. +func hashFields(values []any) string { + buf := make([]byte, 0) + for _, v := range values { + switch val := v.(type) { + case uint8: + buf = append(buf, val) + case uint16: + buf = append(buf, bytes.WriteUint16(val)...) + case uint32: + buf = append(buf, bytes.WriteUint32(val)...) + case uint64: + buf = append(buf, bytes.WriteUint64(val)...) + case int8: + buf = append(buf, byte(val)) + case int16: + buf = append(buf, bytes.WriteUint16(uint16(val))...) + case int32: + buf = append(buf, bytes.WriteUint32(uint32(val))...) + case int64: + buf = append(buf, bytes.WriteUint64(uint64(val))...) + case int: + buf = append(buf, bytes.WriteUint64(uint64(val))...) + case uint: + buf = append(buf, bytes.WriteUint64(uint64(val))...) + case string: + buf = append(buf, val...) + case net.IP: + buf = append(buf, val...) + } + } + return hex.EncodeToString(buf) +} + +func joinsEqual(joins []bool) bool { + for _, j := range joins { + if !j { + return false + } } + return true } diff --git a/pkg/rules/sequence_test.go b/pkg/rules/sequence_test.go index 3d47d610c..cbb0602cc 100644 --- a/pkg/rules/sequence_test.go +++ b/pkg/rules/sequence_test.go @@ -443,6 +443,56 @@ func TestSimpleSequenceDeadline(t *testing.T) { require.True(t, ss.runSequence(e2)) } +func TestSequenceMultiLinks(t *testing.T) { + log.SetLevel(log.DebugLevel) + + c := &config.FilterConfig{Name: "Command shell created a temp file"} + f := filter.New(` + sequence + maxspan 100ms + |evt.name = 'CreateProcess' and ps.name = 'cmd.exe'| by ps.exe, ps.pid + |evt.name = 'CreateFile' and file.path icontains 'temp'| by file.path, ps.pid + `, &config.Config{EventSource: config.EventSourceConfig{EnableFileIOEvents: true}, Filters: &config.Filters{}}) + require.NoError(t, f.Compile()) + + ss := newSequenceState(f, c, new(ps.SnapshotterMock)) + + e1 := &event.Event{ + Type: event.CreateProcess, + Timestamp: time.Now(), + Name: "CreateProcess", + Tid: 2484, + PID: 859, + PS: &pstypes.PS{ + Name: "cmd.exe", + Exe: "C:\\Windows\\system32\\svchost-temp.exe", + }, + Params: event.Params{ + params.ProcessID: {Name: params.ProcessID, Type: params.Uint32, Value: uint32(4143)}, + }, + Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, + } + require.False(t, ss.runSequence(e1)) + + e2 := &event.Event{ + Type: event.CreateFile, + Timestamp: time.Now(), + Name: "CreateFile", + Tid: 2484, + PID: 859, + Category: event.File, + PS: &pstypes.PS{ + Name: "cmd.exe", + Exe: "C:\\Windows\\system32\\svchost.exe", + }, + Params: event.Params{ + params.FilePath: {Name: params.FilePath, Type: params.UnicodeString, Value: "C:\\Windows\\system32\\svchost-temp.exe"}, + }, + Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, + } + require.True(t, ss.runSequence(e2)) +} + func TestComplexSequence(t *testing.T) { log.SetLevel(log.DebugLevel)