Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework middleware and add frontend proxy #331

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,36 @@ import (
"github.com/Southclaws/storyden/internal/config"
)

func WithChaos(cfg config.Config) func(http.Handler) http.Handler {
failRate := cfg.DevChaosFailRate
slowMode := cfg.DevChaosSlowMode
type Middleware struct {
enabled bool
failRate float64
slowMode time.Duration
}

disabled := failRate == 0 && slowMode == 0
func New(cfg config.Config) *Middleware {
return &Middleware{
enabled: cfg.DevChaosFailRate > 0 || cfg.DevChaosSlowMode == 0,
failRate: cfg.DevChaosFailRate,
slowMode: cfg.DevChaosSlowMode,
}
}

if disabled {
func (m *Middleware) WithChaos() func(http.Handler) http.Handler {
if !m.enabled {
return func(h http.Handler) http.Handler { return h }
}

return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if slowMode > 0 {
wait := time.Duration(rand.Intn(int(slowMode)))
if m.slowMode > 0 {
wait := time.Duration(rand.Intn(int(m.slowMode)))
fmt.Println("[DEV_CHAOS_SLOW_MODE] waiting", wait)
time.Sleep(wait)
}

if failRate > 0 {
if m.failRate > 0 {
chance := rand.Float64()
if chance < failRate {
if chance < m.failRate {
fmt.Println("[DEV_CHAOS_FAIL_RATE] crashing")
w.WriteHeader(http.StatusInternalServerError)
return
Expand Down
64 changes: 64 additions & 0 deletions app/transports/http/middleware/frontend/frontend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package frontend

import (
"net/http"
"net/http/httputil"
"strings"

"go.uber.org/fx"
"go.uber.org/zap"

"github.com/Southclaws/storyden/app/transports/http/middleware/session_cookie"
"github.com/Southclaws/storyden/internal/config"
)

type Provider struct {
handler func(http.ResponseWriter, *http.Request)
}

func New(
cfg config.Config,
logger *zap.Logger,
mux *http.ServeMux,
cj *session_cookie.Jar,
) *Provider {
if cfg.FrontendProxy == nil {
return &Provider{}
}

handler := func(p *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}
}

proxy := httputil.NewSingleHostReverseProxy(cfg.FrontendProxy)

return &Provider{
handler: handler(proxy),
}
}

func (p *Provider) WithFrontendProxy() func(next http.Handler) http.Handler {
if p.handler == nil {
return func(next http.Handler) http.Handler {
return next
}
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api") {
next.ServeHTTP(w, r)
} else {
p.handler(w, r)
}
})
}
}

func Build() fx.Option {
return fx.Options(
fx.Provide(New),
)
}
85 changes: 45 additions & 40 deletions app/transports/http/middleware/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@ import (
)

const (
RateLimitLimit = "X-RateLimit-Limit"
RateLimitRemaining = "X-RateLimit-Remaining"
RateLimitReset = "X-RateLimit-Reset"
RetryAfter = "Retry-After"
RateLimitLimit = "X-RateLimit-Limit"
RateLimitRemaining = "X-RateLimit-Remaining"
RateLimitReset = "X-RateLimit-Reset"
RetryAfter = "Retry-After"
MaxRequestSizeBytes = 10 * 1024 * 1024
)

type Middleware struct {
logger *zap.Logger
rl rate.Limiter
kf KeyFunc
logger *zap.Logger
rl rate.Limiter
kf KeyFunc
sizeLimit int64
}

func New(
Expand All @@ -33,47 +35,50 @@ func New(
rl := f.NewLimiter(cfg.RateLimit, cfg.RateLimitPeriod, cfg.RateLimitExpire)

return &Middleware{
logger: logger,
rl: rl,
kf: fromIP("CF-Connecting-IP", "X-Real-IP", "True-Client-IP"),
logger: logger,
rl: rl,
kf: fromIP("CF-Connecting-IP", "X-Real-IP", "True-Client-IP"),
sizeLimit: MaxRequestSizeBytes, // TODO: cfg.MaxRequestSize
}
}

func (m *Middleware) WithRateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
func (m *Middleware) WithRateLimit() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

key, err := m.kf(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
key, err := m.kf(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

// TODO: Generate costs per-operation from OpenAPI spec
cost := 1
// TODO: Generate costs per-operation from OpenAPI spec
cost := 1

status, allowed, err := m.rl.Increment(ctx, key, cost)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
status, allowed, err := m.rl.Increment(ctx, key, cost)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

limit := status.Limit
remaining := status.Remaining
resetTime := status.Reset.UTC().Format(time.RFC1123)
limit := status.Limit
remaining := status.Remaining
resetTime := status.Reset.UTC().Format(time.RFC1123)

w.Header().Set(RateLimitLimit, strconv.FormatUint(uint64(limit), 10))
w.Header().Set(RateLimitRemaining, strconv.FormatUint(uint64(remaining), 10))
w.Header().Set(RateLimitReset, resetTime)
w.Header().Set(RateLimitLimit, strconv.FormatUint(uint64(limit), 10))
w.Header().Set(RateLimitRemaining, strconv.FormatUint(uint64(remaining), 10))
w.Header().Set(RateLimitReset, resetTime)

if !allowed {
w.Header().Set(RetryAfter, resetTime)
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
if !allowed {
w.Header().Set(RetryAfter, resetTime)
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}

next.ServeHTTP(w, r)
})
next.ServeHTTP(w, r)
})
}
}

type KeyFunc func(r *http.Request) (string, error)
Expand All @@ -94,10 +99,10 @@ func fromIP(headers ...string) KeyFunc {
}
}

func WithRequestSizeLimiter(bytes int64) func(http.Handler) http.Handler {
func (m *Middleware) WithRequestSizeLimiter() func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, bytes)
r.Body = http.MaxBytesReader(w, r.Body, m.sizeLimit)
h.ServeHTTP(w, r)
})
}
Expand Down
10 changes: 10 additions & 0 deletions app/transports/http/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@ package middleware
import (
"go.uber.org/fx"

"github.com/Southclaws/storyden/app/transports/http/middleware/chaos"
"github.com/Southclaws/storyden/app/transports/http/middleware/frontend"
"github.com/Southclaws/storyden/app/transports/http/middleware/limiter"
"github.com/Southclaws/storyden/app/transports/http/middleware/origin"
"github.com/Southclaws/storyden/app/transports/http/middleware/reqlog"
"github.com/Southclaws/storyden/app/transports/http/middleware/session_cookie"
"github.com/Southclaws/storyden/app/transports/http/middleware/useragent"
)

func Build() fx.Option {
return fx.Provide(
origin.New,
reqlog.New,
frontend.New,
useragent.New,
session_cookie.New,
limiter.New,
chaos.New,
)
}
10 changes: 9 additions & 1 deletion app/transports/http/middleware/origin/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ import (
"github.com/Southclaws/storyden/internal/config"
)

func WithCORS(cfg config.Config) func(next http.Handler) http.Handler {
type Middleware struct {
cfg config.Config
}

func New(cfg config.Config) *Middleware {
return &Middleware{cfg: cfg}
}

func (m *Middleware) WithCORS() func(next http.Handler) http.Handler {
allowedMethods := []string{
"GET",
"POST",
Expand Down
14 changes: 12 additions & 2 deletions app/transports/http/middleware/reqlog/requestlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ import (
"go.uber.org/zap"
)

type Middleware struct {
logger *zap.Logger
}

func New(logger *zap.Logger) *Middleware {
return &Middleware{
logger: logger,
}
}

type withStatus struct {
http.ResponseWriter
statusCode int
Expand All @@ -24,7 +34,7 @@ func (lrw *withStatus) WriteHeader(code int) {
lrw.ResponseWriter.WriteHeader(code)
}

func WithLogger(logger *zap.Logger) func(http.Handler) http.Handler {
func (m *Middleware) WithLogger() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
Expand All @@ -37,7 +47,7 @@ func WithLogger(logger *zap.Logger) func(http.Handler) http.Handler {
wr := &withStatus{ResponseWriter: w}

defer func() {
log := logger.With(
log := m.logger.With(
zap.Duration("duration", time.Since(start)),
zap.String("query", r.URL.Query().Encode()),
zap.Int64("body", r.ContentLength),
Expand Down
12 changes: 7 additions & 5 deletions app/transports/http/middleware/session_cookie/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ func (j *Jar) WithSession(r *http.Request) context.Context {
}

// WithAuth simply pulls out the session from the cookie and propagates it.
func (j *Jar) WithAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := j.WithSession(r)
func (j *Jar) WithAuth() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := j.WithSession(r)

next.ServeHTTP(w, r.WithContext(ctx))
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

func (j *Jar) GetCookieName() string {
Expand Down
26 changes: 17 additions & 9 deletions app/transports/http/middleware/useragent/ua.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,29 @@ import (
"github.com/mileusna/useragent"
)

type Middleware struct{}

func New() *Middleware {
return &Middleware{}
}

type uaKey struct{}

// UserAgentContext stores in the request context the user agent info.
func UserAgentContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// WithUserAgentContext stores in the request context the user agent info.
func (m *Middleware) WithUserAgentContext() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

ua := useragent.Parse(r.Header.Get("User-Agent"))
ua := useragent.Parse(r.Header.Get("User-Agent"))

newctx := context.WithValue(ctx, uaKey{}, ua)
newctx := context.WithValue(ctx, uaKey{}, ua)

r = r.WithContext(newctx)
r = r.WithContext(newctx)

next.ServeHTTP(w, r)
})
next.ServeHTTP(w, r)
})
}
}

func GetDeviceName(ctx context.Context) string {
Expand Down
Loading
Loading