diff --git a/errgrpc/grpc.go b/errgrpc/grpc.go index e9cfce7..8900145 100644 --- a/errgrpc/grpc.go +++ b/errgrpc/grpc.go @@ -23,71 +23,175 @@ package errgrpc import ( "context" + "errors" "fmt" + "reflect" "strconv" "strings" + spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/protoadapt" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/containerd/typeurl/v2" "github.com/containerd/errdefs" "github.com/containerd/errdefs/internal/cause" + "github.com/containerd/errdefs/internal/types" ) -// ToGRPC will attempt to map the backend containerd error into a grpc error, -// using the original error message as a description. +// ToGRPC will attempt to map the error into a grpc error, from the error types +// defined in the the errdefs package and attempign to preserve the original +// description. Any type which does not resolve to a defined error type will +// be assigned the unknown error code. // // Further information may be extracted from certain errors depending on their -// type. +// type. The grpc error details will be used to attempt to preserve as much of +// the error structures and types as possible. +// +// Errors which can be marshaled using protobuf or typeurl will be considered +// for including as GRPC error details. +// Additionally, use the following interfaces in errors to preserve custom types: // -// If the error is unmapped, the original error will be returned to be handled -// by the regular grpc error handling stack. +// WrapError(error) error - Used to wrap the previous error +// JoinErrors(...error) error - Used to join all previous errors +// CollapseError() - Used for errors which carry information but +// should not have their error message shown. func ToGRPC(err error) error { if err == nil { return nil } - if isGRPCError(err) { + if _, ok := status.FromError(err); ok { // error has already been mapped to grpc return err } + st := statusFromError(err) + if st != nil { + if details := errorDetails(err, false); len(details) > 0 { + if ds, _ := st.WithDetails(details...); ds != nil { + st = ds + } + } + err = st.Err() + } + return err +} - switch { - case errdefs.IsInvalidArgument(err): - return status.Error(codes.InvalidArgument, err.Error()) - case errdefs.IsNotFound(err): - return status.Error(codes.NotFound, err.Error()) - case errdefs.IsAlreadyExists(err): - return status.Error(codes.AlreadyExists, err.Error()) - case errdefs.IsFailedPrecondition(err) || errdefs.IsConflict(err) || errdefs.IsNotModified(err): - return status.Error(codes.FailedPrecondition, err.Error()) - case errdefs.IsUnavailable(err): - return status.Error(codes.Unavailable, err.Error()) - case errdefs.IsNotImplemented(err): - return status.Error(codes.Unimplemented, err.Error()) - case errdefs.IsCanceled(err): - return status.Error(codes.Canceled, err.Error()) - case errdefs.IsDeadlineExceeded(err): - return status.Error(codes.DeadlineExceeded, err.Error()) - case errdefs.IsUnauthorized(err): - return status.Error(codes.Unauthenticated, err.Error()) - case errdefs.IsPermissionDenied(err): - return status.Error(codes.PermissionDenied, err.Error()) - case errdefs.IsInternal(err): - return status.Error(codes.Internal, err.Error()) - case errdefs.IsDataLoss(err): - return status.Error(codes.DataLoss, err.Error()) - case errdefs.IsAborted(err): - return status.Error(codes.Aborted, err.Error()) - case errdefs.IsOutOfRange(err): - return status.Error(codes.OutOfRange, err.Error()) - case errdefs.IsResourceExhausted(err): - return status.Error(codes.ResourceExhausted, err.Error()) - case errdefs.IsUnknown(err): - return status.Error(codes.Unknown, err.Error()) +func statusFromError(err error) *status.Status { + switch errdefs.Resolve(err) { + case errdefs.ErrInvalidArgument: + return status.New(codes.InvalidArgument, err.Error()) + case errdefs.ErrNotFound: + return status.New(codes.NotFound, err.Error()) + case errdefs.ErrAlreadyExists: + return status.New(codes.AlreadyExists, err.Error()) + case errdefs.ErrPermissionDenied: + return status.New(codes.PermissionDenied, err.Error()) + case errdefs.ErrResourceExhausted: + return status.New(codes.ResourceExhausted, err.Error()) + case errdefs.ErrFailedPrecondition, errdefs.ErrConflict, errdefs.ErrNotModified: + return status.New(codes.FailedPrecondition, err.Error()) + case errdefs.ErrAborted: + return status.New(codes.Aborted, err.Error()) + case errdefs.ErrOutOfRange: + return status.New(codes.OutOfRange, err.Error()) + case errdefs.ErrNotImplemented: + return status.New(codes.Unimplemented, err.Error()) + case errdefs.ErrInternal: + return status.New(codes.Internal, err.Error()) + case errdefs.ErrUnavailable: + return status.New(codes.Unavailable, err.Error()) + case errdefs.ErrDataLoss: + return status.New(codes.DataLoss, err.Error()) + case errdefs.ErrUnauthenticated: + return status.New(codes.Unauthenticated, err.Error()) + case context.DeadlineExceeded: + return status.New(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.New(codes.Canceled, err.Error()) + case errdefs.ErrUnknown: + return status.New(codes.Unknown, err.Error()) } + return nil +} - return err +// errorDetails returns an array of errors which make up the provided error. +// If firstIncluded is true, then all encodable errors will be used, otherwise +// the first error in an error list will be not be used, to account for the +// the base status error which details are added to via wrap or join. +// +// The errors are ordered in way that they can be applied in order by either +// wrapping or joining the errors to recreate an error with the same structure +// when `WrapError` and `JoinErrors` interfaces are used. +// +// The intent is that when re-applying the errors to create a single error, the +// results of calls to `Error()`, `errors.Is`, `errors.As`, and "%+v" formatting +// is the same as the original error. +func errorDetails(err error, firstIncluded bool) []protoadapt.MessageV1 { + switch uerr := err.(type) { + case interface{ Unwrap() error }: + details := errorDetails(uerr.Unwrap(), firstIncluded) + + // If the type is able to wrap, then include if proto + if _, ok := err.(interface{ WrapError(error) error }); ok { + // Get proto message + if protoErr := toProtoMessage(err); protoErr != nil { + details = append(details, protoErr) + } + } + + return details + case interface{ Unwrap() []error }: + var details []protoadapt.MessageV1 + for i, e := range uerr.Unwrap() { + details = append(details, errorDetails(e, firstIncluded || i > 0)...) + } + + if _, ok := err.(interface{ JoinErrors(...error) error }); ok { + // Get proto message + if protoErr := toProtoMessage(err); protoErr != nil { + details = append(details, protoErr) + } + } + return details + } + + if firstIncluded { + if protoErr := toProtoMessage(err); protoErr != nil { + return []protoadapt.MessageV1{protoErr} + } + if gs, ok := status.FromError(ToGRPC(err)); ok { + return []protoadapt.MessageV1{gs.Proto()} + } + // TODO: Else include unknown extra error type? + } + + return nil +} + +func toProtoMessage(err error) protoadapt.MessageV1 { + // Do not double encode proto messages, otherwise use Any + if pm, ok := err.(protoadapt.MessageV1); ok { + return pm + } + if pm, ok := err.(proto.Message); ok { + return protoadapt.MessageV1Of(pm) + } + + if reflect.TypeOf(err).Kind() == reflect.Ptr { + a, aerr := typeurl.MarshalAny(err) + if aerr == nil { + return &anypb.Any{ + TypeUrl: a.GetTypeUrl(), + Value: a.GetValue(), + } + } + } + return nil } // ToGRPCf maps the error to grpc error codes, assembling the formatting string @@ -98,17 +202,33 @@ func ToGRPCf(err error, format string, args ...interface{}) error { return ToGRPC(fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)) } -// ToNative returns the underlying error from a grpc service based on the grpc error code +// ToNative returns the underlying error from a grpc service based on the grpc +// error code. The grpc details are used to add wrap the error in more context +// or support multiple errors. func ToNative(err error) error { if err == nil { return nil } - desc := errDesc(err) + s, isGRPC := status.FromError(err) + + var ( + desc string + code codes.Code + ) + + if isGRPC { + desc = s.Message() + code = s.Code() + + } else { + desc = err.Error() + code = codes.Unknown + } var cls error // divide these into error classes, becomes the cause - switch code(err) { + switch code { case codes.InvalidArgument: cls = errdefs.ErrInvalidArgument case codes.AlreadyExists: @@ -118,6 +238,10 @@ func ToNative(err error) error { case codes.Unavailable: cls = errdefs.ErrUnavailable case codes.FailedPrecondition: + // TODO: Has suffix is not sufficient for conflict and not modified + // Message should start with ": " or be at beginning of a line + // Message should end with ": " or be at the end of a line + // Compile a regex if desc == errdefs.ErrConflict.Error() || strings.HasSuffix(desc, ": "+errdefs.ErrConflict.Error()) { cls = errdefs.ErrConflict } else if desc == errdefs.ErrNotModified.Error() || strings.HasSuffix(desc, ": "+errdefs.ErrNotModified.Error()) { @@ -147,7 +271,7 @@ func ToNative(err error) error { cls = errdefs.ErrResourceExhausted default: if idx := strings.LastIndex(desc, cause.UnexpectedStatusPrefix); idx > 0 { - if status, err := strconv.Atoi(desc[idx+len(cause.UnexpectedStatusPrefix):]); err == nil && status >= 200 && status < 600 { + if status, uerr := strconv.Atoi(desc[idx+len(cause.UnexpectedStatusPrefix):]); uerr == nil && status >= 200 && status < 600 { cls = cause.ErrUnexpectedStatus{Status: status} } } @@ -157,10 +281,59 @@ func ToNative(err error) error { } msg := rebaseMessage(cls, desc) - if msg != "" { + if msg == "" { + err = cls + } else if msg != desc { err = fmt.Errorf("%s: %w", msg, cls) + } else if wm, ok := cls.(interface{ WithMessage(string) error }); ok { + err = wm.WithMessage(msg) } else { - err = cls + err = fmt.Errorf("%s: %w", msg, cls) + } + + if isGRPC { + errs := []error{err} + for _, a := range s.Details() { + var derr error + + // First decode error if needed + if s, ok := a.(*spb.Status); ok { + derr = ToNative(status.ErrorProto(s)) + } else if e, ok := a.(error); ok { + derr = e + } else if dany, ok := a.(typeurl.Any); ok { + i, uerr := typeurl.UnmarshalAny(dany) + if uerr == nil { + if e, ok = i.(error); ok { + derr = e + } else { + derr = fmt.Errorf("non-error unmarshalled detail: %v", i) + } + } else { + derr = fmt.Errorf("error of type %q with failure to unmarshal: %v", dany.GetTypeUrl(), uerr) + } + } else { + derr = fmt.Errorf("non-error detail: %v", a) + } + + switch werr := derr.(type) { + case interface{ WrapError(error) error }: + errs[len(errs)-1] = werr.WrapError(errs[len(errs)-1]) + case interface{ JoinErrors(...error) error }: + // TODO: Consider whether this should support joining a subset + errs[0] = werr.JoinErrors(errs...) + case interface{ CollapseError() }: + errs[len(errs)-1] = types.CollapsedError(errs[len(errs)-1], derr) + default: + errs = append(errs, derr) + } + + } + if len(errs) > 1 { + err = errors.Join(errs...) + } else { + err = errs[0] + } } return err @@ -179,22 +352,3 @@ func rebaseMessage(cls error, desc string) string { return strings.TrimSuffix(desc, ": "+clss) } - -func isGRPCError(err error) bool { - _, ok := status.FromError(err) - return ok -} - -func code(err error) codes.Code { - if s, ok := status.FromError(err); ok { - return s.Code() - } - return codes.Unknown -} - -func errDesc(err error) string { - if s, ok := status.FromError(err); ok { - return s.Message() - } - return err.Error() -} diff --git a/errgrpc/grpc_test.go b/errgrpc/grpc_test.go index 7a3778c..bc7790d 100644 --- a/errgrpc/grpc_test.go +++ b/errgrpc/grpc_test.go @@ -20,11 +20,14 @@ import ( "context" "errors" "fmt" + "strings" "testing" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/containerd/typeurl/v2" + "github.com/containerd/errdefs" "github.com/containerd/errdefs/errhttp" "github.com/containerd/errdefs/internal/cause" @@ -106,14 +109,15 @@ func TestGRPCRoundTrip(t *testing.T) { str: "test test test: failed precondition", }, { + // Currently failing input: status.Errorf(codes.Unavailable, "should be not available"), cause: errdefs.ErrUnavailable, - str: "should be not available: unavailable", + str: "should be not available", }, { input: errShouldLeaveAlone, cause: errdefs.ErrUnknown, - str: errShouldLeaveAlone.Error() + ": " + errdefs.ErrUnknown.Error(), + str: errShouldLeaveAlone.Error(), }, { input: context.Canceled, @@ -172,3 +176,101 @@ func TestGRPCRoundTrip(t *testing.T) { }) } } + +type TestError struct { + Value string `json:"value"` +} + +func (*TestError) Error() string { + return "test error" +} + +func TestGRPCCustomDetails(t *testing.T) { + typeurl.Register(&TestError{}, t.Name()) + expected := &TestError{ + Value: "test 1", + } + + err := errors.Join(errdefs.ErrInternal, expected) + gerr := ToGRPC(err) + + s, ok := status.FromError(gerr) + if !ok { + t.Fatalf("Not GRPC error: %v", gerr) + } + if s.Code() != codes.Internal { + t.Fatalf("Unexpectd GRPC code %v, expected %v", s.Code(), codes.Internal) + } + + nerr := ToNative(gerr) + if !errors.Is(nerr, errdefs.ErrInternal) { + t.Fatalf("Expected internal error type, got %v", nerr) + } + if !errdefs.IsInternal(err) { + t.Fatalf("Expected internal error type, got %v", nerr) + } + terr := &TestError{} + if !errors.As(nerr, &terr) { + t.Fatalf("TestError not preserved, got %v", nerr) + } else if terr.Value != expected.Value { + t.Fatalf("Value not preserved, got %v", terr.Value) + } +} + +func TestGRPCMultiError(t *testing.T) { + err := errors.Join(errdefs.ErrPermissionDenied, errdefs.ErrDataLoss, errdefs.ErrConflict, fmt.Errorf("Was not changed at all!: %w", errdefs.ErrNotModified)) + + checkError := func(err error) { + t.Helper() + if !errors.Is(err, errdefs.ErrPermissionDenied) { + t.Fatal("Not permission denied") + } + if !errors.Is(err, errdefs.ErrDataLoss) { + t.Fatal("Not data loss") + } + if !errors.Is(err, errdefs.ErrConflict) { + t.Fatal("Not conflict") + } + if !errors.Is(err, errdefs.ErrNotModified) { + t.Fatal("Not not modified") + } + if errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Fatal("Should not be failed precondition") + } + if !strings.Contains(err.Error(), "Was not changed at all!") { + t.Fatalf("Not modified error message missing from:\n%v", err) + } + } + checkError(err) + + terr := ToNative(ToGRPC(err)) + + checkError(terr) + + // Try again with decoded error + checkError(ToNative(ToGRPC(terr))) +} + +func TestGRPCNestedError(t *testing.T) { + multiErr := errors.Join(fmt.Errorf("First error: %w", errdefs.ErrNotFound), fmt.Errorf("Second error: %w", errdefs.ErrResourceExhausted)) + + checkError := func(err error) { + t.Helper() + if !errors.Is(err, errdefs.ErrNotFound) { + t.Fatal("Not not found") + } + if !errors.Is(err, errdefs.ErrResourceExhausted) { + t.Fatal("Not resource exhausted") + } + if errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Fatal("Should not be failed precondition") + } + } + checkError(multiErr) + + werr := fmt.Errorf("Wrapping the error: %w", multiErr) + + checkError(werr) + + checkError(ToNative(ToGRPC(werr))) +} diff --git a/go.mod b/go.mod index 51cca1c..3fbcfa3 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,13 @@ go 1.20 require ( github.com/containerd/typeurl/v2 v2.1.1 + google.golang.org/genproto/googleapis/rpc v0.0.0-20231002182017-d307bd883b97 google.golang.org/grpc v1.58.3 + google.golang.org/protobuf v1.31.0 ) require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect golang.org/x/sys v0.13.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231002182017-d307bd883b97 // indirect - google.golang.org/protobuf v1.31.0 // indirect )