diff --git a/assertion/function/preprocess/cfg.go b/assertion/function/preprocess/cfg.go index 38e0716e..6b6e40e4 100644 --- a/assertion/function/preprocess/cfg.go +++ b/assertion/function/preprocess/cfg.go @@ -55,7 +55,14 @@ func (p *Preprocessor) CFG(graph *cfg.CFG, funcDecl *ast.FuncDecl) *cfg.CFG { failureBlock := &cfg.Block{Index: int32(len(graph.Blocks))} graph.Blocks = append(graph.Blocks, failureBlock) - // Perform the (series of) CFG transformations. + // Perform a series of CFG transformations here (for hooks and canonicalization). The order of + // these transformations matters due to canonicalization. Some transformations may expect the + // CFG to be in canonical form, and some transformations may change the CFG structure in a way + // that it needs to be re-canonicalized. + + // split blocks do not require the CFG to be in canonical form, and it may modify the CFG + // structure in a way that it needs to be re-canonicalized. Here, we cleverly bundles the two + // operations together such that we only need to run canonicalization once. for _, block := range graph.Blocks { if block.Live { p.splitBlockOnTrustedFuncs(graph, block, failureBlock) @@ -63,7 +70,15 @@ func (p *Preprocessor) CFG(graph *cfg.CFG, funcDecl *ast.FuncDecl) *cfg.CFG { } for _, block := range graph.Blocks { if block.Live { - p.restructureConditional(graph, block) + p.canonicalizeConditional(graph, block) + } + } + // Replacing conditionals in the CFG requires the CFG to be in canonical form (such that it + // does not have to handle "trustedFunc() && trustedFunc()"), and it will canonicalize the + // modified block by itself. + for _, block := range graph.Blocks { + if block.Live { + p.replaceConditional(graph, block) } } @@ -119,6 +134,10 @@ func copyGraph(graph *cfg.CFG) *cfg.CFG { return newGraph } +// splitBlockOnTrustedFuncs splits the CFG block into two parts upon seeing a trusted function +// from the hook framework (e.g., "require.Nil(t, arg)" to "if arg == nil { }". +// This does not expect the CFG to be in canonical form, and it may change the CFG structure in a +// way that it needs to be re-canonicalized. func (p *Preprocessor) splitBlockOnTrustedFuncs(graph *cfg.CFG, thisBlock, failureBlock *cfg.Block) { for i, node := range thisBlock.Nodes { expr, ok := node.(*ast.ExprStmt) @@ -153,47 +172,69 @@ func (p *Preprocessor) splitBlockOnTrustedFuncs(graph *cfg.CFG, thisBlock, failu } } -func (p *Preprocessor) restructureConditional(graph *cfg.CFG, thisBlock *cfg.Block) { - // We only restructure non-empty branching blocks. - if len(thisBlock.Nodes) == 0 || len(thisBlock.Succs) != 2 { +// replaceConditional calls the hook functions and replaces the conditional expressions in the CFG +// with the returned equivalent expression for analysis. +// +// This function expects the CFG to be in canonical form to fully function (otherwise it may miss +// cases like "trustedFunc() && trustedFunc()"). +// +// It also calls canonicalizeConditional to canonicalize the transformed block such that the CFG +// is still canonical. +func (p *Preprocessor) replaceConditional(graph *cfg.CFG, block *cfg.Block) { + // We only replace conditionals on branching blocks. + if len(block.Nodes) == 0 || len(block.Succs) != 2 { return } - cond, ok := thisBlock.Nodes[len(thisBlock.Nodes)-1].(ast.Expr) + call, ok := block.Nodes[len(block.Nodes)-1].(*ast.CallExpr) if !ok { return } + replaced := hook.ReplaceConditional(p.pass, call) + if replaced == nil { + return + } - // places a new given node into the last position of this block - replaceCond := func(node ast.Node) { - thisBlock.Nodes[len(thisBlock.Nodes)-1] = node + block.Nodes[len(block.Nodes)-1] = replaced + // The returned expression may be a binary expression, so we need to canonicalize the CFG again + // after such replacement. + p.canonicalizeConditional(graph, block) +} + +// canonicalizeConditional canonicalizes the conditional CFG structures to make it easier to reason +// about control flows later. For example, it rewrites +// `if !cond {T} {F}` to `if cond {F} {T}` (swap successors), and rewrites +// `if cond1 && cond2 {T} {F}` to `if cond1 {if cond2 {T} else {F}}{F}` (nesting). +func (p *Preprocessor) canonicalizeConditional(graph *cfg.CFG, thisBlock *cfg.Block) { + // We only restructure non-empty branching blocks. + if len(thisBlock.Nodes) == 0 || len(thisBlock.Succs) != 2 { + return } trueBranch := thisBlock.Succs[0] // type *cfg.Block falseBranch := thisBlock.Succs[1] // type *cfg.Block - replaceTrueBranch := func(block *cfg.Block) { - thisBlock.Succs[0] = block - } - replaceFalseBranch := func(block *cfg.Block) { - thisBlock.Succs[1] = block - } + // A few helper functions to make the code more readable. + replaceCond := func(node ast.Node) { thisBlock.Nodes[len(thisBlock.Nodes)-1] = node } // The conditional expr is the last node in the block. + replaceTrueBranch := func(block *cfg.Block) { thisBlock.Succs[0] = block } + replaceFalseBranch := func(block *cfg.Block) { thisBlock.Succs[1] = block } + swapTrueFalseBranches := func() { replaceTrueBranch(falseBranch); replaceFalseBranch(trueBranch) } - swapTrueFalseBranches := func() { - replaceTrueBranch(falseBranch) - replaceFalseBranch(trueBranch) + cond, ok := thisBlock.Nodes[len(thisBlock.Nodes)-1].(ast.Expr) + if !ok { + return } switch cond := cond.(type) { case *ast.ParenExpr: // if a parenexpr, strip and restart - this is done with recursion to account for ((((x)))) case replaceCond(cond.X) - p.restructureConditional(graph, thisBlock) // recur within parens + p.canonicalizeConditional(graph, thisBlock) // recur within parens case *ast.UnaryExpr: if cond.Op == token.NOT { // swap successors - i.e. swap true and false branches swapTrueFalseBranches() replaceCond(cond.X) - p.restructureConditional(graph, thisBlock) // recur within NOT + p.canonicalizeConditional(graph, thisBlock) // recur within NOT } case *ast.BinaryExpr: // Logical AND and Logical OR actually require the exact same short circuiting behavior @@ -214,8 +255,8 @@ func (p *Preprocessor) restructureConditional(graph *cfg.CFG, thisBlock *cfg.Blo replaceFalseBranch(newBlock) } graph.Blocks = append(graph.Blocks, newBlock) - p.restructureConditional(graph, thisBlock) - p.restructureConditional(graph, newBlock) + p.canonicalizeConditional(graph, thisBlock) + p.canonicalizeConditional(graph, newBlock) } // Standardize binary expressions to be of the form `expr OP literal` by swapping `x` and `y`, if `x` is a literal. @@ -277,8 +318,8 @@ func (p *Preprocessor) restructureConditional(graph *cfg.CFG, thisBlock *cfg.Blo Op: token.NOT, X: x, } - replaceCond(newCond) // replaces `ok != true` with `!ok` - p.restructureConditional(graph, thisBlock) // recur to swap true and false branches for the unary expr `!ok` + replaceCond(newCond) // replaces `ok != true` with `!ok` + p.canonicalizeConditional(graph, thisBlock) // recur to swap true and false branches for the unary expr `!ok` } case token.EQL: @@ -292,8 +333,8 @@ func (p *Preprocessor) restructureConditional(graph *cfg.CFG, thisBlock *cfg.Blo Op: token.NOT, X: x, } - replaceCond(newCond) // replaces `ok == false` with `!ok` - p.restructureConditional(graph, thisBlock) // recur to swap true and false branches for the unary expr `!ok` + replaceCond(newCond) // replaces `ok == false` with `!ok` + p.canonicalizeConditional(graph, thisBlock) // recur to swap true and false branches for the unary expr `!ok` } } } diff --git a/hook/replace_conditional.go b/hook/replace_conditional.go new file mode 100644 index 00000000..5efbf895 --- /dev/null +++ b/hook/replace_conditional.go @@ -0,0 +1,78 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hook + +import ( + "go/ast" + "go/token" + "regexp" + + "golang.org/x/tools/go/analysis" +) + +// ReplaceConditional replaces a call to a matched function with the returned expression. This is +// useful for modeling stdlib and 3rd party functions that return a single boolean value, which +// implies nilability of the arguments. For example, `errors.As(err, &target)` implies +// `target != nil`, so it can be replaced with `target != nil`. +// +// If the call does not match any known function, nil is returned. +func ReplaceConditional(pass *analysis.Pass, call *ast.CallExpr) ast.Expr { + for sig, act := range _replaceConditionals { + if sig.match(pass, call) { + return act(pass, call) + } + } + return nil +} + +type replaceConditionalAction func(pass *analysis.Pass, call *ast.CallExpr) ast.Expr + +// _errorAsAction replaces a call to `errors.As(err, &target)` with an equivalent expression +// `errors.As(err, &target) && target != nil`. Keeping the `errors.As(err, &target)` is important +// since `err` may contain complex expressions that may have nilness issues. +// +// Note that technically `target` can still be nil even if `errors.As(err, &target)` is true. For +// example, if err is a typed nil (e.g., `var err *exec.ExitError`), then `errors.As` would +// actually find a match, but `target` would be set to the typed nil value, resulting in a `nil` +// target. However, in practice this should rarely happen such that even the official documentation +// assumes the target is non-nil after such check [1]. So here we make this assumption as well. +// +// [1] https://pkg.go.dev/errors#As +var _errorAsAction replaceConditionalAction = func(_ *analysis.Pass, call *ast.CallExpr) ast.Expr { + if len(call.Args) != 2 { + return nil + } + unaryExpr, ok := call.Args[1].(*ast.UnaryExpr) + if !ok { + return nil + } + if unaryExpr.Op != token.AND { + return nil + } + return &ast.BinaryExpr{ + X: call, + Op: token.LAND, + OpPos: call.Pos(), + Y: newNilBinaryExpr(unaryExpr.X, token.NEQ), + } +} + +var _replaceConditionals = map[trustedFuncSig]replaceConditionalAction{ + { + kind: _func, + enclosingRegex: regexp.MustCompile(`^errors$`), + funcNameRegex: regexp.MustCompile(`^As$`), + }: _errorAsAction, +} diff --git a/testdata/src/go.uber.org/testing/trustedfuncs.go b/testdata/src/go.uber.org/testing/trustedfuncs.go index 2d45c3c0..b82c82e4 100644 --- a/testdata/src/go.uber.org/testing/trustedfuncs.go +++ b/testdata/src/go.uber.org/testing/trustedfuncs.go @@ -20,6 +20,9 @@ This package aims to test any nilaway behavior specific to accomdating tests, su package testing import ( + "errors" + "os/exec" + "go.uber.org/testing/github.com/stretchr/testify/assert" "go.uber.org/testing/github.com/stretchr/testify/require" "go.uber.org/testing/github.com/stretchr/testify/suite" @@ -954,3 +957,49 @@ func testEmpty(t *testing.T, i int, a []int, mp map[int]*int) interface{} { return 0 } + +// nilable(err) +func errorsAs(err error, num string, dummy bool) { + switch num { + case "simple": + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + print(*exitErr) + } + print(*exitErr) //want "unassigned variable `exitErr` dereferenced" + case "not in if block": + var exitErr *exec.ExitError + // Not checking the result of `errors.As` would not guard the variable. + errors.As(err, &exitErr) + print(*exitErr) //want "unassigned variable `exitErr` dereferenced" + case "two errors connected by AND": + var exitErr, anotherErr *exec.ExitError + if errors.As(err, &exitErr) && errors.As(err, &anotherErr) { + print(*exitErr) + print(*anotherErr) + } + case "errors.As with other conditionals connected by AND": + var exitErr *exec.ExitError + if errors.As(err, &exitErr) && dummy { + print(*exitErr) + } + case "errors.As with other conditionals connected by OR": + var exitErr *exec.ExitError + if errors.As(err, &exitErr) || dummy { + print(*exitErr) //want "unassigned variable `exitErr` dereferenced" + } + case "two errors connected by OR": + var exitErr, anotherErr *exec.ExitError + if errors.As(err, &exitErr) || errors.As(err, &anotherErr) { + // We do not know the nilability of either. + print(*exitErr) //want "unassigned variable `exitErr` dereferenced" + print(*anotherErr) //want "unassigned variable `anotherErr` dereferenced" + } + case "nil dereference in first argument": + var exitErr *exec.ExitError + var nilError *error + if errors.As(*nilError, &exitErr) { //want "unassigned variable `nilError` dereferenced" + print(*exitErr) // But this is fine! + } + } +}