Skip to content

Commit

Permalink
ociregistry: improve errors with respect to HTTP status
Browse files Browse the repository at this point in the history
TODO more explanation

Fixes #26.

Signed-off-by: Roger Peppe <rogpeppe@gmail.com>
Change-Id: I5a79c1c6fec9f22c1f565830d73486b406dd181d
  • Loading branch information
rogpeppe committed Mar 22, 2024
1 parent 16e7651 commit c7310d6
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 152 deletions.
194 changes: 193 additions & 1 deletion ociregistry/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,98 @@

package ociregistry

import (
"bytes"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"unicode"
)

// WireErrors is the JSON format used for error responses in
// the OCI HTTP API. It should always contain at least one
// error.
type WireErrors struct {
Errors []WireError `json:"errors"`
}

// Unwrap allows [errors.Is] and [errors.As] to
// see the errors inside e.
func (e *WireErrors) Unwrap() []error {
// TODO we could do this only once.
errs := make([]error, len(e.Errors))
for i := range e.Errors {
errs[i] = &e.Errors[i]
}
return errs
}

func (e *WireErrors) Error() string {
var buf strings.Builder
buf.WriteString(e.Errors[0].Error())
for i := range e.Errors[1:] {
buf.WriteString("; ")
buf.WriteString(e.Errors[i+1].Error())
}
return buf.String()
}

// WireError holds a single error in an OCI HTTP response.
type WireError struct {
Code_ string `json:"code"`
Message string `json:"message,omitempty"`
Detail_ json.RawMessage `json:"detail,omitempty"`
}

// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrBlobUnknown)`
// even when the error hasn't exactly wrapped that error.
func (e *WireError) Is(err error) bool {
var rerr Error
return errors.As(err, &rerr) && rerr.Code() == e.Code()
}

// Error implements the [error] interface.
func (e *WireError) Error() string {
var buf strings.Builder
for _, r := range e.Code_ {
if r == '_' {
buf.WriteByte(' ')
} else {
buf.WriteRune(unicode.ToLower(r))
}
}
if buf.Len() == 0 {
buf.WriteString("(no code)")
}
if e.Message != "" {
buf.WriteString(": ")
buf.WriteString(e.Message)
}
if len(e.Detail_) != 0 && !bytes.Equal(e.Detail_, []byte("null")) {
buf.WriteString("; detail: ")
buf.Write(e.Detail_)
}
return buf.String()
}

// Code implements [Error.Code].
func (e *WireError) Code() string {
return e.Code_
}

// Detail implements [Error.Detail].
func (e *WireError) Detail() any {
if len(e.Detail_) == 0 {
return nil
}
// TODO do this once only?
var d any
json.Unmarshal(e.Detail_, &d)
return d
}

// NewError returns a new error with the given code, message and detail.
func NewError(msg string, code string, detail any) Error {
return &registryError{
Expand All @@ -31,14 +123,110 @@ type Error interface {
// error.Error provides the error message.
error

// Code returns the error code. See
// Code returns the error code.
Code() string

// Detail returns any detail to be associated with the error; it should
// be JSON-marshable.
Detail() any
}

// HTTPError is optionally implemented by an error when
// the error has originated from an HTTP request
// or might be returned from one.
type HTTPError interface {
error

// StatusCode returns the HTTP status code of the response.
StatusCode() int

// Response holds the HTTP response that caused the HTTPError to
// be created. It will return nil if the error was not created
// as a result of an HTTP response.
//
// The caller should not read the response body or otherwise
// change the response (mutation of errors is a Bad Thing).
//
// Use the ResponseBody method to obtain the body of the
// response if needed.
Response() *http.Response

// ResponseBody returns the contents of the response body. It
// will return nil if the error was not created as a result of
// an HTTP response.
//
// The caller should not change or append to the returned data.
ResponseBody() []byte
}

// NewHTTPError returns an error that wraps err to make an [HTTPError]
// that represents the given status code, response and response body.
// Both response and body may be nil.
//
// A shallow copy is made of the response.
func NewHTTPError(err error, statusCode int, response *http.Response, body []byte) HTTPError {
herr := &httpError{
underlying: err,
statusCode: statusCode,
}
if response != nil {
herr.response = ref(*response)
herr.response.Body = nil
herr.body = body
}
return herr
}

type httpError struct {
underlying error
statusCode int
response *http.Response
body []byte
}

// Unwrap implements the [errors] Unwrap interface.
func (e *httpError) Unwrap(err error) error {
return e.underlying
}

// Is makes it possible for users to write `if errors.Is(err, ociregistry.ErrRangeInvalid)`
// even when the error hasn't exactly wrapped that error.
func (e *httpError) Is(err error) bool {
switch e.statusCode {
case http.StatusRequestedRangeNotSatisfiable:
return err == ErrRangeInvalid
}
return false
}

// Error implements [error.Error].
func (e *httpError) Error() string {
var buf strings.Builder
buf.WriteString(strconv.Itoa(e.statusCode))
buf.WriteString(" ")
buf.WriteString(http.StatusText(e.statusCode))
if e.underlying != nil {
buf.WriteString(": ")
buf.WriteString(e.underlying.Error())
}
return buf.String()
}

// StatusCode implements [HTTPEror.StatusCode].
func (e *httpError) StatusCode() int {
return e.statusCode
}

// Response implements [HTTPEror.Response].
func (e *httpError) Response() *http.Response {
return e.response
}

// ResponseBody implements [HTTPEror.ResponseBody].
func (e *httpError) ResponseBody() []byte {
return e.body
}

// The following values represent the known error codes.
var (
ErrBlobUnknown = NewError("blob unknown to registry", "BLOB_UNKNOWN", nil)
Expand Down Expand Up @@ -83,3 +271,7 @@ func (e *registryError) Error() string {
func (e *registryError) Detail() any {
return e.detail
}

func ref[T any](x T) *T {
return &x
}
2 changes: 1 addition & 1 deletion ociregistry/iter_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:build go1.23 || goexperiment.rangefunc
//go:build go1.23

package ociregistry

Expand Down
15 changes: 14 additions & 1 deletion ociregistry/ociauth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,20 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}
}
return r.transport.RoundTrip(req)
resp, err = r.transport.RoundTrip(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
// The server has responded with a 401 even though we've just
// provided a token that it gave us. Treat it as a 404 instead.
resp.Body.Close()
resp.Body = io.NopCloser(strings.NewReader("TODO"))
resp.StatusCode = http.StatusNotFound
resp.Status = http.StatusText(resp.StatusCode)
return resp, nil
}

// setAuthorization sets up authorization on the given request using any
Expand Down
73 changes: 73 additions & 0 deletions ociregistry/ociauth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,79 @@ func TestAuthNotAvailableAfterChallenge(t *testing.T) {
qt.Check(t, qt.Equals(requestCount, 1))
}

func Test401ResponseWithJustAcquiredToken(t *testing.T) {
// This tests the scenario where a server returns a 401 response
// when the client has just successfully acquired a token from
// the auth server.
//
// In this case, a "correct" server should return
// either 403 (access to the resource is forbidden because the
// client's credentials are not sufficient) or 404 (either the
// repository really doesn't exist or the credentials are insufficient
// and the server doesn't allow clients to see whether repositories
// they don't have access to might exist).
//
// However, some real-world servers instead return a 401 response
// erroneously indicating that the client needs to acquire
// authorization credentials, even though they have in fact just
// done so.
//
// As a workaround for this case, we treat the response as a 404.

testScope := ParseScope("repository:foo:pull")
authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
requestedScope := ParseScope(req.Form.Get("scope"))
if !runNonFatal(t, func(t testing.TB) {
qt.Assert(t, qt.DeepEquals(requestedScope, testScope))
qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"}))
}) {
return nil, &httpError{
statusCode: http.StatusInternalServerError,
}
}
return &wireToken{
Token: token{requestedScope}.String(),
}, nil
})
ts := newTargetServer(t, func(req *http.Request) *httpError {
if req.Header.Get("Authorization") == "" {
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
},
}
}
if !runNonFatal(t, func(t testing.TB) {
qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope))
}) {
return &httpError{
statusCode: http.StatusInternalServerError,
}
}
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
},
}
return nil
})
client := &http.Client{
Transport: NewStdTransport(StdTransportParams{
Config: configFunc(func(host string) (ConfigEntry, error) {
return ConfigEntry{}, nil
}),
}),
}
req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil)
qt.Assert(t, qt.IsNil(err))
resp, err := client.Do(req)
qt.Assert(t, qt.IsNil(err))
defer resp.Body.Close()
qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusNotFound))
}

func TestConfigHasAccessToken(t *testing.T) {
accessToken := "somevalue"
ts := newTargetServer(t, func(req *http.Request) *httpError {
Expand Down
Loading

0 comments on commit c7310d6

Please sign in to comment.