Skip to content

Commit

Permalink
Add check for returning invoked functions returning err
Browse files Browse the repository at this point in the history
  • Loading branch information
tomarrell committed Mar 17, 2021
1 parent a605974 commit 148861c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 22 deletions.
40 changes: 39 additions & 1 deletion wrapcheck/testdata/return_function_call/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,52 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"strings"
)

func main() {
_, err := do()
fmt.Println(err)

doNoErr()
doThroughInt()
doThroughIntWithoutWrap()
doThroughIntWithWrap()
}

func do() ([]byte, error) {
return json.Marshal(struct{}{}) // TODO want `error returned from external package is unwrapped`
return json.Marshal(struct{}{}) // want `error returned from external package is unwrapped`
}

func doNoErr() bool {
return strings.HasPrefix("hello world", "hello")
}

func doThroughInt() (bool, bool, bool, bool, error) {
return testInt(impl{}).MultipleReturn() // want `error returned from interface method should be wrapped`
}

func doThroughIntWithoutWrap() error {
return testInt(impl{}).ErrorReturn() // want `error returned from interface method should be wrapped`
}

func doThroughIntWithWrap() error {
return fmt.Errorf("failed: %v", testInt(impl{}).ErrorReturn())
}

type testInt interface {
MultipleReturn() (bool, bool, bool, bool, error)
ErrorReturn() error
}

type impl struct{}

func (_ impl) MultipleReturn() (bool, bool, bool, bool, error) {
return true, true, true, true, errors.New("uh oh")
}

func (_ impl) ErrorReturn() error {
return errors.New("uh oh")
}
81 changes: 60 additions & 21 deletions wrapcheck/wrapcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package wrapcheck

import (
"go/ast"
"go/token"
"go/types"

"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -37,6 +38,38 @@ func run(pass *analysis.Pass) (interface{}, error) {

// Iterate over the values to be returned looking for errors
for _, expr := range ret.Results {
// Check if the return expression is a function call, if it is, we need
// to handle it by checking the return params of the function.
retFn, ok := expr.(*ast.CallExpr)
if ok {
// If the return type of the function is a single error. This will not
// match an error within multiple return values, for that, the below
// tuple check is required.
if isError(pass.TypesInfo.TypeOf(expr)) {
reportUnwrapped(pass, retFn, retFn.Pos())
return true
}

// Check if one of the return values from the function is an error
tup, ok := pass.TypesInfo.TypeOf(expr).(*types.Tuple)
if !ok {
return true
}

// Iterate over the return values of the function looking for error
// types
for i := 0; i < tup.Len(); i++ {
v := tup.At(i)
if v == nil {
return true
}
if isError(v.Type()) {
reportUnwrapped(pass, retFn, expr.Pos())
return true
}
}
}

if !isError(pass.TypesInfo.TypeOf(expr)) {
continue
}
Expand Down Expand Up @@ -95,25 +128,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
return true
}

sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
}

// Check if the underlying type of the "x" in x.y.z is an interface, as
// errors returned from interface types should be wrapped.
if isInterface(pass, sel, ident) {
pass.Reportf(ident.NamePos, "error returned from interface method should be wrapped")
return true
}

// Check whether the function being called comes from another package,
// as functions called across package boundaries which returns errors
// should be wrapped
if isFromOtherPkg(pass, sel, ident) {
pass.Reportf(ident.NamePos, "error returned from external package is unwrapped")
return true
}
reportUnwrapped(pass, call, ident.NamePos)
}

return true
Expand All @@ -123,14 +138,38 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}

// Report unwrapped takes a call expression and an identifier and reports
// if the call is unwrapped.
func reportUnwrapped(pass *analysis.Pass, call *ast.CallExpr, tokenPos token.Pos) {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return
}

// Check if the underlying type of the "x" in x.y.z is an interface, as
// errors returned from interface types should be wrapped.
if isInterface(pass, sel) {
pass.Reportf(tokenPos, "error returned from interface method should be wrapped")
return
}

// Check whether the function being called comes from another package,
// as functions called across package boundaries which returns errors
// should be wrapped
if isFromOtherPkg(pass, sel) {
pass.Reportf(tokenPos, "error returned from external package is unwrapped")
return
}
}

// isInterface returns whether the function call is one defined on an interface.
func isInterface(pass *analysis.Pass, sel *ast.SelectorExpr, ident *ast.Ident) bool {
func isInterface(pass *analysis.Pass, sel *ast.SelectorExpr) bool {
_, ok := pass.TypesInfo.TypeOf(sel.X).Underlying().(*types.Interface)

return ok
}

func isFromOtherPkg(pass *analysis.Pass, sel *ast.SelectorExpr, ident *ast.Ident) bool {
func isFromOtherPkg(pass *analysis.Pass, sel *ast.SelectorExpr) bool {
// The package of the function that we are calling which returns the error
fn := pass.TypesInfo.ObjectOf(sel.Sel)

Expand Down

0 comments on commit 148861c

Please sign in to comment.