Skip to content

Commit

Permalink
Refactor SQL rules for better extensibility (#841)
Browse files Browse the repository at this point in the history
Remove hardwired assumption and heuristics on index of arg taking a SQL
string, be explicit about it instead.
  • Loading branch information
scop authored Aug 2, 2022
1 parent 1b0873a commit 6a26c23
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions rules/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
package rules

import (
"fmt"
"go/ast"
"regexp"
"strings"

"github.com/securego/gosec/v2"
)
Expand All @@ -30,6 +30,51 @@ type sqlStatement struct {
patterns []*regexp.Regexp
}

var sqlCallIdents = map[string]map[string]int{
"*database/sql.DB": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
"*database/sql.Tx": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
}

// findQueryArg locates the argument taking raw SQL
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}
i := -1
if ni, ok := sqlCallIdents[typeName]; ok {
if i, ok = ni[fnName]; !ok {
i = -1
}
}
if i == -1 {
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
}
if i >= len(call.Args) {
return nil, nil
}
query := call.Args[i]
return query, nil
}

func (s *sqlStatement) ID() string {
return s.MetaData.ID
}
Expand Down Expand Up @@ -69,16 +114,10 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {

// checkQuery verifies if the query parameters is a string concatenation
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}

if be, ok := query.(*ast.BinaryExpr); ok {
operands := gosec.GetBinaryExprOperands(be)
Expand Down Expand Up @@ -137,8 +176,11 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
},
}

rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
for s, si := range sqlCallIdents {
for i := range si {
rule.Add(s, i)
}
}
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
}

Expand Down Expand Up @@ -171,16 +213,10 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
}

func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}

if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
decl := ident.Obj.Decl
Expand Down Expand Up @@ -306,8 +342,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
},
},
}
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
for s, si := range sqlCallIdents {
for i := range si {
rule.Add(s, i)
}
}
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
Expand Down

0 comments on commit 6a26c23

Please sign in to comment.