From 6f4c2530d707027abc0152e5fd46e2feca4a3caf Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Wed, 29 May 2024 23:12:21 +0300 Subject: [PATCH] :sparkles: v3 (feature): add configuration support to c.SendFile() --- ctx.go | 117 +++++++++++++++++++++++++------ ctx_interface.go | 2 +- ctx_test.go | 174 ++++++++++++++++++++++++++++++++++++++++++++++- router.go | 6 ++ 4 files changed, 274 insertions(+), 25 deletions(-) diff --git a/ctx.go b/ctx.go index e630a345f6..549e07c025 100644 --- a/ctx.go +++ b/ctx.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" @@ -1398,47 +1399,99 @@ func (c *DefaultCtx) Send(body []byte) error { return nil } -var ( - sendFileOnce sync.Once - sendFileFS *fasthttp.FS - sendFileHandler fasthttp.RequestHandler -) +// SendFile defines configuration options when to transfer file with SendFileWithConfig. +type SendFile struct { + // FS is the file system to serve the static files from. + // You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. + // + // Optional. Default: nil + FS fs.FS + + // When set to true, the server tries minimizing CPU usage by caching compressed files. + // This works differently than the github.com/gofiber/compression middleware. + // Optional. Default value false + Compress bool `json:"compress"` + + // When set to true, enables byte range requests. + // Optional. Default value false + ByteRange bool `json:"byte_range"` + + // When set to true, enables direct download. + // + // Optional. Default: false. + Download bool `json:"download"` + + // Expiration duration for inactive file handlers. + // Use a negative time.Duration to disable it. + // + // Optional. Default value 10 * time.Second. + CacheDuration time.Duration `json:"cache_duration"` + + // The value for the Cache-Control HTTP-header + // that is set on the file response. MaxAge is defined in seconds. + // + // Optional. Default value 0. + MaxAge int `json:"max_age"` +} // SendFile transfers the file from the given path. // The file is not compressed by default, enable this by passing a 'true' argument // Sets the Content-Type response HTTP header field based on the filenames extension. -func (c *DefaultCtx) SendFile(file string, compress ...bool) error { +func (c *DefaultCtx) SendFile(file string, config ...SendFile) error { // Save the filename, we will need it in the error message if the file isn't found filename := file + route := c.Route() + + var cfg SendFile + if len(config) > 0 { + cfg = config[0] + } + + if cfg.CacheDuration == 0 { + cfg.CacheDuration = 10 * time.Second + } + // https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134 - sendFileOnce.Do(func() { - const cacheDuration = 10 * time.Second - sendFileFS = &fasthttp.FS{ + route.sendFileOnce.Do(func() { + route.sendFileFS = &fasthttp.FS{ Root: "", + FS: cfg.FS, AllowEmptyRoot: true, GenerateIndexPages: false, - AcceptByteRange: true, - Compress: true, + AcceptByteRange: cfg.ByteRange, + Compress: cfg.Compress, CompressedFileSuffix: c.app.config.CompressedFileSuffix, - CacheDuration: cacheDuration, + CacheDuration: cfg.CacheDuration, + SkipCache: cfg.CacheDuration < 0, IndexNames: []string{"index.html"}, PathNotFound: func(ctx *fasthttp.RequestCtx) { ctx.Response.SetStatusCode(StatusNotFound) }, } - sendFileHandler = sendFileFS.NewRequestHandler() + + if cfg.FS != nil { + route.sendFileFS.Root = "." + } + + route.sendFileHandler = route.sendFileFS.NewRequestHandler() + + maxAge := cfg.MaxAge + if maxAge > 0 { + route.sendFileCacheControlValue = "public, max-age=" + strconv.Itoa(maxAge) + } }) // Keep original path for mutable params c.pathOriginal = utils.CopyString(c.pathOriginal) + // Disable compression - if len(compress) == 0 || !compress[0] { - // https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L55 - c.fasthttp.Request.Header.Del(HeaderAcceptEncoding) + if cfg.Compress { + c.fasthttp.Request.Header.Set(HeaderAcceptEncoding, "gzip") } + // copy of https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L103-L121 with small adjustments - if len(file) == 0 || !filepath.IsAbs(file) { + if len(file) == 0 || (!filepath.IsAbs(file) && cfg.FS == nil) { // extend relative path to absolute path hasTrailingSlash := len(file) > 0 && (file[len(file)-1] == '/' || file[len(file)-1] == '\\') @@ -1451,6 +1504,7 @@ func (c *DefaultCtx) SendFile(file string, compress ...bool) error { file += "/" } } + // convert the path to forward slashes regardless the OS in order to set the URI properly // the handler will convert back to OS path separator before opening the file file = filepath.ToSlash(file) @@ -1458,22 +1512,43 @@ func (c *DefaultCtx) SendFile(file string, compress ...bool) error { // Restore the original requested URL originalURL := utils.CopyString(c.OriginalURL()) defer c.fasthttp.Request.SetRequestURI(originalURL) + // Set new URI for fileHandler c.fasthttp.Request.SetRequestURI(file) + // Save status code status := c.fasthttp.Response.StatusCode() + // Serve file - sendFileHandler(c.fasthttp) + route.sendFileHandler(c.fasthttp) + + // Sets the response Content-Disposition header to attachment if the Download option is true + if cfg.Download { + c.Attachment() + } + // Get the status code which is set by fasthttp fsStatus := c.fasthttp.Response.StatusCode() + + // Check for error + if status != StatusNotFound && fsStatus == StatusNotFound { + return NewError(StatusNotFound, fmt.Sprintf("sendfile: file %s not found", filename)) + } + // Set the status code set by the user if it is different from the fasthttp status code and 200 if status != fsStatus && status != StatusOK { c.Status(status) } - // Check for error - if status != StatusNotFound && fsStatus == StatusNotFound { - return NewError(StatusNotFound, fmt.Sprintf("sendfile: file %s not found", filename)) + + // Apply cache control header + if status != StatusNotFound && status != StatusForbidden { + if len(route.sendFileCacheControlValue) > 0 { + c.Context().Response.Header.Set(HeaderCacheControl, route.sendFileCacheControlValue) + } + + return nil } + return nil } diff --git a/ctx_interface.go b/ctx_interface.go index 2950c088de..5f48732b28 100644 --- a/ctx_interface.go +++ b/ctx_interface.go @@ -308,7 +308,7 @@ type Ctx interface { // SendFile transfers the file from the given path. // The file is not compressed by default, enable this by passing a 'true' argument // Sets the Content-Type response HTTP header field based on the filenames extension. - SendFile(file string, compress ...bool) error + SendFile(file string, config ...SendFile) error // SendStatus sets the HTTP status code and if the response body is empty, // it sets the correct status message in the body. diff --git a/ctx_test.go b/ctx_test.go index 524869bc2d..c89b7f6f52 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -11,6 +11,7 @@ import ( "compress/zlib" "context" "crypto/tls" + "embed" "encoding/xml" "errors" "fmt" @@ -2970,19 +2971,152 @@ func Test_Ctx_SendFile(t *testing.T) { app.ReleaseCtx(c) } +func Test_Ctx_SendFile_Download(t *testing.T) { + t.Parallel() + app := New() + + // fetch file content + f, err := os.Open("./ctx.go") + require.NoError(t, err) + defer func() { + require.NoError(t, f.Close()) + }() + expectFileContent, err := io.ReadAll(f) + require.NoError(t, err) + // fetch file info for the not modified test case + _, err = os.Stat("./ctx.go") + require.NoError(t, err) + + // simple test case + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + err = c.SendFile("ctx.go", SendFile{ + Download: true, + }) + // check expectation + require.NoError(t, err) + require.Equal(t, expectFileContent, c.Response().Body()) + require.Equal(t, "attachment", string(c.Response().Header.Peek(HeaderContentDisposition))) + require.Equal(t, StatusOK, c.Response().StatusCode()) + app.ReleaseCtx(c) +} + +func Test_Ctx_SendFile_MaxAge(t *testing.T) { + t.Parallel() + app := New() + + // fetch file content + f, err := os.Open("./ctx.go") + require.NoError(t, err) + defer func() { + require.NoError(t, f.Close()) + }() + expectFileContent, err := io.ReadAll(f) + require.NoError(t, err) + + // fetch file info for the not modified test case + _, err = os.Stat("./ctx.go") + require.NoError(t, err) + + // simple test case + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + err = c.SendFile("ctx.go", SendFile{ + MaxAge: 100, + }) + + // check expectation + require.NoError(t, err) + require.Equal(t, expectFileContent, c.Response().Body()) + require.Equal(t, "public, max-age=100", string(c.Context().Response.Header.Peek(HeaderCacheControl)), "CacheControl Control") + require.Equal(t, StatusOK, c.Response().StatusCode()) + app.ReleaseCtx(c) +} + +func Test_Ctx_SendFile_Compressed(t *testing.T) { + t.Parallel() + app := New() + + // fetch file content + f, err := os.Open("./ctx.go") + require.NoError(t, err) + + defer func() { + require.NoError(t, f.Close()) + + }() + expectFileContent, err := io.ReadAll(f) + require.NoError(t, err) + + // fetch file info for the not modified test case + _, err = os.Stat("./ctx.go") + require.NoError(t, err) + + // simple test case + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + err = c.SendFile("./ctx.go", SendFile{ + Compress: true, + }) + require.NoError(t, err) + + gz, err := gzip.NewReader(bytes.NewReader(c.Response().Body())) + require.NoError(t, err) + + body, err := io.ReadAll(gz) + require.NoError(t, err) + + require.Equal(t, expectFileContent, body) + require.Equal(t, "gzip", string(c.Response().Header.Peek(HeaderContentEncoding))) + require.Equal(t, StatusOK, c.Response().StatusCode()) + + app.ReleaseCtx(c) +} + +//go:embed ctx.go +var embedFile embed.FS + +func Test_Ctx_SendFile_EmbedFS(t *testing.T) { + t.Parallel() + app := New() + + f, err := os.Open("./ctx.go") + require.NoError(t, err) + + defer func() { + require.NoError(t, f.Close()) + }() + + expectFileContent, err := io.ReadAll(f) + require.NoError(t, err) + + app.Get("/test", func(c Ctx) error { + return c.SendFile("ctx.go", SendFile{ + FS: embedFile, + }) + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, expectFileContent, body) +} + // go test -race -run Test_Ctx_SendFile_404 func Test_Ctx_SendFile_404(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { - err := c.SendFile(filepath.FromSlash("john_dow.go/")) - require.Error(t, err) - return err + return c.SendFile("ctx12.go") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, StatusNotFound, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "sendfile: file ctx12.go not found", string(body)) } // go test -race -run Test_Ctx_SendFile_Immutable @@ -3050,6 +3184,40 @@ func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) { require.NoError(t, err2) } +func Test_SendFile_withRoutes(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/file", func(c Ctx) error { + return c.SendFile("ctx.go") + }) + + app.Get("/file/download", func(c Ctx) error { + return c.SendFile("ctx.go", SendFile{ + Download: true, + }) + }) + + app.Get("/file/fs", func(c Ctx) error { + return c.SendFile("ctx.go", SendFile{ + FS: os.DirFS("."), + }) + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/file", nil)) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/file/download", nil)) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) + require.Equal(t, "attachment", resp.Header.Get(HeaderContentDisposition)) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/file/fs", nil)) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) +} + // go test -run Test_Ctx_JSON func Test_Ctx_JSON(t *testing.T) { t.Parallel() diff --git a/router.go b/router.go index 26a2483f09..93f24ed59a 100644 --- a/router.go +++ b/router.go @@ -10,6 +10,7 @@ import ( "html" "sort" "strings" + "sync" "sync/atomic" "github.com/gofiber/utils/v2" @@ -60,6 +61,11 @@ type Route struct { Path string `json:"path"` // Original registered route path Params []string `json:"params"` // Case sensitive param keys Handlers []Handler `json:"-"` // Ctx handlers + + sendFileOnce sync.Once + sendFileFS *fasthttp.FS + sendFileHandler fasthttp.RequestHandler + sendFileCacheControlValue string } func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool {