diff --git a/changelog/unreleased/unify-datagateway-method-handling.md b/changelog/unreleased/unify-datagateway-method-handling.md new file mode 100644 index 0000000000..44397d74b7 --- /dev/null +++ b/changelog/unreleased/unify-datagateway-method-handling.md @@ -0,0 +1,5 @@ +Bugfix: unify datagateway method handling + +The datagateway now unpacks and forwards all HTTP methods + +https://github.com/cs3org/reva/pull/4527 diff --git a/internal/http/services/datagateway/datagateway.go b/internal/http/services/datagateway/datagateway.go index 13411fc36d..53ba79af81 100644 --- a/internal/http/services/datagateway/datagateway.go +++ b/internal/http/services/datagateway/datagateway.go @@ -50,8 +50,6 @@ func init() { const ( // TokenTransportHeader holds the header key for the reva transfer token TokenTransportHeader = "X-Reva-Transfer" - // UploadExpiresHeader holds the timestamp for the transport token expiry, defined in https://tus.io/protocols/resumable-upload.html#expiration - UploadExpiresHeader = "Upload-Expires" ) func init() { @@ -133,31 +131,13 @@ func (s *svc) setHandler() { semconv.HTTPURLKey.String(r.URL.String()), ) r = r.WithContext(ctx) - switch r.Method { - case "HEAD": - s.doHead(w, r) - return - case "GET": - s.doGet(w, r) - return - case "PUT": - s.doPut(w, r) - return - case "PATCH": - s.doPatch(w, r) - return - case "OPTIONS": - s.doOptions(w, r) - return - default: - w.WriteHeader(http.StatusNotImplemented) - return - } + s.doRequest(w, r) }) } +// verify extracts the transfer token from the request +// If it is not set as header we assume that it's the last path segment instead. func (s *svc) verify(ctx context.Context, r *http.Request) (*transferClaims, error) { - // Extract transfer token from request header. If not existing, assume that it's the last path segment instead. token := r.Header.Get(TokenTransportHeader) if token == "" { token = path.Base(r.URL.Path) @@ -180,112 +160,7 @@ func (s *svc) verify(ctx context.Context, r *http.Request) (*transferClaims, err return nil, err } -func (s *svc) doHead(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := appctx.GetLogger(ctx) - - claims, err := s.verify(ctx, r) - if err != nil { - err = errors.Wrap(err, "datagateway: error validating transfer token") - log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token") - w.WriteHeader(http.StatusForbidden) - return - } - - log.Debug().Str("target", claims.Target).Msg("sending request to internal data server") - - httpClient := s.client - httpReq, err := rhttp.NewRequest(ctx, "HEAD", claims.Target, nil) - if err != nil { - log.Error().Err(err).Msg("wrong request") - w.WriteHeader(http.StatusInternalServerError) - return - } - httpReq.Header = r.Header - - httpRes, err := httpClient.Do(httpReq) - if err != nil { - log.Error().Err(err).Msg("error doing HEAD request to data service") - w.WriteHeader(http.StatusInternalServerError) - return - } - defer httpRes.Body.Close() - - copyHeader(w.Header(), httpRes.Header) - - // add upload expiry / transfer token expiry header for tus https://tus.io/protocols/resumable-upload.html#expiration - w.Header().Set(UploadExpiresHeader, time.Unix(claims.ExpiresAt, 0).Format(time.RFC1123)) - - if httpRes.StatusCode != http.StatusOK { - // swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it - w.Header().Set("Content-Length", "0") - w.WriteHeader(httpRes.StatusCode) - return - } - - w.WriteHeader(http.StatusOK) -} - -func (s *svc) doGet(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := appctx.GetLogger(ctx) - - claims, err := s.verify(ctx, r) - if err != nil { - err = errors.Wrap(err, "datagateway: error validating transfer token") - log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token") - w.WriteHeader(http.StatusForbidden) - return - } - - log.Debug().Str("target", claims.Target).Msg("sending request to internal data server") - - httpClient := s.client - httpReq, err := rhttp.NewRequest(ctx, "GET", claims.Target, nil) - if err != nil { - log.Error().Err(err).Msg("wrong request") - w.WriteHeader(http.StatusInternalServerError) - return - } - httpReq.Header = r.Header - - httpRes, err := httpClient.Do(httpReq) - if err != nil { - log.Error().Err(err).Msg("error doing GET request to data service") - w.WriteHeader(http.StatusInternalServerError) - return - } - defer httpRes.Body.Close() - - copyHeader(w.Header(), httpRes.Header) - switch httpRes.StatusCode { - case http.StatusOK: - case http.StatusPartialContent: - default: - // swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it - w.Header().Set("Content-Length", "0") - w.WriteHeader(httpRes.StatusCode) - return - } - w.WriteHeader(httpRes.StatusCode) - - var c int64 - c, err = io.Copy(w, httpRes.Body) - if err != nil { - log.Error().Err(err).Msg("error writing body after headers were sent") - } - if httpRes.Header.Get("Content-Length") != "" { - i, err := strconv.ParseInt(httpRes.Header.Get("Content-Length"), 10, 64) - if err != nil { - log.Error().Err(err).Str("content-length", httpRes.Header.Get("Content-Length")).Msg("invalid content length in dataprovider response") - } - if i != c { - log.Error().Int64("content-length", i).Int64("transferred-bytes", c).Msg("content length vs transferred bytes mismatch") - } - } -} - -func (s *svc) doPut(w http.ResponseWriter, r *http.Request) { +func (s *svc) doRequest(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := appctx.GetLogger(ctx) @@ -309,10 +184,9 @@ func (s *svc) doPut(w http.ResponseWriter, r *http.Request) { targetURL.RawQuery = r.URL.RawQuery target = targetURL.String() - log.Debug().Str("target", claims.Target).Msg("sending request to internal data server") + log.Debug().Str("target", target).Msg("sending request to internal data server") - httpClient := s.client - httpReq, err := rhttp.NewRequest(ctx, "PUT", target, r.Body) + httpReq, err := rhttp.NewRequest(ctx, r.Method, target, r.Body) if err != nil { log.Err(err).Msg("wrong request") w.WriteHeader(http.StatusInternalServerError) @@ -321,68 +195,9 @@ func (s *svc) doPut(w http.ResponseWriter, r *http.Request) { httpReq.Header = r.Header httpReq.ContentLength = r.ContentLength - httpRes, err := httpClient.Do(httpReq) + httpRes, err := s.client.Do(httpReq) if err != nil { - log.Err(err).Msg("error doing PUT request to data service") - w.WriteHeader(http.StatusInternalServerError) - return - } - defer httpRes.Body.Close() - - copyHeader(w.Header(), httpRes.Header) - if httpRes.StatusCode != http.StatusOK { - // swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it - w.Header().Set("Content-Length", "0") - w.WriteHeader(httpRes.StatusCode) - return - } - - w.WriteHeader(http.StatusOK) - _, err = io.Copy(w, httpRes.Body) - if err != nil { - log.Err(err).Msg("error writing body after header were set") - } -} - -// TODO: put and post code is pretty much the same. Should be solved in a nicer way in the long run. -func (s *svc) doPatch(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := appctx.GetLogger(ctx) - - claims, err := s.verify(ctx, r) - if err != nil { - err = errors.Wrap(err, "datagateway: error validating transfer token") - log.Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token") - w.WriteHeader(http.StatusForbidden) - return - } - - target := claims.Target - // add query params to target, clients can send checksums and other information. - targetURL, err := url.Parse(target) - if err != nil { - log.Err(err).Msg("datagateway: error parsing target url") - w.WriteHeader(http.StatusInternalServerError) - return - } - - targetURL.RawQuery = r.URL.RawQuery - target = targetURL.String() - - log.Debug().Str("target", claims.Target).Msg("sending request to internal data server") - - httpClient := s.client - httpReq, err := rhttp.NewRequest(ctx, "PATCH", target, r.Body) - if err != nil { - log.Err(err).Msg("wrong request") - w.WriteHeader(http.StatusInternalServerError) - return - } - httpReq.Header = r.Header - - httpRes, err := httpClient.Do(httpReq) - if err != nil { - log.Err(err).Msg("error doing PATCH request to data service") + log.Err(err).Msg("error doing " + r.Method + " request to data service") w.WriteHeader(http.StatusInternalServerError) return } @@ -395,58 +210,22 @@ func (s *svc) doPatch(w http.ResponseWriter, r *http.Request) { w.WriteHeader(httpRes.StatusCode) return } - w.WriteHeader(httpRes.StatusCode) - _, err = io.Copy(w, httpRes.Body) - if err != nil { - log.Err(err).Msg("error writing body after header were set") - } -} - -func (s *svc) doOptions(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := appctx.GetLogger(ctx) - - claims, err := s.verify(ctx, r) - if err != nil { - err = errors.Wrap(err, "datagateway: error validating transfer token") - log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token") - w.WriteHeader(http.StatusForbidden) - return - } - - log.Debug().Str("target", claims.Target).Msg("sending request to internal data server") - httpClient := s.client - httpReq, err := rhttp.NewRequest(ctx, "OPTIONS", claims.Target, nil) - if err != nil { - log.Error().Err(err).Msg("wrong request") - w.WriteHeader(http.StatusInternalServerError) - return - } - httpReq.Header = r.Header - - httpRes, err := httpClient.Do(httpReq) + var c int64 + c, err = io.Copy(w, httpRes.Body) if err != nil { - log.Error().Err(err).Msg("error doing OPTIONS request to data service") - w.WriteHeader(http.StatusInternalServerError) - return + log.Err(err).Msg("error writing body after header were set") } - defer httpRes.Body.Close() - - copyHeader(w.Header(), httpRes.Header) - - // add upload expiry / transfer token expiry header for tus https://tus.io/protocols/resumable-upload.html#expiration - w.Header().Set(UploadExpiresHeader, time.Unix(claims.ExpiresAt, 0).Format(time.RFC1123)) - - if httpRes.StatusCode != http.StatusOK { - // swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it - w.Header().Set("Content-Length", "0") - w.WriteHeader(httpRes.StatusCode) - return + if httpRes.Header.Get("Content-Length") != "" { + i, err := strconv.ParseInt(httpRes.Header.Get("Content-Length"), 10, 64) + if err != nil { + log.Error().Err(err).Str("content-length", httpRes.Header.Get("Content-Length")).Msg("invalid content length in dataprovider response") + } + if i != c { + log.Error().Int64("content-length", i).Int64("transferred-bytes", c).Msg("content length vs transferred bytes mismatch") + } } - - w.WriteHeader(http.StatusOK) } func copyHeader(dst, src http.Header) {