diff --git a/wrapcheck/testdata/return_function_call/main.go b/wrapcheck/testdata/return_function_call/main.go index c1a5911..d01b966 100644 --- a/wrapcheck/testdata/return_function_call/main.go +++ b/wrapcheck/testdata/return_function_call/main.go @@ -2,14 +2,52 @@ package main import ( "encoding/json" + "errors" "fmt" + "strings" ) func main() { _, err := do() fmt.Println(err) + + doNoErr() + doThroughInt() + doThroughIntWithoutWrap() + doThroughIntWithWrap() } func do() ([]byte, error) { - return json.Marshal(struct{}{}) // TODO want `error returned from external package is unwrapped` + return json.Marshal(struct{}{}) // want `error returned from external package is unwrapped` +} + +func doNoErr() bool { + return strings.HasPrefix("hello world", "hello") +} + +func doThroughInt() (bool, bool, bool, bool, error) { + return testInt(impl{}).MultipleReturn() // want `error returned from interface method should be wrapped` +} + +func doThroughIntWithoutWrap() error { + return testInt(impl{}).ErrorReturn() // want `error returned from interface method should be wrapped` +} + +func doThroughIntWithWrap() error { + return fmt.Errorf("failed: %v", testInt(impl{}).ErrorReturn()) +} + +type testInt interface { + MultipleReturn() (bool, bool, bool, bool, error) + ErrorReturn() error +} + +type impl struct{} + +func (_ impl) MultipleReturn() (bool, bool, bool, bool, error) { + return true, true, true, true, errors.New("uh oh") +} + +func (_ impl) ErrorReturn() error { + return errors.New("uh oh") } diff --git a/wrapcheck/wrapcheck.go b/wrapcheck/wrapcheck.go index e249a7e..63f274f 100644 --- a/wrapcheck/wrapcheck.go +++ b/wrapcheck/wrapcheck.go @@ -2,6 +2,7 @@ package wrapcheck import ( "go/ast" + "go/token" "go/types" "golang.org/x/tools/go/analysis" @@ -37,6 +38,38 @@ func run(pass *analysis.Pass) (interface{}, error) { // Iterate over the values to be returned looking for errors for _, expr := range ret.Results { + // Check if the return expression is a function call, if it is, we need + // to handle it by checking the return params of the function. + retFn, ok := expr.(*ast.CallExpr) + if ok { + // If the return type of the function is a single error. This will not + // match an error within multiple return values, for that, the below + // tuple check is required. + if isError(pass.TypesInfo.TypeOf(expr)) { + reportUnwrapped(pass, retFn, retFn.Pos()) + return true + } + + // Check if one of the return values from the function is an error + tup, ok := pass.TypesInfo.TypeOf(expr).(*types.Tuple) + if !ok { + return true + } + + // Iterate over the return values of the function looking for error + // types + for i := 0; i < tup.Len(); i++ { + v := tup.At(i) + if v == nil { + return true + } + if isError(v.Type()) { + reportUnwrapped(pass, retFn, expr.Pos()) + return true + } + } + } + if !isError(pass.TypesInfo.TypeOf(expr)) { continue } @@ -95,25 +128,7 @@ func run(pass *analysis.Pass) (interface{}, error) { return true } - sel, ok := call.Fun.(*ast.SelectorExpr) - if !ok { - return true - } - - // Check if the underlying type of the "x" in x.y.z is an interface, as - // errors returned from interface types should be wrapped. - if isInterface(pass, sel, ident) { - pass.Reportf(ident.NamePos, "error returned from interface method should be wrapped") - return true - } - - // Check whether the function being called comes from another package, - // as functions called across package boundaries which returns errors - // should be wrapped - if isFromOtherPkg(pass, sel, ident) { - pass.Reportf(ident.NamePos, "error returned from external package is unwrapped") - return true - } + reportUnwrapped(pass, call, ident.NamePos) } return true @@ -123,14 +138,38 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } +// Report unwrapped takes a call expression and an identifier and reports +// if the call is unwrapped. +func reportUnwrapped(pass *analysis.Pass, call *ast.CallExpr, tokenPos token.Pos) { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + + // Check if the underlying type of the "x" in x.y.z is an interface, as + // errors returned from interface types should be wrapped. + if isInterface(pass, sel) { + pass.Reportf(tokenPos, "error returned from interface method should be wrapped") + return + } + + // Check whether the function being called comes from another package, + // as functions called across package boundaries which returns errors + // should be wrapped + if isFromOtherPkg(pass, sel) { + pass.Reportf(tokenPos, "error returned from external package is unwrapped") + return + } +} + // isInterface returns whether the function call is one defined on an interface. -func isInterface(pass *analysis.Pass, sel *ast.SelectorExpr, ident *ast.Ident) bool { +func isInterface(pass *analysis.Pass, sel *ast.SelectorExpr) bool { _, ok := pass.TypesInfo.TypeOf(sel.X).Underlying().(*types.Interface) return ok } -func isFromOtherPkg(pass *analysis.Pass, sel *ast.SelectorExpr, ident *ast.Ident) bool { +func isFromOtherPkg(pass *analysis.Pass, sel *ast.SelectorExpr) bool { // The package of the function that we are calling which returns the error fn := pass.TypesInfo.ObjectOf(sel.Sel)