From c7310d6a89a7a59ab57a0c16e076c556741b9205 Mon Sep 17 00:00:00 2001 From: Roger Peppe Date: Fri, 22 Mar 2024 14:57:45 +0000 Subject: [PATCH] ociregistry: improve errors with respect to HTTP status TODO more explanation Fixes #26. Signed-off-by: Roger Peppe Change-Id: I5a79c1c6fec9f22c1f565830d73486b406dd181d --- ociregistry/error.go | 194 +++++++++++++++++++++++++++- ociregistry/iter_test.go | 2 +- ociregistry/ociauth/auth.go | 15 ++- ociregistry/ociauth/auth_test.go | 73 +++++++++++ ociregistry/ociclient/error.go | 126 +++--------------- ociregistry/ociserver/error.go | 74 +++++------ ociregistry/ociserver/error_test.go | 20 +++ ociregistry/ociserver/registry.go | 11 ++ 8 files changed, 363 insertions(+), 152 deletions(-) create mode 100644 ociregistry/ociserver/error_test.go diff --git a/ociregistry/error.go b/ociregistry/error.go index 09368e0..55d70eb 100644 --- a/ociregistry/error.go +++ b/ociregistry/error.go @@ -14,6 +14,98 @@ package ociregistry +import ( + "bytes" + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + "unicode" +) + +// WireErrors is the JSON format used for error responses in +// the OCI HTTP API. It should always contain at least one +// error. +type WireErrors struct { + Errors []WireError `json:"errors"` +} + +// Unwrap allows [errors.Is] and [errors.As] to +// see the errors inside e. +func (e *WireErrors) Unwrap() []error { + // TODO we could do this only once. + errs := make([]error, len(e.Errors)) + for i := range e.Errors { + errs[i] = &e.Errors[i] + } + return errs +} + +func (e *WireErrors) Error() string { + var buf strings.Builder + buf.WriteString(e.Errors[0].Error()) + for i := range e.Errors[1:] { + buf.WriteString("; ") + buf.WriteString(e.Errors[i+1].Error()) + } + return buf.String() +} + +// WireError holds a single error in an OCI HTTP response. +type WireError struct { + Code_ string `json:"code"` + Message string `json:"message,omitempty"` + Detail_ json.RawMessage `json:"detail,omitempty"` +} + +// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrBlobUnknown)` +// even when the error hasn't exactly wrapped that error. +func (e *WireError) Is(err error) bool { + var rerr Error + return errors.As(err, &rerr) && rerr.Code() == e.Code() +} + +// Error implements the [error] interface. +func (e *WireError) Error() string { + var buf strings.Builder + for _, r := range e.Code_ { + if r == '_' { + buf.WriteByte(' ') + } else { + buf.WriteRune(unicode.ToLower(r)) + } + } + if buf.Len() == 0 { + buf.WriteString("(no code)") + } + if e.Message != "" { + buf.WriteString(": ") + buf.WriteString(e.Message) + } + if len(e.Detail_) != 0 && !bytes.Equal(e.Detail_, []byte("null")) { + buf.WriteString("; detail: ") + buf.Write(e.Detail_) + } + return buf.String() +} + +// Code implements [Error.Code]. +func (e *WireError) Code() string { + return e.Code_ +} + +// Detail implements [Error.Detail]. +func (e *WireError) Detail() any { + if len(e.Detail_) == 0 { + return nil + } + // TODO do this once only? + var d any + json.Unmarshal(e.Detail_, &d) + return d +} + // NewError returns a new error with the given code, message and detail. func NewError(msg string, code string, detail any) Error { return ®istryError{ @@ -31,7 +123,7 @@ type Error interface { // error.Error provides the error message. error - // Code returns the error code. See + // Code returns the error code. Code() string // Detail returns any detail to be associated with the error; it should @@ -39,6 +131,102 @@ type Error interface { Detail() any } +// HTTPError is optionally implemented by an error when +// the error has originated from an HTTP request +// or might be returned from one. +type HTTPError interface { + error + + // StatusCode returns the HTTP status code of the response. + StatusCode() int + + // Response holds the HTTP response that caused the HTTPError to + // be created. It will return nil if the error was not created + // as a result of an HTTP response. + // + // The caller should not read the response body or otherwise + // change the response (mutation of errors is a Bad Thing). + // + // Use the ResponseBody method to obtain the body of the + // response if needed. + Response() *http.Response + + // ResponseBody returns the contents of the response body. It + // will return nil if the error was not created as a result of + // an HTTP response. + // + // The caller should not change or append to the returned data. + ResponseBody() []byte +} + +// NewHTTPError returns an error that wraps err to make an [HTTPError] +// that represents the given status code, response and response body. +// Both response and body may be nil. +// +// A shallow copy is made of the response. +func NewHTTPError(err error, statusCode int, response *http.Response, body []byte) HTTPError { + herr := &httpError{ + underlying: err, + statusCode: statusCode, + } + if response != nil { + herr.response = ref(*response) + herr.response.Body = nil + herr.body = body + } + return herr +} + +type httpError struct { + underlying error + statusCode int + response *http.Response + body []byte +} + +// Unwrap implements the [errors] Unwrap interface. +func (e *httpError) Unwrap(err error) error { + return e.underlying +} + +// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrRangeInvalid)` +// even when the error hasn't exactly wrapped that error. +func (e *httpError) Is(err error) bool { + switch e.statusCode { + case http.StatusRequestedRangeNotSatisfiable: + return err == ErrRangeInvalid + } + return false +} + +// Error implements [error.Error]. +func (e *httpError) Error() string { + var buf strings.Builder + buf.WriteString(strconv.Itoa(e.statusCode)) + buf.WriteString(" ") + buf.WriteString(http.StatusText(e.statusCode)) + if e.underlying != nil { + buf.WriteString(": ") + buf.WriteString(e.underlying.Error()) + } + return buf.String() +} + +// StatusCode implements [HTTPEror.StatusCode]. +func (e *httpError) StatusCode() int { + return e.statusCode +} + +// Response implements [HTTPEror.Response]. +func (e *httpError) Response() *http.Response { + return e.response +} + +// ResponseBody implements [HTTPEror.ResponseBody]. +func (e *httpError) ResponseBody() []byte { + return e.body +} + // The following values represent the known error codes. var ( ErrBlobUnknown = NewError("blob unknown to registry", "BLOB_UNKNOWN", nil) @@ -83,3 +271,7 @@ func (e *registryError) Error() string { func (e *registryError) Detail() any { return e.detail } + +func ref[T any](x T) *T { + return &x +} diff --git a/ociregistry/iter_test.go b/ociregistry/iter_test.go index 55382ee..119bc04 100644 --- a/ociregistry/iter_test.go +++ b/ociregistry/iter_test.go @@ -1,4 +1,4 @@ -//go:build go1.23 || goexperiment.rangefunc +//go:build go1.23 package ociregistry diff --git a/ociregistry/ociauth/auth.go b/ociregistry/ociauth/auth.go index 70f00ba..7c45b18 100644 --- a/ociregistry/ociauth/auth.go +++ b/ociregistry/ociauth/auth.go @@ -182,7 +182,20 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } } - return r.transport.RoundTrip(req) + resp, err = r.transport.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + // The server has responded with a 401 even though we've just + // provided a token that it gave us. Treat it as a 404 instead. + resp.Body.Close() + resp.Body = io.NopCloser(strings.NewReader("TODO")) + resp.StatusCode = http.StatusNotFound + resp.Status = http.StatusText(resp.StatusCode) + return resp, nil } // setAuthorization sets up authorization on the given request using any diff --git a/ociregistry/ociauth/auth_test.go b/ociregistry/ociauth/auth_test.go index 3b9a8f3..b4701e4 100644 --- a/ociregistry/ociauth/auth_test.go +++ b/ociregistry/ociauth/auth_test.go @@ -235,6 +235,79 @@ func TestAuthNotAvailableAfterChallenge(t *testing.T) { qt.Check(t, qt.Equals(requestCount, 1)) } +func Test401ResponseWithJustAcquiredToken(t *testing.T) { + // This tests the scenario where a server returns a 401 response + // when the client has just successfully acquired a token from + // the auth server. + // + // In this case, a "correct" server should return + // either 403 (access to the resource is forbidden because the + // client's credentials are not sufficient) or 404 (either the + // repository really doesn't exist or the credentials are insufficient + // and the server doesn't allow clients to see whether repositories + // they don't have access to might exist). + // + // However, some real-world servers instead return a 401 response + // erroneously indicating that the client needs to acquire + // authorization credentials, even though they have in fact just + // done so. + // + // As a workaround for this case, we treat the response as a 404. + + testScope := ParseScope("repository:foo:pull") + authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) { + requestedScope := ParseScope(req.Form.Get("scope")) + if !runNonFatal(t, func(t testing.TB) { + qt.Assert(t, qt.DeepEquals(requestedScope, testScope)) + qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"})) + }) { + return nil, &httpError{ + statusCode: http.StatusInternalServerError, + } + } + return &wireToken{ + Token: token{requestedScope}.String(), + }, nil + }) + ts := newTargetServer(t, func(req *http.Request) *httpError { + if req.Header.Get("Authorization") == "" { + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)}, + }, + } + } + if !runNonFatal(t, func(t testing.TB) { + qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope)) + }) { + return &httpError{ + statusCode: http.StatusInternalServerError, + } + } + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)}, + }, + } + return nil + }) + client := &http.Client{ + Transport: NewStdTransport(StdTransportParams{ + Config: configFunc(func(host string) (ConfigEntry, error) { + return ConfigEntry{}, nil + }), + }), + } + req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil) + qt.Assert(t, qt.IsNil(err)) + resp, err := client.Do(req) + qt.Assert(t, qt.IsNil(err)) + defer resp.Body.Close() + qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusNotFound)) +} + func TestConfigHasAccessToken(t *testing.T) { accessToken := "somevalue" ts := newTargetServer(t, func(req *http.Request) *httpError { diff --git a/ociregistry/ociclient/error.go b/ociregistry/ociclient/error.go index 71d3cdd..fb8b66b 100644 --- a/ociregistry/ociclient/error.go +++ b/ociregistry/ociclient/error.go @@ -15,16 +15,12 @@ package ociclient import ( - "bytes" "encoding/json" - "errors" "fmt" "io" "mime" "net/http" - "strconv" "strings" - "unicode" "cuelabs.dev/go/oci/ociregistry" ) @@ -34,98 +30,26 @@ import ( // bytes. Hence, 8 KiB should be sufficient. const errorBodySizeLimit = 8 * 1024 -type wireError struct { - Code_ string `json:"code"` - Message string `json:"message,omitempty"` - Detail_ json.RawMessage `json:"detail,omitempty"` -} - -func (e *wireError) Error() string { - var buf strings.Builder - for _, r := range e.Code_ { - if r == '_' { - buf.WriteByte(' ') +func makeError(resp *http.Response) error { + var data []byte + var err error + if resp.Body != nil { + data, err = io.ReadAll(io.LimitReader(resp.Body, errorBodySizeLimit+1)) + if err != nil { + err = fmt.Errorf("cannot read error body: %v", err) + } else if len(data) > errorBodySizeLimit { + // TODO include some part of the body + err = fmt.Errorf("error body too large") } else { - buf.WriteRune(unicode.ToLower(r)) + err = makeError1(resp, data) } } - if buf.Len() == 0 { - buf.WriteString("(no code)") - } - if e.Message != "" { - buf.WriteString(": ") - buf.WriteString(e.Message) - } - if len(e.Detail_) != 0 && !bytes.Equal(e.Detail_, []byte("null")) { - buf.WriteString("; detail: ") - buf.Write(e.Detail_) - } - return buf.String() -} - -// Code implements [ociregistry.Error.Code]. -func (e *wireError) Code() string { - return e.Code_ -} - -// Detail implements [ociregistry.Error.Detail]. -func (e *wireError) Detail() any { - if len(e.Detail_) == 0 { - return nil - } - // TODO do this once only? - var d any - json.Unmarshal(e.Detail_, &d) - return d -} - -// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrBlobUnknown)` -// even when the error hasn't exactly wrapped that error. -func (e *wireError) Is(err error) bool { - var rerr ociregistry.Error - return errors.As(err, &rerr) && rerr.Code() == e.Code() -} - -type wireErrors struct { - httpStatusCode int - Errors []wireError `json:"errors"` -} - -func (e *wireErrors) Unwrap() []error { - // TODO we could do this only once. - errs := make([]error, len(e.Errors)) - for i := range e.Errors { - errs[i] = &e.Errors[i] - } - return errs -} - -// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrRangeInvalid)` -// even when the error hasn't exactly wrapped that error. -func (e *wireErrors) Is(err error) bool { - switch e.httpStatusCode { - case http.StatusRequestedRangeNotSatisfiable: - return err == ociregistry.ErrRangeInvalid - } - return false -} - -func (e *wireErrors) Error() string { - var buf strings.Builder - buf.WriteString(strconv.Itoa(e.httpStatusCode)) - buf.WriteString(" ") - buf.WriteString(http.StatusText(e.httpStatusCode)) - buf.WriteString(": ") - buf.WriteString(e.Errors[0].Error()) - for i := range e.Errors[1:] { - buf.WriteString("; ") - buf.WriteString(e.Errors[i+1].Error()) - } - return buf.String() + // We always include the status code and response in the error. + return ociregistry.NewHTTPError(err, resp.StatusCode, resp, data) } // makeError forms an error from a non-OK response. -func makeError(resp *http.Response) error { +func makeError1(resp *http.Response, bodyData []byte) error { if resp.Request.Method == "HEAD" { // When we've made a HEAD request, we can't see any of // the actual error, so we'll have to make up something @@ -143,31 +67,21 @@ func makeError(resp *http.Response) error { case http.StatusBadRequest: err = ociregistry.ErrUnsupported default: - return fmt.Errorf("error response: %v", resp.Status) + // Our caller will turn this into a non-nil error. + return nil } return fmt.Errorf("error response: %v: %w", resp.Status, err) } - if !isJSONMediaType(resp.Header.Get("Content-Type")) || resp.Request.Method == "HEAD" { - // TODO include some of the body in this case? - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("error response: %v; body: %q", resp.Status, data) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, errorBodySizeLimit+1)) - if err != nil { - return fmt.Errorf("%s: cannot read error body: %v", resp.Status, err) - } - if len(data) > errorBodySizeLimit { - // TODO include some part of the body - return fmt.Errorf("error body too large") + if ctype := resp.Header.Get("Content-Type"); !isJSONMediaType(ctype) { + return fmt.Errorf("non-JSON error response %q", ctype) } - var errs wireErrors - if err := json.Unmarshal(data, &errs); err != nil { + var errs ociregistry.WireErrors + if err := json.Unmarshal(bodyData, &errs); err != nil { return fmt.Errorf("%s: malformed error response: %v", resp.Status, err) } if len(errs.Errors) == 0 { return fmt.Errorf("%s: no errors in body (probably a server issue)", resp.Status) } - errs.httpStatusCode = resp.StatusCode return &errs } diff --git a/ociregistry/ociserver/error.go b/ociregistry/ociserver/error.go index 2df35e2..ed332d9 100644 --- a/ociregistry/ociserver/error.go +++ b/ociregistry/ociserver/error.go @@ -23,41 +23,48 @@ import ( "cuelabs.dev/go/oci/ociregistry" ) -type wireError struct { - Code string `json:"code"` - Message string `json:"message,omitempty"` - Detail any `json:"detail,omitempty"` -} - -type wireErrors struct { - Errors []wireError `json:"errors"` -} - func writeError(resp http.ResponseWriter, err error) { - e := wireError{ + e := ociregistry.WireError{ Message: err.Error(), } + // TODO perhaps we should iterate through all the + // errors instead of just choosing one. + // See https://github.com/golang/go/issues/66455 var ociErr ociregistry.Error if errors.As(err, &ociErr) { - e.Code = ociErr.Code() - e.Detail = ociErr.Detail() + e.Code_ = ociErr.Code() + if detail := ociErr.Detail(); detail != nil { + data, err := json.Marshal(detail) + if err != nil { + panic(fmt.Errorf("cannot marshal error detail: %v", err)) + } + e.Detail_ = json.RawMessage(data) + } } else { // This is contrary to spec, but it's what the Docker registry // does, so it can't be too bad. - e.Code = "UNKNOWN" + e.Code_ = "UNKNOWN" } + + // Use the HTTP status code from the error only when there isn't + // one implied from the error code. This means that the HTTP status + // is always consistent with the error code, but still allows a registry + // to return custom HTTP status codes for other codes. + httpStatus := http.StatusInternalServerError - var statusErr *httpStatusError - if errors.As(err, &statusErr) { - httpStatus = statusErr.status - } else if status, ok := errorStatuses[e.Code]; ok { + if status, ok := errorStatuses[e.Code_]; ok { httpStatus = status + } else { + var httpErr ociregistry.HTTPError + if errors.As(err, &httpErr) { + httpStatus = httpErr.StatusCode() + } } resp.Header().Set("Content-Type", "application/json") resp.WriteHeader(httpStatus) - data, err := json.Marshal(wireErrors{ - Errors: []wireError{e}, + data, err := json.Marshal(ociregistry.WireErrors{ + Errors: []ociregistry.WireError{e}, }) if err != nil { // TODO log @@ -83,29 +90,10 @@ var errorStatuses = map[string]int{ ociregistry.ErrRangeInvalid.Code(): http.StatusRequestedRangeNotSatisfiable, } -func badAPIUseError(f string, a ...any) error { - return ociregistry.NewError(fmt.Sprintf(f, a...), ociregistry.ErrUnsupported.Code(), nil) +func withHTTPCode(statusCode int, err error) error { + return ociregistry.NewHTTPError(err, statusCode, nil, nil) } -func withHTTPCode(status int, err error) error { - if err == nil { - panic("expected error to wrap") - } - return &httpStatusError{ - err: err, - status: status, - } -} - -type httpStatusError struct { - err error - status int -} - -func (e *httpStatusError) Unwrap() error { - return e.err -} - -func (e *httpStatusError) Error() string { - return e.err.Error() +func badAPIUseError(f string, a ...any) error { + return ociregistry.NewError(fmt.Sprintf(f, a...), ociregistry.ErrUnsupported.Code(), nil) } diff --git a/ociregistry/ociserver/error_test.go b/ociregistry/ociserver/error_test.go new file mode 100644 index 0000000..5a8b7f8 --- /dev/null +++ b/ociregistry/ociserver/error_test.go @@ -0,0 +1,20 @@ +// Copyright 2023 CUE Labs AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ociserver + +import "testing" + +func TestHTTPStatusCode(t *testing.T) { +} diff --git a/ociregistry/ociserver/registry.go b/ociregistry/ociserver/registry.go index dde50c5..8585b7b 100644 --- a/ociregistry/ociserver/registry.go +++ b/ociregistry/ociserver/registry.go @@ -111,6 +111,17 @@ var debugID int32 // If opts is nil, it's equivalent to passing new(Options). // // The returned handler should be registered at the site root. +// +// # Errors +// +// All HTTP responses will be JSON, formatted according to the +// OCI spec. If an error returned from backend conforms to +// [ociregistry.Error], the associated code and detail will be used. +// +// The HTTP response code will be determined from the error +// code when possible. If it can't be determined and the +// error implements [ociregistry.HTTPError], the code returned +// by StatusCode will be used as the HTTP response code. func New(backend ociregistry.Interface, opts *Options) http.Handler { if opts == nil { opts = new(Options)