Skip to content

Commit 6e24cec

Browse files
chore: migrate to the newer analysis runner
Signed-off-by: Sourya Vatsyayan <sourya@deepsource.io>
1 parent 084046e commit 6e24cec

File tree

1 file changed

+67
-68
lines changed

1 file changed

+67
-68
lines changed

checkers/python/avoid-unsanitized-sql.go

Lines changed: 67 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,58 +9,60 @@ import (
99
"globstar.dev/analysis"
1010
)
1111

12+
var AvoidUnsanitizedSQL = &analysis.Analyzer{
13+
Name: "avoid-unsanitized-sql",
14+
Language: analysis.LangPy,
15+
Description: "Check if SQL query is sanitized",
16+
Category: analysis.CategorySecurity,
17+
Severity: analysis.SeverityCritical,
18+
Run: checkSQLInjection,
19+
}
20+
1221
// checkSQLInjection is the rule callback that inspects each call node.
13-
func checkSQLInjection(r analysis.Rule, ana *analysis.Analyzer, node *sitter.Node) {
14-
source := ana.FileContext.Source
22+
func checkSQLInjection(pass *analysis.Pass) (interface{}, error) {
23+
analysis.Preorder(pass, func(node *sitter.Node) {
24+
source := pass.FileContext.Source
1525

16-
// Only process call nodes.
17-
if node.Type() != "call" {
18-
return
19-
}
26+
// Only process call nodes.
27+
if node.Type() != "call" {
28+
return
29+
}
2030

21-
// Extract the function part (e.g. cursor.execute).
22-
functionNode := node.ChildByFieldName("function")
23-
if functionNode == nil {
24-
return
25-
}
31+
// Extract the function part (e.g. cursor.execute).
32+
functionNode := node.ChildByFieldName("function")
33+
if functionNode == nil {
34+
return
35+
}
2636

27-
// Proceed only if the function is one of our recognized SQL methods.
28-
if !isSQLExecuteMethod(functionNode, source) {
29-
return
30-
}
37+
// Proceed only if the function is one of our recognized SQL methods.
38+
if !isSQLExecuteMethod(functionNode, source) {
39+
return
40+
}
3141

32-
// Check the first argument.
33-
argsNode := node.ChildByFieldName("arguments")
34-
if argsNode == nil {
35-
return
36-
}
37-
firstArg := getNthChild(argsNode, 0)
38-
if firstArg == nil {
39-
return
40-
}
42+
// Check the first argument.
43+
argsNode := node.ChildByFieldName("arguments")
44+
if argsNode == nil {
45+
return
46+
}
47+
firstArg := getNthChild(argsNode, 0)
48+
if firstArg == nil {
49+
return
50+
}
4151

42-
// If the query string is built unsafely, report an issue.
43-
if isUnsafeString(firstArg, source) {
44-
ana.Report(&analysis.Issue{
45-
Message: "Concatenated string in SQL query is an SQL injection threat!",
46-
Category: analysis.CategorySecurity,
47-
Severity: analysis.SeverityCritical,
48-
Range: node.Range(),
49-
})
50-
return
51-
}
52+
// If the query string is built unsafely, report an issue.
53+
if isUnsafeString(firstArg, source) {
54+
pass.Report(pass, node, "Concatenated string in SQL query is an SQL injection threat")
55+
return
56+
}
5257

53-
// If the argument is an identifier, trace its origin.
54-
if firstArg.Type() == "identifier" {
55-
varName := firstArg.Content(source)
56-
traceVariableOrigin(r, ana, varName, node, make(map[string]bool), make(map[string]bool), source)
57-
}
58-
}
58+
// If the argument is an identifier, trace its origin.
59+
if firstArg.Type() == "identifier" {
60+
varName := firstArg.Content(source)
61+
traceVariableOrigin(pass, varName, node, make(map[string]bool), make(map[string]bool), source)
62+
}
63+
})
5964

60-
// SQLInjection registers the SQL injection rule.
61-
func SQLInjection() analysis.Rule {
62-
var entry analysis.VisitFn = checkSQLInjection
63-
return analysis.CreateRule("call", analysis.LangPy, &entry, nil)
65+
return nil, nil
6466
}
6567

6668
// --- Helper Functions ---
@@ -107,34 +109,34 @@ func isUnsafeString(node *sitter.Node, source []byte) bool {
107109
return false
108110
}
109111

110-
func traceVariableOrigin(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node,
112+
func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitter.Node,
111113
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) {
112114

113115
if visitedVars[varName] {
114116
return
115117
}
116118
visitedVars[varName] = true
117119

118-
if traceLocalAssignments(r, ana, varName, originalNode, visitedVars, visitedFiles, source) {
120+
if traceLocalAssignments(pass, varName, originalNode, visitedVars, visitedFiles, source) {
119121
return
120122
}
121123

122-
traceCrossFileImports(r, ana, varName, originalNode, visitedVars, visitedFiles, source)
124+
traceCrossFileImports(pass, varName, originalNode, visitedVars, visitedFiles, source)
123125
}
124126

125-
func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node,
127+
func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *sitter.Node,
126128
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) bool {
127129

128130
query := `(assignment left: (identifier) @var right: (_) @value)`
129-
q, err := sitter.NewQuery([]byte(query), ana.Language.Parser())
131+
q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar())
130132
if err != nil {
131133
return false
132134
}
133135
defer q.Close()
134136

135137
cursor := sitter.NewQueryCursor()
136138
defer cursor.Close()
137-
cursor.Exec(q, ana.FileContext.Ast)
139+
cursor.Exec(q, pass.FileContext.Ast)
138140

139141
for {
140142
match, ok := cursor.NextMatch()
@@ -143,8 +145,8 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri
143145
}
144146

145147
var varNode, valueNode *sitter.Node
146-
for _, capture := range match.Captures {
147-
switch capture.Name {
148+
for idx, capture := range match.Captures {
149+
switch q.CaptureNameForId(uint32(idx)) {
148150
case "var":
149151
varNode = capture.Node
150152
case "value":
@@ -154,24 +156,21 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri
154156

155157
if varNode != nil && varNode.Content(source) == varName {
156158
if isUnsafeString(valueNode, source) {
157-
ana.Report(&analysis.Issue{
158-
Message: fmt.Sprintf("Variable '%s' originates from an unsafe string", varName),
159-
Range: originalNode.Range(),
160-
})
159+
pass.Report(pass, originalNode, fmt.Sprintf("Variable '%s' originates from an unsafe string", varName))
161160
return true
162161
}
163162

164163
if valueNode.Type() == "identifier" {
165164
newVar := valueNode.Content(source)
166-
traceVariableOrigin(r, ana, newVar, originalNode, visitedVars, visitedFiles, source)
165+
traceVariableOrigin(pass, newVar, originalNode, visitedVars, visitedFiles, source)
167166
return true
168167
}
169168
}
170169
}
171170
return false
172171
}
173172

174-
func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node,
173+
func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *sitter.Node,
175174
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) {
176175

177176
query := `(
@@ -180,15 +179,15 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
180179
name: (dotted_name) @imported_var
181180
) @import
182181
)`
183-
q, err := sitter.NewQuery([]byte(query), ana.Language.Parser())
182+
q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar())
184183
if err != nil {
185184
return
186185
}
187186
defer q.Close()
188187

189188
cursor := sitter.NewQueryCursor()
190189
defer cursor.Close()
191-
cursor.Exec(q, ana.FileContext.Ast)
190+
cursor.Exec(q, pass.FileContext.Ast)
192191

193192
for {
194193
match, ok := cursor.NextMatch()
@@ -197,8 +196,8 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
197196
}
198197

199198
var moduleNode, varNode *sitter.Node
200-
for _, capture := range match.Captures {
201-
switch capture.Name {
199+
for idx, capture := range match.Captures {
200+
switch q.CaptureNameForId(uint32(idx)) {
202201
case "module":
203202
moduleNode = capture.Node
204203
case "imported_var":
@@ -213,16 +212,16 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri
213212
}
214213
visitedFiles[modulePath] = true
215214

216-
for _, file := range ana.Files {
215+
for _, file := range pass.Files {
217216
if strings.HasSuffix(file.FilePath, modulePath) {
218217
// Create a temporary analyzer context for the imported file.
219-
tempAna := &analysis.Analyzer{
220-
Language: ana.Language,
218+
tempPass := &analysis.Pass{
219+
Analyzer: pass.Analyzer,
221220
FileContext: file,
222-
Files: ana.Files,
223-
Report: ana.Report, // Reuse the report function.
221+
Files: pass.Files,
222+
Report: pass.Report, // Reuse the report function.
224223
}
225-
traceVariableOrigin(r, tempAna, varName, originalNode, visitedVars, visitedFiles, file.Source)
224+
traceVariableOrigin(tempPass, varName, originalNode, visitedVars, visitedFiles, file.Source)
226225
}
227226
}
228227
}

0 commit comments

Comments
 (0)