From a4f012734053fa4b4c7ab4bbd44eb608c0b91b45 Mon Sep 17 00:00:00 2001 From: Gianmaria Del Monte Date: Tue, 4 Apr 2023 15:12:11 +0200 Subject: [PATCH] Refactored downloader --- .../services/archiver/manager/archiver.go | 10 ++++++-- .../http/services/owncloud/ocdav/versions.go | 11 ++++++++- pkg/storage/utils/downloader/downloader.go | 23 +++++++++---------- .../utils/downloader/mock/downloader_mock.go | 10 +++----- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/internal/http/services/archiver/manager/archiver.go b/internal/http/services/archiver/manager/archiver.go index 1d9ab3fcc0..28c8009552 100644 --- a/internal/http/services/archiver/manager/archiver.go +++ b/internal/http/services/archiver/manager/archiver.go @@ -168,10 +168,13 @@ func (a *Archiver) CreateTar(ctx context.Context, dst io.Writer) error { } if !isDir { - err = a.downloader.Download(ctx, path, "", w) + r, err := a.downloader.Download(ctx, path, "") if err != nil { return err } + if _, err := io.Copy(w, r); err != nil { + return err + } } return nil }) @@ -239,10 +242,13 @@ func (a *Archiver) CreateZip(ctx context.Context, dst io.Writer) error { } if !isDir { - err = a.downloader.Download(ctx, path, "", dst) + r, err := a.downloader.Download(ctx, path, "") if err != nil { return err } + if _, err := io.Copy(dst, r); err != nil { + return err + } } return nil }) diff --git a/internal/http/services/owncloud/ocdav/versions.go b/internal/http/services/owncloud/ocdav/versions.go index 5fe0517da5..7c42d2900f 100644 --- a/internal/http/services/owncloud/ocdav/versions.go +++ b/internal/http/services/owncloud/ocdav/versions.go @@ -21,6 +21,7 @@ package ocdav import ( "context" "fmt" + "io" "net/http" "path" "path/filepath" @@ -253,7 +254,15 @@ func (h *VersionsHandler) doDownload(w http.ResponseWriter, r *http.Request, s * w.Header().Set("Content-Transfer-Encoding", "binary") down := downloader.NewDownloader(client) - if err := down.Download(ctx, resStat.Info.Path, key, w); err != nil { + d, err := down.Download(ctx, resStat.Info.Path, key) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer d.Close() + + _, err = io.Copy(w, d) + if err != nil { w.WriteHeader(http.StatusInternalServerError) return } diff --git a/pkg/storage/utils/downloader/downloader.go b/pkg/storage/utils/downloader/downloader.go index 35a9d2520a..badccd1201 100644 --- a/pkg/storage/utils/downloader/downloader.go +++ b/pkg/storage/utils/downloader/downloader.go @@ -36,7 +36,7 @@ import ( // Downloader is the interface implemented by the objects that are able to // download a path into a destination Writer. type Downloader interface { - Download(ctx context.Context, path string, versionKey string, w io.Writer) error + Download(ctx context.Context, path, versionKey string) (io.ReadCloser, error) } type revaDownloader struct { @@ -62,7 +62,7 @@ func getDownloadProtocol(protocols []*gateway.FileDownloadProtocol, prot string) } // Download downloads a resource given the path to the dst Writer. -func (r *revaDownloader) Download(ctx context.Context, path, versionKey string, dst io.Writer) error { +func (r *revaDownloader) Download(ctx context.Context, path, versionKey string) (io.ReadCloser, error) { req := &provider.InitiateFileDownloadRequest{ Ref: &provider.Reference{ Path: path, @@ -82,37 +82,36 @@ func (r *revaDownloader) Download(ctx context.Context, path, versionKey string, switch { case err != nil: - return err + return nil, err case downResp.Status.Code != rpc.Code_CODE_OK: - return errtypes.InternalError(downResp.Status.Message) + return nil, errtypes.InternalError(downResp.Status.Message) } p, err := getDownloadProtocol(downResp.Protocols, "simple") if err != nil { - return err + return nil, err } httpReq, err := rhttp.NewRequest(ctx, http.MethodGet, p.DownloadEndpoint, nil) if err != nil { - return err + return nil, err } httpReq.Header.Set(datagateway.TokenTransportHeader, p.Token) httpRes, err := r.httpClient.Do(httpReq) if err != nil { - return err + return nil, err } - defer httpRes.Body.Close() if httpRes.StatusCode != http.StatusOK { + defer httpRes.Body.Close() switch httpRes.StatusCode { case http.StatusNotFound: - return errtypes.NotFound(path) + return nil, errtypes.NotFound(path) default: - return errtypes.InternalError(httpRes.Status) + return nil, errtypes.InternalError(httpRes.Status) } } - _, err = io.Copy(dst, httpRes.Body) - return err + return httpRes.Body, nil } diff --git a/pkg/storage/utils/downloader/mock/downloader_mock.go b/pkg/storage/utils/downloader/mock/downloader_mock.go index 8368c9a99a..9fb905e9e9 100644 --- a/pkg/storage/utils/downloader/mock/downloader_mock.go +++ b/pkg/storage/utils/downloader/mock/downloader_mock.go @@ -19,7 +19,6 @@ package mock import ( - "bufio" "context" "io" "os" @@ -36,13 +35,10 @@ func NewDownloader() downloader.Downloader { } // Download copies the content of a local file into the dst Writer. -func (m *mockDownloader) Download(ctx context.Context, path, _ string, dst io.Writer) error { +func (m *mockDownloader) Download(ctx context.Context, path, _ string) (io.ReadCloser, error) { f, err := os.Open(path) if err != nil { - return err + return nil, err } - defer f.Close() - fr := bufio.NewReader(f) - _, err = io.Copy(dst, fr) - return err + return f, nil }