diff --git a/ociregistry/error.go b/ociregistry/error.go index 09368e0..466eba7 100644 --- a/ociregistry/error.go +++ b/ociregistry/error.go @@ -14,6 +14,104 @@ package ociregistry +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "unicode" +) + +// WireErrors is the JSON format used for error responses in +// the OCI HTTP API. It should always contain at least one +// error. +type WireErrors struct { + Errors []WireError `json:"errors"` +} + +// Unwrap allows [errors.Is] and [errors.As] to +// see the errors inside e. +func (e *WireErrors) Unwrap() []error { + // TODO we could do this only once. + errs := make([]error, len(e.Errors)) + for i := range e.Errors { + errs[i] = &e.Errors[i] + } + return errs +} + +func (e *WireErrors) Error() string { + var buf strings.Builder + buf.WriteString(e.Errors[0].Error()) + for i := range e.Errors[1:] { + buf.WriteString("; ") + buf.WriteString(e.Errors[i+1].Error()) + } + return buf.String() +} + +// WireError holds a single error in an OCI HTTP response. +type WireError struct { + Code_ string `json:"code"` + Message string `json:"message,omitempty"` + // Detail_ holds the JSON detail for the message. + // It's assumed to be valid JSON if non-empty. + Detail_ json.RawMessage `json:"detail,omitempty"` +} + +// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrBlobUnknown)` +// even when the error hasn't exactly wrapped that error. +func (e *WireError) Is(err error) bool { + var rerr Error + return errors.As(err, &rerr) && rerr.Code() == e.Code() +} + +// Error implements the [error] interface. +func (e *WireError) Error() string { + var buf strings.Builder + for _, r := range e.Code_ { + if r == '_' { + buf.WriteByte(' ') + } else { + buf.WriteRune(unicode.ToLower(r)) + } + } + if buf.Len() == 0 { + buf.WriteString("(no code)") + } + if e.Message != "" { + buf.WriteString(": ") + buf.WriteString(e.Message) + } + if len(e.Detail_) != 0 && !bytes.Equal(e.Detail_, []byte("null")) { + buf.WriteString("; detail: ") + buf.Write(e.Detail_) + } + return buf.String() +} + +// Code implements [Error.Code]. +func (e *WireError) Code() string { + return e.Code_ +} + +// Detail implements [Error.Detail]. +// It panics if e.Detail_ contains invalid JSON. +func (e *WireError) Detail() any { + if len(e.Detail_) == 0 { + return nil + } + // TODO do this once only? + var d any + if err := json.Unmarshal(e.Detail_, &d); err != nil { + panic(fmt.Errorf("invalid error detail JSON %q: %v", e.Detail_, err)) + } + return d +} + // NewError returns a new error with the given code, message and detail. func NewError(msg string, code string, detail any) Error { return ®istryError{ @@ -31,7 +129,7 @@ type Error interface { // error.Error provides the error message. error - // Code returns the error code. See + // Code returns the error code. Code() string // Detail returns any detail to be associated with the error; it should @@ -39,6 +137,102 @@ type Error interface { Detail() any } +// HTTPError is optionally implemented by an error when +// the error has originated from an HTTP request +// or might be returned from one. +type HTTPError interface { + error + + // StatusCode returns the HTTP status code of the response. + StatusCode() int + + // Response holds the HTTP response that caused the HTTPError to + // be created. It will return nil if the error was not created + // as a result of an HTTP response. + // + // The caller should not read the response body or otherwise + // change the response (mutation of errors is a Bad Thing). + // + // Use the ResponseBody method to obtain the body of the + // response if needed. + Response() *http.Response + + // ResponseBody returns the contents of the response body. It + // will return nil if the error was not created as a result of + // an HTTP response. + // + // The caller should not change or append to the returned data. + ResponseBody() []byte +} + +// NewHTTPError returns an error that wraps err to make an [HTTPError] +// that represents the given status code, response and response body. +// Both response and body may be nil. +// +// A shallow copy is made of the response. +func NewHTTPError(err error, statusCode int, response *http.Response, body []byte) HTTPError { + herr := &httpError{ + underlying: err, + statusCode: statusCode, + } + if response != nil { + herr.response = ref(*response) + herr.response.Body = nil + herr.body = body + } + return herr +} + +type httpError struct { + underlying error + statusCode int + response *http.Response + body []byte +} + +// Unwrap implements the [errors] Unwrap interface. +func (e *httpError) Unwrap() error { + return e.underlying +} + +// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrRangeInvalid)` +// even when the error hasn't exactly wrapped that error. +func (e *httpError) Is(err error) bool { + switch e.statusCode { + case http.StatusRequestedRangeNotSatisfiable: + return err == ErrRangeInvalid + } + return false +} + +// Error implements [error.Error]. +func (e *httpError) Error() string { + var buf strings.Builder + buf.WriteString(strconv.Itoa(e.statusCode)) + buf.WriteString(" ") + buf.WriteString(http.StatusText(e.statusCode)) + if e.underlying != nil { + buf.WriteString(": ") + buf.WriteString(e.underlying.Error()) + } + return buf.String() +} + +// StatusCode implements [HTTPError.StatusCode]. +func (e *httpError) StatusCode() int { + return e.statusCode +} + +// Response implements [HTTPError.Response]. +func (e *httpError) Response() *http.Response { + return e.response +} + +// ResponseBody implements [HTTPError.ResponseBody]. +func (e *httpError) ResponseBody() []byte { + return e.body +} + // The following values represent the known error codes. var ( ErrBlobUnknown = NewError("blob unknown to registry", "BLOB_UNKNOWN", nil) @@ -83,3 +277,7 @@ func (e *registryError) Error() string { func (e *registryError) Detail() any { return e.detail } + +func ref[T any](x T) *T { + return &x +} diff --git a/ociregistry/ociauth/auth.go b/ociregistry/ociauth/auth.go index 70f00ba..efbc8e0 100644 --- a/ociregistry/ociauth/auth.go +++ b/ociregistry/ociauth/auth.go @@ -12,6 +12,8 @@ import ( "strings" "sync" "time" + + "cuelabs.dev/go/oci/ociregistry" ) // TODO decide on a good value for this. @@ -298,8 +300,8 @@ func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantSc scope := requiredScope.Union(wantScope) tok, err := r.acquireToken(ctx, scope) if err != nil { - var rerr *responseError - if !errors.As(err, &rerr) || rerr.statusCode != http.StatusUnauthorized { + var herr ociregistry.HTTPError + if !errors.As(err, &herr) || herr.StatusCode() != http.StatusUnauthorized { return "", err } // The documentation says this: @@ -372,8 +374,8 @@ func (r *registry) acquireToken(ctx context.Context, scope Scope) (*wireToken, e if err == nil { return tok, nil } - var rerr *responseError - if !errors.As(err, &rerr) || rerr.statusCode != http.StatusNotFound { + var herr ociregistry.HTTPError + if !errors.As(err, &herr) || herr.StatusCode() != http.StatusNotFound { return tok, err } // The request to the endpoint returned 404 from the POST request, @@ -449,7 +451,8 @@ func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) { } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errorFromResponse(resp) + // TODO include body of response in error message. + return nil, ociregistry.NewHTTPError(nil, resp.StatusCode, resp, nil) } data, err := io.ReadAll(resp.Body) if err != nil { @@ -462,22 +465,6 @@ func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) { return &tok, nil } -type responseError struct { - statusCode int - msg string -} - -func errorFromResponse(resp *http.Response) error { - // TODO include body of response in error message. - return &responseError{ - statusCode: resp.StatusCode, - } -} - -func (e *responseError) Error() string { - return fmt.Sprintf("unexpected HTTP response %d", e.statusCode) -} - // deleteExpiredTokens removes all tokens from r that expire after the given // time. // TODO ask the store to remove expired tokens? diff --git a/ociregistry/ociclient/error.go b/ociregistry/ociclient/error.go index 71d3cdd..bf2c9e7 100644 --- a/ociregistry/ociclient/error.go +++ b/ociregistry/ociclient/error.go @@ -15,16 +15,12 @@ package ociclient import ( - "bytes" "encoding/json" - "errors" "fmt" "io" "mime" "net/http" - "strconv" "strings" - "unicode" "cuelabs.dev/go/oci/ociregistry" ) @@ -34,98 +30,28 @@ import ( // bytes. Hence, 8 KiB should be sufficient. const errorBodySizeLimit = 8 * 1024 -type wireError struct { - Code_ string `json:"code"` - Message string `json:"message,omitempty"` - Detail_ json.RawMessage `json:"detail,omitempty"` -} - -func (e *wireError) Error() string { - var buf strings.Builder - for _, r := range e.Code_ { - if r == '_' { - buf.WriteByte(' ') +// makeError forms an error from a non-OK response. +// +// It reads but does not close resp.Body. +func makeError(resp *http.Response) error { + var data []byte + var err error + if resp.Body != nil { + data, err = io.ReadAll(io.LimitReader(resp.Body, errorBodySizeLimit+1)) + if err != nil { + err = fmt.Errorf("cannot read error body: %v", err) + } else if len(data) > errorBodySizeLimit { + // TODO include some part of the body + err = fmt.Errorf("error body too large") } else { - buf.WriteRune(unicode.ToLower(r)) + err = makeError1(resp, data) } } - if buf.Len() == 0 { - buf.WriteString("(no code)") - } - if e.Message != "" { - buf.WriteString(": ") - buf.WriteString(e.Message) - } - if len(e.Detail_) != 0 && !bytes.Equal(e.Detail_, []byte("null")) { - buf.WriteString("; detail: ") - buf.Write(e.Detail_) - } - return buf.String() -} - -// Code implements [ociregistry.Error.Code]. -func (e *wireError) Code() string { - return e.Code_ -} - -// Detail implements [ociregistry.Error.Detail]. -func (e *wireError) Detail() any { - if len(e.Detail_) == 0 { - return nil - } - // TODO do this once only? - var d any - json.Unmarshal(e.Detail_, &d) - return d -} - -// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrBlobUnknown)` -// even when the error hasn't exactly wrapped that error. -func (e *wireError) Is(err error) bool { - var rerr ociregistry.Error - return errors.As(err, &rerr) && rerr.Code() == e.Code() -} - -type wireErrors struct { - httpStatusCode int - Errors []wireError `json:"errors"` -} - -func (e *wireErrors) Unwrap() []error { - // TODO we could do this only once. - errs := make([]error, len(e.Errors)) - for i := range e.Errors { - errs[i] = &e.Errors[i] - } - return errs -} - -// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrRangeInvalid)` -// even when the error hasn't exactly wrapped that error. -func (e *wireErrors) Is(err error) bool { - switch e.httpStatusCode { - case http.StatusRequestedRangeNotSatisfiable: - return err == ociregistry.ErrRangeInvalid - } - return false + // We always include the status code and response in the error. + return ociregistry.NewHTTPError(err, resp.StatusCode, resp, data) } -func (e *wireErrors) Error() string { - var buf strings.Builder - buf.WriteString(strconv.Itoa(e.httpStatusCode)) - buf.WriteString(" ") - buf.WriteString(http.StatusText(e.httpStatusCode)) - buf.WriteString(": ") - buf.WriteString(e.Errors[0].Error()) - for i := range e.Errors[1:] { - buf.WriteString("; ") - buf.WriteString(e.Errors[i+1].Error()) - } - return buf.String() -} - -// makeError forms an error from a non-OK response. -func makeError(resp *http.Response) error { +func makeError1(resp *http.Response, bodyData []byte) error { if resp.Request.Method == "HEAD" { // When we've made a HEAD request, we can't see any of // the actual error, so we'll have to make up something @@ -143,31 +69,21 @@ func makeError(resp *http.Response) error { case http.StatusBadRequest: err = ociregistry.ErrUnsupported default: - return fmt.Errorf("error response: %v", resp.Status) + // Our caller will turn this into a non-nil error. + return nil } return fmt.Errorf("error response: %v: %w", resp.Status, err) } - if !isJSONMediaType(resp.Header.Get("Content-Type")) || resp.Request.Method == "HEAD" { - // TODO include some of the body in this case? - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("error response: %v; body: %q", resp.Status, data) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, errorBodySizeLimit+1)) - if err != nil { - return fmt.Errorf("%s: cannot read error body: %v", resp.Status, err) - } - if len(data) > errorBodySizeLimit { - // TODO include some part of the body - return fmt.Errorf("error body too large") + if ctype := resp.Header.Get("Content-Type"); !isJSONMediaType(ctype) { + return fmt.Errorf("non-JSON error response %q; body %q", ctype, bodyData) } - var errs wireErrors - if err := json.Unmarshal(data, &errs); err != nil { + var errs ociregistry.WireErrors + if err := json.Unmarshal(bodyData, &errs); err != nil { return fmt.Errorf("%s: malformed error response: %v", resp.Status, err) } if len(errs.Errors) == 0 { return fmt.Errorf("%s: no errors in body (probably a server issue)", resp.Status) } - errs.httpStatusCode = resp.StatusCode return &errs } diff --git a/ociregistry/ociclient/error_test.go b/ociregistry/ociclient/error_test.go new file mode 100644 index 0000000..9c9e44f --- /dev/null +++ b/ociregistry/ociclient/error_test.go @@ -0,0 +1,40 @@ +package ociclient + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "cuelabs.dev/go/oci/ociregistry" + "cuelabs.dev/go/oci/ociregistry/ocitest" + "github.com/go-quicktest/qt" +) + +func TestNonJSONErrorResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusTeapot) + w.Write([]byte("some body")) + })) + defer srv.Close() + + srvURL, _ := url.Parse(srv.URL) + r, err := New(srvURL.Host, &Options{ + Insecure: true, + }) + qt.Assert(t, qt.IsNil(err)) + // TODO(go1.23) for method, call := range ocitest.MethodCalls() { + ocitest.MethodCalls()(func(method ocitest.Method, call ocitest.MethodCall) bool { + t.Run(method.String(), func(t *testing.T) { + err := call(context.Background(), r) + t.Logf("call error: %v", err) + var herr ociregistry.HTTPError + ok := errors.As(err, &herr) + qt.Assert(t, qt.IsTrue(ok)) + qt.Assert(t, qt.Equals(herr.StatusCode(), http.StatusTeapot)) + }) + return true + }) +} diff --git a/ociregistry/ociclient/writer.go b/ociregistry/ociclient/writer.go index 1acf197..3d8fc3c 100644 --- a/ociregistry/ociclient/writer.go +++ b/ociregistry/ociclient/writer.go @@ -22,6 +22,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "sync" "github.com/opencontainers/go-digest" @@ -223,6 +224,14 @@ func (c *client) PushBlobChunkedResume(ctx context.Context, repo string, id stri if err != nil { return nil, fmt.Errorf("provided ID is not a valid location URL") } + if !strings.HasPrefix(location.Path, "/") { + // Our BlobWriter.ID method always returns a fully + // qualified absolute URL, so this must be a mistake + // on the part of the caller. + // We allow a relative URL even though we don't + // ever return one to make things a bit easier for tests. + return nil, fmt.Errorf("provided upload ID %q has unexpected relative URL path", id) + } } ctx = ociauth.ContextWithRequestInfo(ctx, ociauth.RequestInfo{ RequiredScope: ociauth.NewScope(ociauth.ResourceScope{ @@ -378,7 +387,7 @@ func (w *blobWriter) Commit(digest ociregistry.Digest) (ociregistry.Descriptor, w.mu.Lock() defer w.mu.Unlock() if err := w.flush(nil, digest); err != nil { - return ociregistry.Descriptor{}, fmt.Errorf("cannot flush data before commit: %v", err) + return ociregistry.Descriptor{}, fmt.Errorf("cannot flush data before commit: %w", err) } return ociregistry.Descriptor{ MediaType: "application/octet-stream", diff --git a/ociregistry/ociserver/error.go b/ociregistry/ociserver/error.go index 2df35e2..2fcb0e6 100644 --- a/ociregistry/ociserver/error.go +++ b/ociregistry/ociserver/error.go @@ -23,41 +23,47 @@ import ( "cuelabs.dev/go/oci/ociregistry" ) -type wireError struct { - Code string `json:"code"` - Message string `json:"message,omitempty"` - Detail any `json:"detail,omitempty"` -} - -type wireErrors struct { - Errors []wireError `json:"errors"` -} - func writeError(resp http.ResponseWriter, err error) { - e := wireError{ + e := ociregistry.WireError{ Message: err.Error(), } + // TODO perhaps we should iterate through all the + // errors instead of just choosing one. + // See https://github.com/golang/go/issues/66455 var ociErr ociregistry.Error if errors.As(err, &ociErr) { - e.Code = ociErr.Code() - e.Detail = ociErr.Detail() + e.Code_ = ociErr.Code() + if detail := ociErr.Detail(); detail != nil { + data, err := json.Marshal(detail) + if err != nil { + panic(fmt.Errorf("cannot marshal error detail: %v", err)) + } + e.Detail_ = json.RawMessage(data) + } } else { // This is contrary to spec, but it's what the Docker registry // does, so it can't be too bad. - e.Code = "UNKNOWN" + e.Code_ = "UNKNOWN" } + + // Use the HTTP status code from the error only when there isn't + // one implied from the error code. This means that the HTTP status + // is always consistent with the error code, but still allows a registry + // to return custom HTTP status codes for other codes. httpStatus := http.StatusInternalServerError - var statusErr *httpStatusError - if errors.As(err, &statusErr) { - httpStatus = statusErr.status - } else if status, ok := errorStatuses[e.Code]; ok { + if status, ok := errorStatuses[e.Code_]; ok { httpStatus = status + } else { + var httpErr ociregistry.HTTPError + if errors.As(err, &httpErr) { + httpStatus = httpErr.StatusCode() + } } resp.Header().Set("Content-Type", "application/json") resp.WriteHeader(httpStatus) - data, err := json.Marshal(wireErrors{ - Errors: []wireError{e}, + data, err := json.Marshal(ociregistry.WireErrors{ + Errors: []ociregistry.WireError{e}, }) if err != nil { // TODO log @@ -83,29 +89,10 @@ var errorStatuses = map[string]int{ ociregistry.ErrRangeInvalid.Code(): http.StatusRequestedRangeNotSatisfiable, } -func badAPIUseError(f string, a ...any) error { - return ociregistry.NewError(fmt.Sprintf(f, a...), ociregistry.ErrUnsupported.Code(), nil) -} - -func withHTTPCode(status int, err error) error { - if err == nil { - panic("expected error to wrap") - } - return &httpStatusError{ - err: err, - status: status, - } -} - -type httpStatusError struct { - err error - status int +func withHTTPCode(statusCode int, err error) error { + return ociregistry.NewHTTPError(err, statusCode, nil, nil) } -func (e *httpStatusError) Unwrap() error { - return e.err -} - -func (e *httpStatusError) Error() string { - return e.err.Error() +func badAPIUseError(f string, a ...any) error { + return ociregistry.NewError(fmt.Sprintf(f, a...), ociregistry.ErrUnsupported.Code(), nil) } diff --git a/ociregistry/ociserver/error_test.go b/ociregistry/ociserver/error_test.go new file mode 100644 index 0000000..ceed1a4 --- /dev/null +++ b/ociregistry/ociserver/error_test.go @@ -0,0 +1,75 @@ +// Copyright 2023 CUE Labs AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ociserver + +import ( + "context" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "cuelabs.dev/go/oci/ociregistry" + + "github.com/go-quicktest/qt" +) + +func TestHTTPStatusOverriddenByErrorCode(t *testing.T) { + // Test that if an Interface method returns an HTTPError error, the + // HTTP status code is derived from the OCI error code in preference + // to the HTTPError status code. + r := New(&ociregistry.Funcs{ + GetTag_: func(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { + return nil, ociregistry.NewHTTPError(ociregistry.ErrNameUnknown, http.StatusUnauthorized, nil, nil) + }, + }, nil) + s := httptest.NewServer(r) + defer s.Close() + resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag") + qt.Assert(t, qt.IsNil(err)) + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusNotFound)) + qt.Assert(t, qt.JSONEquals(body, &ociregistry.WireErrors{ + Errors: []ociregistry.WireError{{ + Code_: ociregistry.ErrNameUnknown.Code(), + Message: "401 Unauthorized: repository name not known to registry", + }}, + })) +} + +func TestHTTPStatusUsedForUnknownErrorCode(t *testing.T) { + // Test that if an Interface method returns an HTTPError error, that + // HTTP status code is used when the code isn't known to be + // associated with a particular HTTP status. + r := New(&ociregistry.Funcs{ + GetTag_: func(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { + return nil, ociregistry.NewHTTPError(ociregistry.NewError("foo", "SOMECODE", nil), http.StatusTeapot, nil, nil) + }, + }, nil) + s := httptest.NewServer(r) + defer s.Close() + resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag") + qt.Assert(t, qt.IsNil(err)) + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusTeapot)) + qt.Assert(t, qt.JSONEquals(body, &ociregistry.WireErrors{ + Errors: []ociregistry.WireError{{ + Code_: "SOMECODE", + Message: "418 I'm a teapot: foo", + }}, + })) +} diff --git a/ociregistry/ociserver/registry.go b/ociregistry/ociserver/registry.go index dde50c5..8585b7b 100644 --- a/ociregistry/ociserver/registry.go +++ b/ociregistry/ociserver/registry.go @@ -111,6 +111,17 @@ var debugID int32 // If opts is nil, it's equivalent to passing new(Options). // // The returned handler should be registered at the site root. +// +// # Errors +// +// All HTTP responses will be JSON, formatted according to the +// OCI spec. If an error returned from backend conforms to +// [ociregistry.Error], the associated code and detail will be used. +// +// The HTTP response code will be determined from the error +// code when possible. If it can't be determined and the +// error implements [ociregistry.HTTPError], the code returned +// by StatusCode will be used as the HTTP response code. func New(backend ociregistry.Interface, opts *Options) http.Handler { if opts == nil { opts = new(Options)