Skip to content

Commit 35a92b4

Browse files
authored
feat(sql): enhance SQL injection detection with improved string concatenation checks (#1454)
* feat(sql): enhance SQL injection detection with improved string concatenation checks * optimize: only one ast.Inspect loop, use slices.ContainsFunc * refactor(sql): streamline SQL argument retrieval, replace constObject with TryResolve, minor cleanup * feat(sql): enhance query mutation checks for shadowed variables and add regression tests * remove deprecated ast.Object
1 parent bc9d2bc commit 35a92b4

File tree

4 files changed

+569
-165
lines changed

4 files changed

+569
-165
lines changed

helpers.go

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -385,29 +385,44 @@ func GetPkgAbsPath(pkgPath string) (string, error) {
385385
return absPath, nil
386386
}
387387

388-
// ConcatString recursively concatenates strings from a binary expression
389-
func ConcatString(n *ast.BinaryExpr) (string, bool) {
390-
var s string
391-
// sub expressions are found in X object, Y object is always last BasicLit
392-
if rightOperand, ok := n.Y.(*ast.BasicLit); ok {
393-
if str, err := GetString(rightOperand); err == nil {
394-
s = str + s
395-
}
396-
} else {
388+
// ConcatString recursively concatenates constant strings from an expression
389+
// if the entire chain is fully constant-derived (using TryResolve).
390+
// Returns the concatenated string and true if successful.
391+
func ConcatString(expr ast.Expr, ctx *Context) (string, bool) {
392+
if expr == nil || !TryResolve(expr, ctx) {
397393
return "", false
398394
}
399-
if leftOperand, ok := n.X.(*ast.BinaryExpr); ok {
400-
if recursion, ok := ConcatString(leftOperand); ok {
401-
s = recursion + s
402-
}
403-
} else if leftOperand, ok := n.X.(*ast.BasicLit); ok {
404-
if str, err := GetString(leftOperand); err == nil {
405-
s = str + s
395+
396+
var build strings.Builder
397+
var traverse func(ast.Expr) bool
398+
traverse = func(e ast.Expr) bool {
399+
switch node := e.(type) {
400+
case *ast.BasicLit:
401+
if str, err := GetString(node); err == nil {
402+
build.WriteString(str)
403+
return true
404+
}
405+
return false
406+
case *ast.Ident:
407+
values := GetIdentStringValuesRecursive(node)
408+
for _, v := range values {
409+
build.WriteString(v)
410+
}
411+
return len(values) > 0
412+
case *ast.BinaryExpr:
413+
if node.Op != token.ADD {
414+
return false
415+
}
416+
return traverse(node.X) && traverse(node.Y)
417+
default:
418+
return false
406419
}
407-
} else {
408-
return "", false
409420
}
410-
return s, true
421+
422+
if traverse(expr) {
423+
return build.String(), true
424+
}
425+
return "", false
411426
}
412427

413428
// FindVarIdentities returns array of all variable identities in a given binary expression
@@ -574,3 +589,18 @@ func CLIBuildTags(buildTags []string) []string {
574589

575590
return buildFlags
576591
}
592+
593+
// ContainingFile returns the *ast.File from ctx.PkgFiles that contains the given node.
594+
// Returns nil if not found (shouldn't happen for nodes from the analyzed package).
595+
func ContainingFile(n ast.Node, ctx *Context) *ast.File {
596+
if n == nil {
597+
return nil
598+
}
599+
pos := n.Pos()
600+
for _, f := range ctx.PkgFiles {
601+
if f.Pos() <= pos && pos < f.End() {
602+
return f
603+
}
604+
}
605+
return nil
606+
}

0 commit comments

Comments
 (0)