Skip to content

Commit

Permalink
add support for other external middleware (crewjam#184)
Browse files Browse the repository at this point in the history
Changed samlsp/middleware.go to move the request handling part of the RequireAccount method out into a separate public method attached to the middleware type which can be used as a stand alone HTTP RequestHandler. That way the Handler can be used with other middleware chains which provide their own http.Handler wrapper methods.
  • Loading branch information
exedor authored and crewjam committed May 1, 2019
1 parent d7fbac2 commit 138fcb2
Showing 1 changed file with 59 additions and 56 deletions.
115 changes: 59 additions & 56 deletions samlsp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,71 +102,74 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {
handler.ServeHTTP(w, r)
return
}
m.RequireAccountHandler(w, r)
}
return http.HandlerFunc(fn)
}

// If we try to redirect when the original request is the ACS URL we'll
// end up in a loop. This is a programming error, so we panic here. In
// general this means a 500 to the user, which is preferable to a
// redirect loop.
if r.URL.Path == m.ServiceProvider.AcsURL.Path {
panic("don't wrap Middleware with RequireAccount")
}
func (m *Middleware) RequireAccountHandler(w http.ResponseWriter, r *http.Request) {
// If we try to redirect when the original request is the ACS URL we'll
// end up in a loop. This is a programming error, so we panic here. In
// general this means a 500 to the user, which is preferable to a
// redirect loop.
if r.URL.Path == m.ServiceProvider.AcsURL.Path {
panic("don't wrap Middleware with RequireAccount")
}

var binding, bindingLocation string
if m.Binding != "" {
binding = m.Binding
var binding, bindingLocation string
if m.Binding != "" {
binding = m.Binding
bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
} else {
binding = saml.HTTPRedirectBinding
bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
if bindingLocation == "" {
binding = saml.HTTPPostBinding
bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
} else {
binding = saml.HTTPRedirectBinding
bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
if bindingLocation == "" {
binding = saml.HTTPPostBinding
bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
}
}
}

req, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// relayState is limited to 80 bytes but also must be integrety protected.
// this means that we cannot use a JWT because it is way to long. Instead
// we set a cookie that corresponds to the state
relayState := base64.URLEncoding.EncodeToString(randomBytes(42))
// relayState is limited to 80 bytes but also must be integrety protected.
// this means that we cannot use a JWT because it is way to long. Instead
// we set a cookie that corresponds to the state
relayState := base64.URLEncoding.EncodeToString(randomBytes(42))

secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)
state := jwt.New(jwtSigningMethod)
claims := state.Claims.(jwt.MapClaims)
claims["id"] = req.ID
claims["uri"] = r.URL.String()
signedState, err := state.SignedString(secretBlock)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)
state := jwt.New(jwtSigningMethod)
claims := state.Claims.(jwt.MapClaims)
claims["id"] = req.ID
claims["uri"] = r.URL.String()
signedState, err := state.SignedString(secretBlock)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

m.ClientState.SetState(w, r, relayState, signedState)
if binding == saml.HTTPRedirectBinding {
redirectURL := req.Redirect(relayState)
w.Header().Add("Location", redirectURL.String())
w.WriteHeader(http.StatusFound)
return
}
if binding == saml.HTTPPostBinding {
w.Header().Add("Content-Security-Policy", ""+
"default-src; "+
"script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+
"reflected-xss block; referrer no-referrer;")
w.Header().Add("Content-type", "text/html")
w.Write([]byte(`<!DOCTYPE html><html><body>`))
w.Write(req.Post(relayState))
w.Write([]byte(`</body></html>`))
return
}
panic("not reached")
m.ClientState.SetState(w, r, relayState, signedState)
if binding == saml.HTTPRedirectBinding {
redirectURL := req.Redirect(relayState)
w.Header().Add("Location", redirectURL.String())
w.WriteHeader(http.StatusFound)
return
}
return http.HandlerFunc(fn)
if binding == saml.HTTPPostBinding {
w.Header().Add("Content-Security-Policy", ""+
"default-src; "+
"script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+
"reflected-xss block; referrer no-referrer;")
w.Header().Add("Content-type", "text/html")
w.Write([]byte(`<!DOCTYPE html><html><body>`))
w.Write(req.Post(relayState))
w.Write([]byte(`</body></html>`))
return
}
panic("not reached")
}

func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string {
Expand Down

0 comments on commit 138fcb2

Please sign in to comment.