Skip to content

Commit

Permalink
chore (refactoring): revisit middleware types
Browse files Browse the repository at this point in the history
  • Loading branch information
mickael-kerjean-qantas committed Mar 25, 2024
1 parent 5bdb3f8 commit c06069c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 49 deletions.
14 changes: 8 additions & 6 deletions server/ctrl/webdav.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package ctrl

import (
. "github.com/mickael-kerjean/filestash/server/common"
"github.com/mickael-kerjean/filestash/server/model"
"github.com/mickael-kerjean/net/webdav"
"net/http"
"path/filepath"
"strings"

. "github.com/mickael-kerjean/filestash/server/common"
"github.com/mickael-kerjean/filestash/server/middleware"
"github.com/mickael-kerjean/filestash/server/model"
"github.com/mickael-kerjean/net/webdav"
)

func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -53,8 +55,8 @@ func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
* an imbecile and considering we can't even see the source code they are running, the best approach we
* could go on is: "crap in, crap out" where useless request coming in are identified and answer appropriatly
*/
func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func WebdavBlacklist(fn middleware.HandlerFunc) middleware.HandlerFunc {
return middleware.HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
base := filepath.Base(req.URL.String())

if req.Method == "PUT" || req.Method == "MKCOL" {
Expand Down Expand Up @@ -125,5 +127,5 @@ func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx
}
}
fn(ctx, res, req)
}
})
}
6 changes: 3 additions & 3 deletions server/middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strings"
)

func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
func BodyParser(fn HandlerFunc) HandlerFunc {
extractBody := func(req *http.Request) (map[string]interface{}, error) {
body := map[string]interface{}{}
byt, err := ioutil.ReadAll(req.Body)
Expand All @@ -25,14 +25,14 @@ func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
return body, nil
}

return func(ctx *App, res http.ResponseWriter, req *http.Request) {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
var err error
if ctx.Body, err = extractBody(req); err != nil {
SendErrorResult(res, ErrNotValid)
return
}
fn(ctx, res, req)
}
})
}

func GenerateRequestID(prefix string) string {
Expand Down
42 changes: 21 additions & 21 deletions server/middleware/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"strings"
)

func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func ApiHeaders(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
header := res.Header()
header.Set("Content-Type", "application/json")
header.Set("Cache-Control", "no-cache")
Expand All @@ -20,20 +20,20 @@ func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
header.Set("X-Request-ID", GenerateRequestID("API"))
}
fn(ctx, res, req)
}
})
}

func StaticHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func StaticHeaders(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
header := res.Header()
header.Set("Content-Type", GetMimeType(filepath.Ext(req.URL.Path)))
header.Set("Cache-Control", "max-age=2592000")
fn(ctx, res, req)
}
})
}

func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func IndexHeaders(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
header := res.Header()
header.Set("Content-Type", "text/html")
header.Set("Cache-Control", "no-cache")
Expand Down Expand Up @@ -65,23 +65,23 @@ func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
}
// header.Set("Content-Security-Policy", cspHeader)
fn(ctx, res, req)
}
})
}

func SecureHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func SecureHeaders(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
header := res.Header()
if Config.Get("general.force_ssl").Bool() {
header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
header.Set("X-Content-Type-Options", "nosniff")
header.Set("X-XSS-Protection", "1; mode=block")
fn(ctx, res, req)
}
})
}

func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func SecureOrigin(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
if host := Config.Get("general.host").String(); host != "" {
host = strings.TrimPrefix(host, "http://")
host = strings.TrimPrefix(host, "https://")
Expand All @@ -105,11 +105,11 @@ func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A

Log.Warning("Intrusion detection: %s - %s", RetrievePublicIp(req), req.URL.String())
SendErrorResult(res, ErrNotAllowed)
}
})
}

func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func WithPublicAPI(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
apiKey := req.URL.Query().Get("key")
if apiKey == "" {
fn(ctx, res, req)
Expand All @@ -132,13 +132,13 @@ func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *
return
}
fn(ctx, res, req)
}
})
}

var limiter = rate.NewLimiter(10, 1000)

func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func RateLimiter(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
if limiter.Allow() == false {
Log.Warning("middleware::http::ratelimit too many requests")
SendErrorResult(
Expand All @@ -148,7 +148,7 @@ func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *Ap
return
}
fn(ctx, res, req)
}
})
}

func EnableCors(req *http.Request, res http.ResponseWriter, host string) error {
Expand Down
2 changes: 1 addition & 1 deletion server/middleware/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

type HandlerFunc func(*App, http.ResponseWriter, *http.Request)
type Middleware func(func(*App, http.ResponseWriter, *http.Request)) func(*App, http.ResponseWriter, *http.Request)
type Middleware func(HandlerFunc) HandlerFunc

func init() {
Hooks.Register.Onload(func() {
Expand Down
36 changes: 18 additions & 18 deletions server/middleware/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ import (
"time"
)

func LoggedInOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func LoggedInOnly(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
if ctx.Backend == nil || ctx.Session == nil {
SendErrorResult(res, ErrPermissionDenied)
return
}
fn(ctx, res, req)
}
})
}

func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func AdminOnly(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
if admin := Config.Get("auth.admin").String(); admin != "" {
c, err := req.Cookie(COOKIE_NAME_ADMIN)
if err != nil {
Expand All @@ -47,11 +47,11 @@ func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App,
}
}
fn(ctx, res, req)
}
})
}

func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func SessionStart(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
var err error

if ctx.Share, err = _extractShare(req); err != nil {
Expand All @@ -72,21 +72,21 @@ func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
return
}
fn(ctx, res, req)
}
})
}

func SessionTry(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func SessionTry(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
ctx.Share, _ = _extractShare(req)
ctx.Authorization = _extractAuthorization(req)
ctx.Session, _ = _extractSession(req, ctx)
ctx.Backend, _ = _extractBackend(req, ctx)
fn(ctx, res, req)
}
})
}

func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func RedirectSharedLoginIfNeeded(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
share_id := _extractShareId(req)
if share_id == "" {
if mux.Vars(req)["share"] == "private" {
Expand All @@ -103,11 +103,11 @@ func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Reques
return
}
fn(ctx, res, req)
}
})
}

func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
func CanManageShare(fn HandlerFunc) HandlerFunc {
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
share_id := mux.Vars(req)["share"]
if share_id == "" {
Log.Debug("middleware::session::share 'invalid share id'")
Expand Down Expand Up @@ -167,7 +167,7 @@ func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx
}
SendErrorResult(res, ErrPermissionDenied)
return
}
})
}

func _extractAuthorization(req *http.Request) (token string) {
Expand Down

0 comments on commit c06069c

Please sign in to comment.