Skip to content
This repository has been archived by the owner on Jul 12, 2023. It is now read-only.

Separate loading realm and require realm #530

Merged
merged 1 commit into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/server/assets/realms/select.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

{{$csrfField := .csrfField}}
{{$realms := .realms}}
{{$currentRealm := .currentRealm}}

<!doctype html>
<html lang="en">
Expand All @@ -26,7 +27,7 @@ <h1>Select your realm</h1>
<form action="/realm/select" method="POST">
{{$csrfField}}
<input type="hidden" name="realm" value="{{$realm.ID}}" />
<a href="#" class="w-100 d-flex flex-row justify-content-between align-items-center align-self-center list-group-item list-group-item-action" data-submit-form>
<a href="#" class="w-100 d-flex flex-row justify-content-between align-items-center align-self-center list-group-item list-group-item-action {{if $currentRealm }}{{if eq $currentRealm.Name $realm.Name}}active{{end}}{{end}}" data-submit-form>
<div>
<h5 class="mb-1">{{$realm.Name}}</h5>
{{if $realm.RegionCode}}
Expand Down
10 changes: 9 additions & 1 deletion cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ func realMain(ctx context.Context) error {
requireAuth := middleware.RequireAuth(ctx, cacher, auth, db, h, config.SessionDuration)
requireVerified := middleware.RequireVerified(ctx, auth, db, h, config.SessionDuration)
requireAdmin := middleware.RequireRealmAdmin(ctx, h)
requireRealm := middleware.RequireRealm(ctx, cacher, db, h)
loadCurrentRealm := middleware.LoadCurrentRealm(ctx, cacher, db, h)
requireRealm := middleware.RequireRealm(ctx, h)
requireSystemAdmin := middleware.RequireAdmin(ctx, h)
requireMFA := middleware.RequireMFA(ctx, h)
rateLimit := httplimiter.Handle
Expand Down Expand Up @@ -238,6 +239,7 @@ func realMain(ctx context.Context) error {
sub := r.PathPrefix("/realm").Subrouter()
sub.Use(requireAuth)
sub.Use(requireVerified)
sub.Use(loadCurrentRealm)
sub.Use(rateLimit)

// Realms - list and select.
Expand All @@ -250,6 +252,7 @@ func realMain(ctx context.Context) error {
sub := r.PathPrefix("/home").Subrouter()
sub.Use(requireAuth)
sub.Use(requireVerified)
sub.Use(loadCurrentRealm)
sub.Use(requireRealm)
sub.Use(requireMFA)
sub.Use(rateLimit)
Expand All @@ -269,6 +272,7 @@ func realMain(ctx context.Context) error {
sub := r.PathPrefix("/code").Subrouter()
sub.Use(requireAuth)
sub.Use(requireVerified)
sub.Use(loadCurrentRealm)
sub.Use(requireRealm)
sub.Use(requireMFA)
sub.Use(rateLimit)
Expand All @@ -284,6 +288,7 @@ func realMain(ctx context.Context) error {
sub := r.PathPrefix("/apikeys").Subrouter()
sub.Use(requireAuth)
sub.Use(requireVerified)
sub.Use(loadCurrentRealm)
sub.Use(requireRealm)
sub.Use(requireMFA)
sub.Use(requireAdmin)
Expand All @@ -305,6 +310,7 @@ func realMain(ctx context.Context) error {
userSub := r.PathPrefix("/users").Subrouter()
userSub.Use(requireAuth)
userSub.Use(requireVerified)
userSub.Use(loadCurrentRealm)
userSub.Use(requireRealm)
userSub.Use(requireMFA)
userSub.Use(requireAdmin)
Expand All @@ -325,6 +331,7 @@ func realMain(ctx context.Context) error {
realmSub := r.PathPrefix("/realm").Subrouter()
realmSub.Use(requireAuth)
realmSub.Use(requireVerified)
realmSub.Use(loadCurrentRealm)
realmSub.Use(requireRealm)
realmSub.Use(requireMFA)
realmSub.Use(requireAdmin)
Expand Down Expand Up @@ -366,6 +373,7 @@ func realMain(ctx context.Context) error {
adminSub := r.PathPrefix("/admin").Subrouter()
adminSub.Use(requireAuth)
adminSub.Use(requireVerified)
adminSub.Use(loadCurrentRealm)
adminSub.Use(requireSystemAdmin)
adminSub.Use(rateLimit)

Expand Down
49 changes: 34 additions & 15 deletions pkg/controller/middleware/realm.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ import (
"github.com/gorilla/mux"
)

// RequireRealm requires a realm to exist in the session. It also ensures the
// realm is set as currentRealm in the template map. It must come after
// RequireAuth so that a user is set on the context.
func RequireRealm(ctx context.Context, cacher cache.Cacher, db *database.Database, h *render.Renderer) mux.MiddlewareFunc {
// LoadCurrentRealm loads the selected realm from the cache to the context
func LoadCurrentRealm(ctx context.Context, cacher cache.Cacher, db *database.Database, h *render.Renderer) mux.MiddlewareFunc {
logger := logging.FromContext(ctx).Named("middleware.RequireRealm")

cacheTTL := 5 * time.Minute
Expand All @@ -42,12 +40,6 @@ func RequireRealm(ctx context.Context, cacher cache.Cacher, db *database.Databas
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

user := controller.UserFromContext(ctx)
if user == nil {
controller.MissingUser(w, r, h)
return
}

session := controller.SessionFromContext(ctx)
if session == nil {
controller.MissingSession(w, r, h)
Expand All @@ -57,7 +49,7 @@ func RequireRealm(ctx context.Context, cacher cache.Cacher, db *database.Databas
realmID := controller.RealmIDFromSession(session)
if realmID == 0 {
logger.Debugw("realm does not exist in session")
controller.MissingRealm(w, r, h)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should comment this noting why we call next.ServerHTTP

next.ServeHTTP(w, r)
return
}

Expand All @@ -79,6 +71,37 @@ func RequireRealm(ctx context.Context, cacher cache.Cacher, db *database.Databas
return
}

// Save the realm on the context.
ctx = controller.WithRealm(ctx, &realm)
*r = *r.WithContext(ctx)

next.ServeHTTP(w, r)
})
}
}

// RequireRealm requires a realm to exist in the session. It also ensures the
// realm is set as currentRealm in the template map. It must come after
// RequireAuth so that a user is set on the context.
func RequireRealm(ctx context.Context, h *render.Renderer) mux.MiddlewareFunc {
logger := logging.FromContext(ctx).Named("middleware.RequireRealm")

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

user := controller.UserFromContext(ctx)
if user == nil {
controller.MissingUser(w, r, h)
return
}

realm := controller.RealmFromContext(ctx)
if realm == nil {
controller.MissingRealm(w, r, h)
return
}

if !user.CanViewRealm(realm.ID) {
logger.Debugw("user cannot view realm")
// Technically this is unauthorized, but we don't want to leak the
Expand All @@ -87,10 +110,6 @@ func RequireRealm(ctx context.Context, cacher cache.Cacher, db *database.Databas
return
}

// Save the realm on the context.
ctx = controller.WithRealm(ctx, &realm)
*r = *r.WithContext(ctx)

next.ServeHTTP(w, r)
})
}
Expand Down