@@ -9,58 +9,60 @@ import (
99 "globstar.dev/analysis"
1010)
1111
12+ var AvoidUnsanitizedSQL = & analysis.Analyzer {
13+ Name : "avoid-unsanitized-sql" ,
14+ Language : analysis .LangPy ,
15+ Description : "Check if SQL query is sanitized" ,
16+ Category : analysis .CategorySecurity ,
17+ Severity : analysis .SeverityCritical ,
18+ Run : checkSQLInjection ,
19+ }
20+
1221// checkSQLInjection is the rule callback that inspects each call node.
13- func checkSQLInjection (r analysis.Rule , ana * analysis.Analyzer , node * sitter.Node ) {
14- source := ana .FileContext .Source
22+ func checkSQLInjection (pass * analysis.Pass ) (interface {}, error ) {
23+ analysis .Preorder (pass , func (node * sitter.Node ) {
24+ source := pass .FileContext .Source
1525
16- // Only process call nodes.
17- if node .Type () != "call" {
18- return
19- }
26+ // Only process call nodes.
27+ if node .Type () != "call" {
28+ return
29+ }
2030
21- // Extract the function part (e.g. cursor.execute).
22- functionNode := node .ChildByFieldName ("function" )
23- if functionNode == nil {
24- return
25- }
31+ // Extract the function part (e.g. cursor.execute).
32+ functionNode := node .ChildByFieldName ("function" )
33+ if functionNode == nil {
34+ return
35+ }
2636
27- // Proceed only if the function is one of our recognized SQL methods.
28- if ! isSQLExecuteMethod (functionNode , source ) {
29- return
30- }
37+ // Proceed only if the function is one of our recognized SQL methods.
38+ if ! isSQLExecuteMethod (functionNode , source ) {
39+ return
40+ }
3141
32- // Check the first argument.
33- argsNode := node .ChildByFieldName ("arguments" )
34- if argsNode == nil {
35- return
36- }
37- firstArg := getNthChild (argsNode , 0 )
38- if firstArg == nil {
39- return
40- }
42+ // Check the first argument.
43+ argsNode := node .ChildByFieldName ("arguments" )
44+ if argsNode == nil {
45+ return
46+ }
47+ firstArg := getNthChild (argsNode , 0 )
48+ if firstArg == nil {
49+ return
50+ }
4151
42- // If the query string is built unsafely, report an issue.
43- if isUnsafeString (firstArg , source ) {
44- ana .Report (& analysis.Issue {
45- Message : "Concatenated string in SQL query is an SQL injection threat!" ,
46- Category : analysis .CategorySecurity ,
47- Severity : analysis .SeverityCritical ,
48- Range : node .Range (),
49- })
50- return
51- }
52+ // If the query string is built unsafely, report an issue.
53+ if isUnsafeString (firstArg , source ) {
54+ pass .Report (pass , node , "Concatenated string in SQL query is an SQL injection threat" )
55+ return
56+ }
5257
53- // If the argument is an identifier, trace its origin.
54- if firstArg .Type () == "identifier" {
55- varName := firstArg .Content (source )
56- traceVariableOrigin (r , ana , varName , node , make (map [string ]bool ), make (map [string ]bool ), source )
57- }
58- }
58+ // If the argument is an identifier, trace its origin.
59+ if firstArg .Type () == "identifier" {
60+ varName := firstArg .Content (source )
61+ traceVariableOrigin (pass , varName , node , make (map [string ]bool ), make (map [string ]bool ), source )
62+ }
63+ })
5964
60- // SQLInjection registers the SQL injection rule.
61- func SQLInjection () analysis.Rule {
62- var entry analysis.VisitFn = checkSQLInjection
63- return analysis .CreateRule ("call" , analysis .LangPy , & entry , nil )
65+ return nil , nil
6466}
6567
6668// --- Helper Functions ---
@@ -107,34 +109,34 @@ func isUnsafeString(node *sitter.Node, source []byte) bool {
107109 return false
108110}
109111
110- func traceVariableOrigin (r analysis. Rule , ana * analysis.Analyzer , varName string , originalNode * sitter.Node ,
112+ func traceVariableOrigin (pass * analysis.Pass , varName string , originalNode * sitter.Node ,
111113 visitedVars map [string ]bool , visitedFiles map [string ]bool , source []byte ) {
112114
113115 if visitedVars [varName ] {
114116 return
115117 }
116118 visitedVars [varName ] = true
117119
118- if traceLocalAssignments (r , ana , varName , originalNode , visitedVars , visitedFiles , source ) {
120+ if traceLocalAssignments (pass , varName , originalNode , visitedVars , visitedFiles , source ) {
119121 return
120122 }
121123
122- traceCrossFileImports (r , ana , varName , originalNode , visitedVars , visitedFiles , source )
124+ traceCrossFileImports (pass , varName , originalNode , visitedVars , visitedFiles , source )
123125}
124126
125- func traceLocalAssignments (r analysis. Rule , ana * analysis.Analyzer , varName string , originalNode * sitter.Node ,
127+ func traceLocalAssignments (pass * analysis.Pass , varName string , originalNode * sitter.Node ,
126128 visitedVars map [string ]bool , visitedFiles map [string ]bool , source []byte ) bool {
127129
128130 query := `(assignment left: (identifier) @var right: (_) @value)`
129- q , err := sitter .NewQuery ([]byte (query ), ana . Language .Parser ())
131+ q , err := sitter .NewQuery ([]byte (query ), pass . Analyzer . Language .Grammar ())
130132 if err != nil {
131133 return false
132134 }
133135 defer q .Close ()
134136
135137 cursor := sitter .NewQueryCursor ()
136138 defer cursor .Close ()
137- cursor .Exec (q , ana .FileContext .Ast )
139+ cursor .Exec (q , pass .FileContext .Ast )
138140
139141 for {
140142 match , ok := cursor .NextMatch ()
@@ -143,8 +145,8 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri
143145 }
144146
145147 var varNode , valueNode * sitter.Node
146- for _ , capture := range match .Captures {
147- switch capture . Name {
148+ for idx , capture := range match .Captures {
149+ switch q . CaptureNameForId ( uint32 ( idx )) {
148150 case "var" :
149151 varNode = capture .Node
150152 case "value" :
@@ -154,24 +156,21 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri
154156
155157 if varNode != nil && varNode .Content (source ) == varName {
156158 if isUnsafeString (valueNode , source ) {
157- ana .Report (& analysis.Issue {
158- Message : fmt .Sprintf ("Variable '%s' originates from an unsafe string" , varName ),
159- Range : originalNode .Range (),
160- })
159+ pass .Report (pass , originalNode , fmt .Sprintf ("Variable '%s' originates from an unsafe string" , varName ))
161160 return true
162161 }
163162
164163 if valueNode .Type () == "identifier" {
165164 newVar := valueNode .Content (source )
166- traceVariableOrigin (r , ana , newVar , originalNode , visitedVars , visitedFiles , source )
165+ traceVariableOrigin (pass , newVar , originalNode , visitedVars , visitedFiles , source )
167166 return true
168167 }
169168 }
170169 }
171170 return false
172171}
173172
174- func traceCrossFileImports (r analysis. Rule , ana * analysis.Analyzer , varName string , originalNode * sitter.Node ,
173+ func traceCrossFileImports (pass * analysis.Pass , varName string , originalNode * sitter.Node ,
175174 visitedVars map [string ]bool , visitedFiles map [string ]bool , source []byte ) {
176175
177176 query := `(
@@ -180,15 +179,15 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
180179 name: (dotted_name) @imported_var
181180 ) @import
182181 )`
183- q , err := sitter .NewQuery ([]byte (query ), ana . Language .Parser ())
182+ q , err := sitter .NewQuery ([]byte (query ), pass . Analyzer . Language .Grammar ())
184183 if err != nil {
185184 return
186185 }
187186 defer q .Close ()
188187
189188 cursor := sitter .NewQueryCursor ()
190189 defer cursor .Close ()
191- cursor .Exec (q , ana .FileContext .Ast )
190+ cursor .Exec (q , pass .FileContext .Ast )
192191
193192 for {
194193 match , ok := cursor .NextMatch ()
@@ -197,8 +196,8 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
197196 }
198197
199198 var moduleNode , varNode * sitter.Node
200- for _ , capture := range match .Captures {
201- switch capture . Name {
199+ for idx , capture := range match .Captures {
200+ switch q . CaptureNameForId ( uint32 ( idx )) {
202201 case "module" :
203202 moduleNode = capture .Node
204203 case "imported_var" :
@@ -213,16 +212,16 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
213212 }
214213 visitedFiles [modulePath ] = true
215214
216- for _ , file := range ana .Files {
215+ for _ , file := range pass .Files {
217216 if strings .HasSuffix (file .FilePath , modulePath ) {
218217 // Create a temporary analyzer context for the imported file.
219- tempAna := & analysis.Analyzer {
220- Language : ana . Language ,
218+ tempPass := & analysis.Pass {
219+ Analyzer : pass . Analyzer ,
221220 FileContext : file ,
222- Files : ana .Files ,
223- Report : ana .Report , // Reuse the report function.
221+ Files : pass .Files ,
222+ Report : pass .Report , // Reuse the report function.
224223 }
225- traceVariableOrigin (r , tempAna , varName , originalNode , visitedVars , visitedFiles , file .Source )
224+ traceVariableOrigin (tempPass , varName , originalNode , visitedVars , visitedFiles , file .Source )
226225 }
227226 }
228227 }
0 commit comments