diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 8d81efbb..af878934 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -102,71 +102,74 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { handler.ServeHTTP(w, r) return } + m.RequireAccountHandler(w, r) + } + return http.HandlerFunc(fn) +} - // If we try to redirect when the original request is the ACS URL we'll - // end up in a loop. This is a programming error, so we panic here. In - // general this means a 500 to the user, which is preferable to a - // redirect loop. - if r.URL.Path == m.ServiceProvider.AcsURL.Path { - panic("don't wrap Middleware with RequireAccount") - } +func (m *Middleware) RequireAccountHandler(w http.ResponseWriter, r *http.Request) { + // If we try to redirect when the original request is the ACS URL we'll + // end up in a loop. This is a programming error, so we panic here. In + // general this means a 500 to the user, which is preferable to a + // redirect loop. + if r.URL.Path == m.ServiceProvider.AcsURL.Path { + panic("don't wrap Middleware with RequireAccount") + } - var binding, bindingLocation string - if m.Binding != "" { - binding = m.Binding + var binding, bindingLocation string + if m.Binding != "" { + binding = m.Binding + bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) + } else { + binding = saml.HTTPRedirectBinding + bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) + if bindingLocation == "" { + binding = saml.HTTPPostBinding bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) - } else { - binding = saml.HTTPRedirectBinding - bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) - if bindingLocation == "" { - binding = saml.HTTPPostBinding - bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) - } } + } - req, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + req, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - // relayState is limited to 80 bytes but also must be integrety protected. - // this means that we cannot use a JWT because it is way to long. Instead - // we set a cookie that corresponds to the state - relayState := base64.URLEncoding.EncodeToString(randomBytes(42)) + // relayState is limited to 80 bytes but also must be integrety protected. + // this means that we cannot use a JWT because it is way to long. Instead + // we set a cookie that corresponds to the state + relayState := base64.URLEncoding.EncodeToString(randomBytes(42)) - secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) - state := jwt.New(jwtSigningMethod) - claims := state.Claims.(jwt.MapClaims) - claims["id"] = req.ID - claims["uri"] = r.URL.String() - signedState, err := state.SignedString(secretBlock) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) + state := jwt.New(jwtSigningMethod) + claims := state.Claims.(jwt.MapClaims) + claims["id"] = req.ID + claims["uri"] = r.URL.String() + signedState, err := state.SignedString(secretBlock) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - m.ClientState.SetState(w, r, relayState, signedState) - if binding == saml.HTTPRedirectBinding { - redirectURL := req.Redirect(relayState) - w.Header().Add("Location", redirectURL.String()) - w.WriteHeader(http.StatusFound) - return - } - if binding == saml.HTTPPostBinding { - w.Header().Add("Content-Security-Policy", ""+ - "default-src; "+ - "script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+ - "reflected-xss block; referrer no-referrer;") - w.Header().Add("Content-type", "text/html") - w.Write([]byte(``)) - w.Write(req.Post(relayState)) - w.Write([]byte(``)) - return - } - panic("not reached") + m.ClientState.SetState(w, r, relayState, signedState) + if binding == saml.HTTPRedirectBinding { + redirectURL := req.Redirect(relayState) + w.Header().Add("Location", redirectURL.String()) + w.WriteHeader(http.StatusFound) + return } - return http.HandlerFunc(fn) + if binding == saml.HTTPPostBinding { + w.Header().Add("Content-Security-Policy", ""+ + "default-src; "+ + "script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+ + "reflected-xss block; referrer no-referrer;") + w.Header().Add("Content-type", "text/html") + w.Write([]byte(``)) + w.Write(req.Post(relayState)) + w.Write([]byte(``)) + return + } + panic("not reached") } func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string {